位置: IT常识 - 正文

ChatGPT强化学习大杀器——近端策略优化(PPO)

编辑:rootadmin
ChatGPT强化学习大杀器——近端策略优化(PPO) ChatGPT强化学习大杀器——近端策略优化(PPO)

推荐整理分享ChatGPT强化学习大杀器——近端策略优化(PPO),希望有所帮助,仅作参考,欢迎阅读内容。

文章相关热门搜索词:,内容如对您有帮助,希望把文章链接给更多的朋友!

近端策略优化(Proximal Policy Optimization)来自 Proximal Policy Optimization Algorithms(Schulman et. al., 2017)这篇论文,是当前最先进的强化学习 (RL) 算法。这种优雅的算法可以用于各种任务,并且已经在很多项目中得到了应用,最近火爆的ChatGPT就采用了该算法。

网上讲解ChatGPT算法和训练流程的文章很多,但很少有人深入地将其中关键的近端策略优化算法讲清楚,本文我会重点讲解近端策略优化算法,并用PyTorch从头实现一遍。

文章目录强化学习算法策略优化(基于梯度)近端策略优化CLIP项价值函数项熵奖励项算法实现工具代码核心代码结论强化学习

近端策略优化作为一个先进的强化学习算法,我们首先要对强化学习有个了解。关于强化学习,介绍的文章很多,这里我不做过多介绍,但这里我们可以看一下ChatGPT是怎么解释的:

ChatGPT给出的解释比较通俗易懂,更加学术一点的说,强化学习的流程如下:

强化学习框架

上图中,每个时刻环境都会为代理反馈奖励,并监控当前状态。有了这些信息,代理就会在环境中采取行动,然后新的奖励和状态等会反馈给代理,以此形成循环。这个框架非常通用,可以应用于各种领域。

我们的目标是创建一个可以最大化获得奖励的代理。 通常这个最大化奖励是各时间折扣奖励的总和。 G=∑t=TγtrtG = \sum_{t=0}^T\gamma^tr_tG=t=0∑T​γtrt​ 这里γ\gammaγ是折扣因子,通常在 [0.95, 0.99] 范围内,rtr_trt​ 是时间 t 的奖励。

算法

那么我们如何解决强化学习问题呢? 有多种算法,可以(对于马尔可夫决策过程或 MDP)分为两类:基于模型(创建环境模型)和无模型(仅给定状态学习)。

强化学习算法分类

基于模型的算法创建环境模型并使用该模型来预测未来状态和奖励。 该模型要么是给定的(例如棋盘),要么是学习的。

无模型算法直接学习如何针对训练期间遇到的状态(策略优化或 PO)采取行动,哪些状态-行动会产生良好的回报(Q-Learning)。

我们今天讨论的近端策略优化算法属于 PO 算法家族。 因此,我们不需要环境模型来驱动学习。PO 和 Q-Learning 算法之间的主要区别在于 PO 算法可以用于具有连续动作空间的环境(即我们的动作具有真实值)并且即使该策略是随机策略(即按概率行事)也可以找到最优策略;而 Q-Learning 算法不能做这两件事。 这是PO 算法更受欢迎的另一个原因。 另一方面,Q-Learning 算法往往更简单、更直观且更易于训练。

策略优化(基于梯度)

策略优化算法可以直接学习策略。 为此,策略优化可以使用无梯度算法(例如遗传算法),也可以使用更常见的基于梯度的算法。

通过基于梯度的方法,我们指的是所有尝试估计学习策略相对于累积奖励的梯度的方法。 如果我们知道这个梯度(或它的近似值),我们可以简单地将策略的参数移向梯度的方向以最大化奖励。

策略梯度方法通过重复估计梯度g:=∇θE[∑t=∞rt]g:=\nabla_\theta\mathbb{E}[\sum_{t=0}^{\infin}r_t]g:=∇θ​E[∑t=0∞​rt​]来最大化预期总奖励。策略梯度有几种不同的相关表达式,其形式为:

g=E[∑t=∞Ψt∇θlogπθ(at∣st)](1)g=\mathbb{E}\Bigg\lbrack \sum_{t=0}^{\infin} \Psi_t \nabla_\theta log\pi_\theta(a_t \mid s_t) \Bigg\rbrack \tag{1}g=E[t=0∑∞​Ψt​∇θ​logπθ​(at​∣st​)](1)

其中 Ψt\Psi_tΨt​ 可以是如下几个:

∑t=∞rt\sum_{t=0}^\infin r_t∑t=0∞​rt​: 轨迹的总奖励∑t′=t∞rt′\sum_{t'=t}^\infin r_{t'}∑t′=t∞​rt′​: 下一个动作 ata_tat​ 的奖励∑t=∞rt−b(st)\sum_{t=0}^\infin r_t - b(s_t)∑t=0∞​rt​−b(st​): 上面公式的基线版本Qπ(st,at)Q^\pi(s_t, a_t)Qπ(st​,at​): 状态-动作价值函数Aπ(st,at)A^\pi(s_t, a_t)Aπ(st​,at​): 优势函数rt+Vπ(st+1)+Vπ(st)r_t+V^\pi(s_{t+1})+V^\pi(s_{t})rt​+Vπ(st+1​)+Vπ(st​): TD残差

后面3个公式的具体定义如下: Vπ(st):=Est+1:∞,at:∞[∑l=∞rt+l]Qπ(st,at):=Est+1:∞,at+1:∞[∑l=∞rt+l](2)V^\pi(s_t) := \mathbb{E}_{s_{t+1:\infin}, a_{t:\infin}}\Bigg\lbrack\sum_{l=0}^\infin r_{t+l} \Bigg\rbrack \\ Q^\pi(s_t, a_t) := \mathbb{E}_{s_{t+1:\infin}, a_{t+1:\infin}}\Bigg\lbrack\sum_{l=0}^\infin r_{t+l} \Bigg\rbrack \tag{2}Vπ(st​):=Est+1:∞​,at:∞​​[l=0∑∞​rt+l​]Qπ(st​,at​):=Est+1:∞​,at+1:∞​​[l=0∑∞​rt+l​](2)

Aπ(st,at):=Qπ(st,at)−Vπ(st)(3)A^\pi(s_t, a_t) := Q^\pi(s_t, a_t) - V^\pi(s_t) \tag{3}Aπ(st​,at​):=Qπ(st​,at​)−Vπ(st​)(3)

请注意,有多种方法可以估计梯度。 在这里,我们列出了 6 个不同的值:总奖励、后继动作的奖励、减去基线版本的奖励、状态-动作价值函数、优势函数(在原始 PPO 论文中使用)和时间差 (TD) 残差。我们可以选择这些值作为我们的最大化目标。 原则上,它们都提供了我们所关注的真实梯度的估计。

近端策略优化

近端策略优化简称PPO,是一种(无模型)基于策略优化梯度的算法。 该算法旨在学习一种策略,可以根据训练期间的经验最大化获得的累积奖励。

它由一个参与者(actor) πθ(.∣st)\pi\theta(. \mid st)πθ(.∣st) 和一个评估者(critic) V(st)V(st)V(st)组成。前者在时间 ttt 处输出下一个动作的概率分布,后者估计该状态的预期累积奖励(标量)。 由于 actor 和 critic 都将状态作为输入,因此可以在提取高级特征的两个网络之间共享骨干架构。

PPO 旨在使策略选择具有更高“优势”的行动,即具有比评估者预测的高得多的累积奖励。 同时,我们也不希望一次更新策略太多,这样很可能会出现优化问题。 最后,如果策略具有高熵,我们会倾向于给与额外奖励,以激励更多探索。

ChatGPT强化学习大杀器——近端策略优化(PPO)

总损失函数由三项组成:CLIP项、价值函数 (VF) 项和熵奖励项。最终目标如下: LtCLIP+VF+S(θ)=E^t[LtCLIP(θ)−c1LtVF(θ)+c2S[πθ](st)]L_t^{CLIP+VF+S}(\theta) = \hat{\mathbb{E}}_t \Big\lbrack L_t^{CLIP}(\theta) - c_1L_t^{VF}(\theta)+c_2S[\pi_\theta](s_t)\Big\rbrackLtCLIP+VF+S​(θ)=Et​[LtCLIP​(θ)−c1​LtVF​(θ)+c2​S[πθ​](st​)] 其中 c1c_1c1​ 和 c2c_2c2​ 是超参,分别衡量策略评估(critic)和探索(exploration)准确性的重要性。

CLIP项

正如我们所说,损失函数激发行为概率最大化(或最小化),从而导致行为正面优势(或负面优势) LCLIP(θ)=E^t[min(rt(θ)At^,clip(rt(θ),1−ϵ,1+ϵ)A^t)]L^{CLIP}(\theta) = \hat{\mathbb{E}}_t\Big\lbrack min \Big\lparen r_t(\theta)\hat{A_t},clip \big\lparen r_t(\theta),1-\epsilon, 1+\epsilon\big\rparen \hat{A}_t \Big\rparen \Big\rbrackLCLIP(θ)=Et​[min(rt​(θ)At​​,clip(rt​(θ),1−ϵ,1+ϵ)At​)] 其中: rt(θ)=πθ(at∣st)πθold(at∣st)r_t(\theta) = \frac{\pi_\theta(a_t \mid s_t)}{\pi_{\theta_{old}}(a_t \mid s_t)}rt​(θ)=πθold​​(at​∣st​)πθ​(at​∣st​)​ 是衡量我们现在(更新后的策略)相对于之前执行该先前操作的可能性的比率。 原则上,我们不希望这个系数太大,因为太大意味着策略突然改变。这就是为什么我们采用其最小值和 [1−ϵ,1+ϵ][1-\epsilon, 1+\epsilon][1−ϵ,1+ϵ] 之间的裁剪版本,其中 ϵ\epsilonϵ 是个超参。

优势(advantage)的计算公式如下: A^t=−V(st)+rt+γrt+1+γ2rt+2+⋯+γ(T−t+1)rT−1+γT−tV(sT)\hat{A}_t = -V(s_t)+r_t+\gamma r_{t+1}+\gamma^2 r_{t+2}+\dots+\gamma^{(T-t+1)} r_{T-1} + \gamma^{T-t}V(s_T)At​=−V(st​)+rt​+γrt+1​+γ2rt+2​+⋯+γ(T−t+1)rT−1​+γT−tV(sT​) 其中:At^\hat{A_t}At​​ 是估计的优势,−V(st)-V(s_t)−V(st​) 是估计的初始状态值,γT−tV(sT)\gamma^{T-t}V(s_T)γT−tV(sT​) 是估计的终末状态值,中间部分是过程中观测到的累计奖励。

我们看到它只是衡量评估者对给定状态 sts_tst​ 的错误程度。 如果我们获得更高的累积奖励,则优势估计将为正,我们将更有可能在这种状态下采取行动。 反之亦然,如果我们期望更高的奖励但我们得到的奖励更小,则优势估计将为负,我们将降低在此步骤中采取行动的可能性。

请注意,如果我们一直走到终末状态 sTs_TsT​,我们就不再需要依赖评估者了,我们可以简单地将评估者与实际累积奖励进行比较。 在这种情况下,优势的估计就是真正的优势。

价值函数项

为了对优势有一个良好的估计,我们需要一个可以预测给定状态值的评估者。 该模型为具有简单 MSE 损失的监督学习: LtVF=MSE(rt+γrt+1+⋯+γ(T−t+1)rT−1+V(sT),V(st))=∣∣A^t∣∣2L_t^{VF} = MSE(r_t+\gamma r_{t+1}+\dots+\gamma^{(T-t+1)} r_{T-1}+V(s_T),V(s_t)) = ||\hat{A}_t||_2LtVF​=MSE(rt​+γrt+1​+⋯+γ(T−t+1)rT−1​+V(sT​),V(st​))=∣∣At​∣∣2​ 每次迭代中,我们也会更新评估者,以便随着训练的进行,它会为我们提供越来越准确的状态值。

熵奖励项

最后,我们鼓励对策略输出分布的熵进行少量奖励的探索。 标准熵为: S[πθ](st)=−∫πθ(at∣st)log(πθ(at∣st))datS[\pi_\theta](s_t) = -\int \pi_\theta(a_t \mid s_t) log(\pi_\theta(a_t \mid s_t))da_tS[πθ​](st​)=−∫πθ​(at​∣st​)log(πθ​(at​∣st​))dat​

算法实现

如果上面的讲解不够清晰,不用担心,下面将带大家从零开始一步一步实现近端策略优化算法。

工具代码

首先先导入所需的库

from argparse import ArgumentParserimport gymimport numpy as npimport wandbimport torchimport torch.nn as nnfrom torch.optim import Adamfrom torch.optim.lr_scheduler import LinearLRfrom torch.distributions.categorical import Categoricalimport pytorch_lightning as pl

PPO 的重要超参有 actor 数量、horizon、epsilon、每个优化阶段的 epoch 数量、学习率、折扣因子gamma以及权衡不同损失项的常数 c1 和 c2。 这些超参我们通过参数传入进来。

def parse_args(): """解析参数""" parser = ArgumentParser() parser.add_argument("--max_iterations", type=int, help="训练迭代次数", default=100) parser.add_argument("--n_actors", type=int, help="actor数量", default=8) parser.add_argument("--horizon", type=int, help="每个actor的时间戳数量", default=128) parser.add_argument("--epsilon", type=float, help="Epsilon", default=0.1) parser.add_argument("--n_epochs", type=int, help="每次迭代的训练轮数", default=3) parser.add_argument("--batch_size", type=int, help="Batch size", default=32 * 8) parser.add_argument("--lr", type=float, help="学习率", default=2.5 * 1e-4) parser.add_argument("--gamma", type=float, help="折扣因子gamma", default=0.99) parser.add_argument("--c1", type=float, help="损失函数价值函数的权重", default=1) parser.add_argument("--c2", type=float, help="损失函数熵奖励的权重", default=0.01) parser.add_argument("--n_test_episodes", type=int, help="Number of episodes to render", default=5) parser.add_argument("--seed", type=int, help="随机种子", default=0) return vars(parser.parse_args())

请注意,默认情况下,参数是按照论文中的描述设置的。 理想情况下,我们的代码应该尽可能在 GPU 上运行,因此我们需要设置一下torch的设备。

def get_device(): if torch.cuda.is_available(): device = torch.device("cuda") print(f"Found GPU device: {torch.cuda.get_device_name(device)}") else: device = torch.device("cpu") print("No GPU found: Running on CPU") return device

当我们执行强化学习时,通常会设一个缓冲区来存储当前模型遇到的状态、动作和奖励,用于更新我们的模型。 我们创建一个函数 run_timestamps,它将在给定的环境中运行给定的模型并获得固定数量的时间戳(如果episode 结束则重新设置环境)。 我们还使用选项 render=False 以便我们只想查看训练模型的表现。

@torch.no_grad()def run_timestamps(env, model, timestamps=128, render=False, device="cpu"): """针对给定数量的时间戳在给定环境中运行给定策略。 返回具有状态、动作和奖励的缓冲区。""" buffer = [] state = env.reset()[0] # 运行时间戳并收集状态、动作、奖励和终止 for ts in range(timestamps): model_input = torch.from_numpy(state).unsqueeze(0).to(device).float() action, action_logits, value = model(model_input) new_state, reward, terminated, truncated, info = env.step(action.item()) # (s, a, r, t)渲染到环境或存储到buffer if render: env.render() else: buffer.append([model_input, action, action_logits, value, reward, terminated or truncated]) # 更新当前状态 state = new_state # 如果episode终止或被截断,则重置环境 if terminated or truncated: state = env.reset()[0] return buffer

该函数的返回值(未渲染时)是一个缓冲区,其中包含状态、采取的行动、行动概率(logits)、评估者价值、奖励以及每个时间戳所提供策略的终止状态。 请注意,该函数使用装饰器@torch.no_grad(),因此我们不需要为与环境交互期间采取的操作存储梯度。

核心代码

有了上面的工具函数,我们就可以开发近端策略优化的核心代码了。首先新搭建main函数流程:

def main(): # 解析参数 args = parse_args() print(args) # 设置种子 pl.seed_everything(args["seed"]) # 获取设备 device = get_device() # 创建环境 env_name = "CartPole-v1" env = gym.make(env_name) # TODO 创建模型,训练模型,输出结果 model = MyPPO(env.observation_space.shape, env.action_space.n).to(device) training_loop(env, model, args) model = load_best_model() testing_loop(env, model)

上面就是整体程序的流程框架。接下来,我们只需要定义 PPO 模型、训练和测试函数。

PPO 模型的架构这里不过多阐述,我们只需要两个在环境中起作用的模型(actor和critic)。 当然,模型架构在更复杂的任务中作用至关重要,但有我们这个简单任务中,MLP 就可以完成这项工作。

因此,我们可以创建一个包含actor和critic模型的 MyPPO 类。当对某些状态运行前向方法时,我们返回actor的采样动作、每个可能动作的相对概率 (logits) 以及critic对每个状态的估计值。

class MyPPO(nn.Module): """ PPO模型的实现。 相同的代码结构即可用于actor,也可用于critic。 """ def __init__(self, in_shape, n_actions, hidden_d=100, share_backbone=False): # 父类构造函数 super(MyPPO, self).__init__() # 属性 self.in_shape = in_shape self.n_actions = n_actions self.hidden_d = hidden_d self.share_backbone = share_backbone # 共享策略主干和价值函数 in_dim = np.prod(in_shape) def to_features(): return nn.Sequential( nn.Flatten(), nn.Linear(in_dim, hidden_d), nn.ReLU(), nn.Linear(hidden_d, hidden_d), nn.ReLU() ) self.backbone = to_features() if self.share_backbone else nn.Identity() # State action function self.actor = nn.Sequential( nn.Identity() if self.share_backbone else to_features(), nn.Linear(hidden_d, hidden_d), nn.ReLU(), nn.Linear(hidden_d, n_actions), nn.Softmax(dim=-1) ) # Value function self.critic = nn.Sequential( nn.Identity() if self.share_backbone else to_features(), nn.Linear(hidden_d, hidden_d), nn.ReLU(), nn.Linear(hidden_d, 1) ) def forward(self, x): features = self.backbone(x) action = self.actor(features) value = self.critic(features) return Categorical(action).sample(), action, value

请注意,Categorical(action).sample() 创建了一个分类分布,其中包含一个动作(针对每个状态)的动作 logits 和样本。

最后,我们可以处理 training_loop 函数中的实际算法。 正如我们从论文中了解到的,该函数的实际签名应如下所示:

def training_loop(env, model, max_iterations, n_actors, horizon, gamma,epsilon, n_epochs, batch_size, lr, c1, c2, device, env_name=""): # TODO...

以下是论文中为 PPO 训练程序的伪代码:

PPO 的伪代码相对简单:我们只需通过策略模型(称为actor)的多个副本收集与环境的交互,并使用先前定义的目标来优化actor和critic网络。

由于我们需要衡量我们真正获得的累积奖励,因此需要创建一个函数,给定一个缓冲区,用累积奖励替换每个时间的奖励:

def compute_cumulative_rewards(buffer, gamma): """ 给定一个包含状态、策略操作逻辑、奖励和终止的缓冲区,计算每个时间的累积奖励并将它们代入缓冲区。 """ curr_rew = 0. # 反向遍历缓冲区 for i in range(len(buffer) - 1, -1, -1): r, t = buffer[i][-2], buffer[i][-1] if t: curr_rew = 0 else: curr_rew = r + gamma * curr_rew buffer[i][-2] = curr_rew # 在规范化之前获得平均奖励(用于日志记录和检查点) avg_rew = np.mean([buffer[i][-2] for i in range(len(buffer))]) # 规范化累积奖励 mean = np.mean([buffer[i][-2] for i in range(len(buffer))]) std = np.std([buffer[i][-2] for i in range(len(buffer))]) + 1e-6 for i in range(len(buffer)): buffer[i][-2] = (buffer[i][-2] - mean) / std return avg_rew

请注意,最后我们将累积奖励归一化处理。 这是使优化问题更容易和训练更顺利的标准技巧。

现在我们可以获得包含状态、采取的动作、动作概率和累积奖励的缓冲区,可以编写一个函数,在给定缓冲区的情况下,为我们的最终目标计算三个损失项:

def get_losses(model, batch, epsilon, annealing, device="cpu"): """给定模型、给定批次和附加参数返回三个损失项""" # 获取旧数据 n = len(batch) states = torch.cat([batch[i][0] for i in range(n)]) actions = torch.cat([batch[i][1] for i in range(n)]).view(n, 1) logits = torch.cat([batch[i][2] for i in range(n)]) values = torch.cat([batch[i][3] for i in range(n)]) cumulative_rewards = torch.tensor([batch[i][-2] for i in range(n)]).view(-1, 1).float().to(device) # 使用新模型计算预测 _, new_logits, new_values = model(states) # 状态动作函数损失(L_CLIP) advantages = cumulative_rewards - values margin = epsilon * annealing ratios = new_logits.gather(1, actions) / logits.gather(1, actions) l_clip = torch.mean( torch.min( torch.cat( (ratios * advantages, torch.clip(ratios, 1 - margin, 1 + margin) * advantages), dim=1), dim=1 ).values ) # 价值函数损失(L_VF) l_vf = torch.mean((cumulative_rewards - new_values) ** 2) # 熵奖励 entropy_bonus = torch.mean(torch.sum(-new_logits * (torch.log(new_logits + 1e-5)), dim=1)) return l_clip, l_vf, entropy_bonus

请注意,在实践中,我们使用初值为 1 并在整个训练过程中线性衰减至 0 的退火参数。 因为随着训练的进行,我们希望我们的策略变化越来越少。 另外,与 new_logits 和 new_values 不同,我们不不跟踪advantages 变量的梯度,只是张量之差。

现在我们有了环境交互和存储缓冲区、计算(真实)累积奖励并获得损失项的方法,可以着手编写最终训练代码了:

def training_loop(env, model, max_iterations, n_actors, horizon, gamma, epsilon, n_epochs, batch_size, lr, c1, c2, device, env_name=""): """使用最多n个时间戳的多个actor在给定环境中训练模型。""" # 开始运行新的权重和偏差 wandb.init(project="Papers Re-implementations", entity="peutlefaire", name=f"PPO - {env_name}", config={ "env": str(env), "number of actors": n_actors, "horizon": horizon, "gamma": gamma, "epsilon": epsilon, "epochs": n_epochs, "batch size": batch_size, "learning rate": lr, "c1": c1, "c2": c2 }) # 训练变量 max_reward = float("-inf") optimizer = Adam(model.parameters(), lr=lr, maximize=True) scheduler = LinearLR(optimizer, 1, 0, max_iterations * n_epochs) anneals = np.linspace(1, 0, max_iterations) # 训练循环 for iteration in range(max_iterations): buffer = [] annealing = anneals[iteration] # 使用当前策略收集所有actor的时间戳 for actor in range(1, n_actors + 1): buffer.extend(run_timestamps(env, model, horizon, False, device)) # 计算累积奖励并刷新缓冲区 avg_rew = compute_cumulative_rewards(buffer, gamma) np.random.shuffle(buffer) # 运行几轮优化 for epoch in range(n_epochs): for batch_idx in range(len(buffer) // batch_size): start = batch_size * batch_idx end = start + batch_size if start + batch_size < len(buffer) else -1 batch = buffer[start:end] # 归零优化器梯度 optimizer.zero_grad() # 获取损失 l_clip, l_vf, entropy_bonus = get_losses(model, batch, epsilon, annealing, device) # 计算总损失并反向传播 loss = l_clip - c1 * l_vf + c2 * entropy_bonus loss.backward() # 优化 optimizer.step() scheduler.step() # 记录输出 curr_loss = loss.item() log = f"Iteration {iteration + 1} / {max_iterations}: " \ f"Average Reward: {avg_rew:.2f}\t" \ f"Loss: {curr_loss:.3f} " \ f"(L_CLIP: {l_clip.item():.1f} | L_VF: {l_vf.item():.1f} | L_bonus: {entropy_bonus.item():.1f})" if avg_rew > max_reward: torch.save(model.state_dict(), MODEL_PATH) max_reward = avg_rew log += " --> Stored model with highest average reward" print(log) # 将信息记录到 W&B wandb.log({ "loss (total)": curr_loss, "loss (clip)": l_clip.item(), "loss (vf)": l_vf.item(), "loss (entropy bonus)": entropy_bonus.item(), "average reward": avg_rew }) # 完成 W&B 会话 wandb.finish()

最后,为了查看模型的最终效果,我们使用以下 testing_loop 函数:

def testing_loop(env, model, n_episodes, device): for _ in range(n_episodes): run_timestamps(env, model, timestamps=128, render=True, device=device)

这样,我们的主程序就会变得很简单:

def main(): # 解析参数 args = parse_args() print(args) # 设置种子 pl.seed_everything(args["seed"]) # 获取设备 device = get_device() # 创建环境 env_name = "CartPole-v1" env = gym.make(env_name) # 创建模型(actor和critic) model = MyPPO(env.observation_space.shape, env.action_space.n).to(device) # 训练 training_loop(env, model, args["max_iterations"], args["n_actors"], args["horizon"], args["gamma"], args["epsilon"], args["n_epochs"], args["batch_size"], args["lr"], args["c1"], args["c2"], device, env_name) # 加载最佳模型 model = MyPPO(env.observation_space.shape, env.action_space.n).to(device) model.load_state_dict(torch.load(MODEL_PATH, map_location=device)) # 测试 env = gym.make(env_name, render_mode="human") testing_loop(env, model, args["n_test_episodes"], device) env.close()

以上这就是全部的实现! 如果你理解上面的代码,恭喜你,你已经理解PPO算法了。

结论

近端策略优化是最先进的策略强化学习优化算法,它几乎可以在任何环境中使用。 此外,近端策略优化具有相对简单的目标函数和相对较少的需要调整的超参。

ChatGPT依赖PPO在第三步获得了超出预期的效果,大家可以在自己的强化学习任务中予以采用,可能会获得意想不到的效果。

本文链接地址:https://www.jiuchutong.com/zhishi/299249.html 转载请保留说明!

上一篇:数字图像处理总结(冈萨雷斯版)(数字图像处理总结)

下一篇:js二十五道面试题(含答案)(js面试2021)

  • 微信朋友圈营销要掌握的要点(微信朋友圈营销方案)

    微信朋友圈营销要掌握的要点(微信朋友圈营销方案)

  • 小米mix4怎么开启极速充电(小米MIX4怎么开120帧王者)

    小米mix4怎么开启极速充电(小米MIX4怎么开120帧王者)

  • 微信视频号怎么下载别人的视频(微信视频号怎么看历史观看记录)

    微信视频号怎么下载别人的视频(微信视频号怎么看历史观看记录)

  • ipad申请微信号怎么申请(ipad微信号怎么申请)

    ipad申请微信号怎么申请(ipad微信号怎么申请)

  • 苹果手表和手机微信不同步(苹果手表和手机连接不上怎么办)

    苹果手表和手机微信不同步(苹果手表和手机连接不上怎么办)

  • 打印机驱动无法使用是什么原因(打印机驱动无法删除提示正在使用中)

    打印机驱动无法使用是什么原因(打印机驱动无法删除提示正在使用中)

  • 腾讯快速会议和预定会议有什么区别(腾讯快速会议和预定会议哪个好)

    腾讯快速会议和预定会议有什么区别(腾讯快速会议和预定会议哪个好)

  • nova7支持红外吗(华为nova 7支持红外功能吗)

    nova7支持红外吗(华为nova 7支持红外功能吗)

  • 微信实名认证注销后会怎样(微信实名认证注销后重新实名认证还是原来的支付账户吗)

    微信实名认证注销后会怎样(微信实名认证注销后重新实名认证还是原来的支付账户吗)

  • 高对比度文字啥意思(高对比度有什么好处)

    高对比度文字啥意思(高对比度有什么好处)

  • 手机内存卡怎么用(手机内存卡怎么使用)

    手机内存卡怎么用(手机内存卡怎么使用)

  • 拼多多怎样同时拍两件(拼多多如何两个一起付款)

    拼多多怎样同时拍两件(拼多多如何两个一起付款)

  • 手机通知栏hd是什么意思(手机通知栏有个hd)

    手机通知栏hd是什么意思(手机通知栏有个hd)

  • 手机贴吧怎么看回复(手机贴吧怎么看等级头衔)

    手机贴吧怎么看回复(手机贴吧怎么看等级头衔)

  • 闲鱼怎样发高清图片(闲鱼怎么发清晰度高的视频)

    闲鱼怎样发高清图片(闲鱼怎么发清晰度高的视频)

  • 京东白条收款码怎么开(京东白条收款码在哪)

    京东白条收款码怎么开(京东白条收款码在哪)

  • 手机爱奇艺怎么取消广告(手机爱奇艺怎么出示二维码让别人登录)

    手机爱奇艺怎么取消广告(手机爱奇艺怎么出示二维码让别人登录)

  • 音响调节器怎么调音量(音响调节器怎么是2路输入)

    音响调节器怎么调音量(音响调节器怎么是2路输入)

  • vivo低电量模式怎么开启(vivo低电量模式有必要一直开着吗)

    vivo低电量模式怎么开启(vivo低电量模式有必要一直开着吗)

  • 苹果自带录屏在哪(苹果自带录屏在录制和平精英是怎么隐藏键位)

    苹果自带录屏在哪(苹果自带录屏在录制和平精英是怎么隐藏键位)

  • 苹果ios10.3.3怎么录屏(苹果ios10.0怎么更新)

    苹果ios10.3.3怎么录屏(苹果ios10.0怎么更新)

  • cltal01华为什么型号(华为clt-tl01)

    cltal01华为什么型号(华为clt-tl01)

  • oppor17高清通话怎么关闭(oppor17高清通话volte怎么开)

    oppor17高清通话怎么关闭(oppor17高清通话volte怎么开)

  • 淘宝怎么删除评价(淘宝怎么删除评价商品)

    淘宝怎么删除评价(淘宝怎么删除评价商品)

  • js中Array.from的用法(js array())

    js中Array.from的用法(js array())

  • 个人去税务局开劳务费怎么交税
  • 什么公司可以核算成本
  • 以前年度损益调整结转到哪里
  • 小微企业城建税及附加减半
  • 怎样根据税负率调账
  • 一千万人民币可以买多少斤黄金
  • 金蝶k3审核过账在哪里
  • 为什么纳税申报
  • 企业所得税汇算清缴补缴税款分录
  • 工会经费怎样申报
  • 换件维修的部件什么意思
  • 减免的教育费附加和地方教育费附加怎么做分录
  • 广告公司制作警示牌可以开具什么样的发票?
  • 增值税普票跨月怎么冲红
  • 什么时候需要交个人所得税
  • 营改增后印花税计税依据文件
  • 营改增后征收增值税的税目
  • 小规模纳税人差额征税
  • 加油的普票可以抵扣进项税吗
  • 外出经营地预交税金归主管税务所管吗
  • 股权转让受让方要交个人所得税吗
  • 上月材料入库会计分录
  • 股东可以随时退出吗
  • 手把手怎么样
  • 游泳耳朵进水怎么办
  • 没有发票的成本怎么算
  • 油气勘探支出包含哪些
  • 非正常损失会计利润调整
  • php中使用js
  • 公允价值变动损益借贷方向增减
  • php加密后的代码能运行吗?
  • codeignitor
  • 计算机视觉:一种现代方法
  • ConvNeXt V2学习笔记
  • 主动学习(Active Learning,AL)的理解以及代码流程讲解
  • 网络安全工具大全图片
  • 有哪些员工福利
  • 备用金管理方式
  • 收到工会经费返还属于现金流量表哪
  • 安装mysql5.1的步骤和方法
  • mysql5.7.35安装配置教程
  • 给客户的现金奖励会计处理
  • 不得开具增值税专用发票是什么意思
  • 出售固定资产不能作为企业的收入
  • 税款要在15号前扣吗
  • 应交税费案例分析题
  • 印花税减征比例
  • msdn sql server
  • mysql分页优化原理
  • mysql8.0存储过程
  • sql server 链接
  • 代扣代缴的社保为什么是其他应付款
  • 转让技术所有权是其他业务收入吗
  • 个体户是怎么交公积金的
  • 弥补亏损的会计科目有哪些
  • 房屋租赁怎么干
  • 盈余公积的提取基数
  • 企业正常经营的条件
  • 逾期的押金计入什么科目
  • 差旅费算人工费吗
  • 现金支付现金股利
  • 实际到货跟采购不一致
  • 生育津贴案件
  • 听妈妈讲那过去的事情讲课
  • sqlserver2008r2创建实例
  • MySql 5.6.35 winx64 安装详细教程
  • 如何看xp系统
  • VMware虚拟机中卸载java命令
  • linux rmdir
  • 安装win 7系统
  • -f linux命令
  • 如何检测电脑能否上网
  • linux中修改命令
  • linux网络设备有哪些
  • three.js dispose
  • socket restful
  • for语句中的++i
  • JavaScript中Textarea滚动条不能拖动的解决方法
  • 城市配套费需要交税吗
  • 如何办理清税证书
  • 免责声明:网站部分图片文字素材来源于网络,如有侵权,请及时告知,我们会第一时间删除,谢谢! 邮箱:opceo@qq.com

    鄂ICP备2023003026号

    网站地图: 企业信息 工商信息 财税知识 网络常识 编程技术

    友情链接: 武汉网站建设