当前位置: 首页 > news >正文

DDPG算法里的‘演员’和‘评论家’到底在吵什么?用Python代码逐行拆解训练过程

DDPG算法里的‘演员’和‘评论家’到底在吵什么?用Python代码逐行拆解训练过程

想象一下,你正在导演一场没有剧本的即兴戏剧。演员(Actor)需要在舞台上即兴发挥,而评论家(Critic)则在台下实时点评。这场戏的特殊之处在于——演员的动作可以精确到毫米级的角度变化,而评论家的打分标准也在不断调整。这就是DDPG(深度确定性策略梯度)算法的核心戏剧冲突。让我们用PyTorch代码作为舞台,揭开这场"表演艺术"背后的技术内幕。

1. 搭建舞台:DDPG的四大角色初始化

任何好戏都需要精心搭建舞台。在DDPG的宇宙里,我们需要先准备好四个关键神经网络:

import torch import torch.nn as nn import torch.optim as optim import numpy as np class Actor(nn.Module): def __init__(self, state_dim, action_dim, max_action): super(Actor, self).__init__() self.layer_1 = nn.Linear(state_dim, 400) self.layer_2 = nn.Linear(400, 300) self.layer_3 = nn.Linear(300, action_dim) self.max_action = max_action def forward(self, state): x = torch.relu(self.layer_1(state)) x = torch.relu(self.layer_2(x)) return self.max_action * torch.tanh(self.layer_3(x)) class Critic(nn.Module): def __init__(self, state_dim, action_dim): super(Critic, self).__init__() self.layer_1 = nn.Linear(state_dim + action_dim, 400) self.layer_2 = nn.Linear(400, 300) self.layer_3 = nn.Linear(300, 1) def forward(self, state, action): x = torch.cat([state, action], dim=1) x = torch.relu(self.layer_1(x)) x = torch.relu(self.layer_2(x)) return self.layer_3(x)

这里有两个关键设计决策值得注意:

  • Actor的输出层使用tanh:将动作限制在[-max_action, max_action]范围内
  • Critic接收状态和动作的拼接:这是Q函数的典型设计,用于评估(state, action)对的价值

四个角色的初始化就像组建剧团:

# 主演员和主评论家 actor = Actor(state_dim, action_dim, max_action) critic = Critic(state_dim, action_dim) # 备用演员和备用评论家(目标网络) target_actor = Actor(state_dim, action_dim, max_action) target_critic = Critic(state_dim, action_dim) # 初始时目标网络与主网络参数相同 target_actor.load_state_dict(actor.state_dict()) target_critic.load_state_dict(critic.state_dict())

2. 排练过程:训练循环中的动态博弈

真正的戏剧性冲突发生在训练循环中。让我们分解一个完整的训练步骤:

2.1 经验收集阶段

def select_action(state, noise): state = torch.FloatTensor(state.reshape(1, -1)) action = actor(state).data.numpy().flatten() return np.clip(action + noise, -max_action, max_action) # 在环境中执行动作并存储经验 next_state, reward, done, _ = env.step(action) replay_buffer.add(state, action, reward, next_state, done)

这里引入的探索噪声就像演员的即兴发挥——在确定性策略中加入随机性,避免表演变得刻板。常见的选择是Ornstein-Uhlenbeck噪声,它能产生时间相关的随机过程,适合物理系统的连续控制。

2.2 批评家的学习时刻

从经验池采样后,Critic开始它的"毒舌点评":

# 计算目标Q值 target_actions = target_actor(next_states) target_q_values = target_critic(next_states, target_actions) targets = rewards + (1 - dones) * gamma * target_q_values # 计算当前Q值估计 current_q_values = critic(states, actions) # Critic损失函数 critic_loss = nn.MSELoss()(current_q_values, targets.detach())

Critic的更新包含三个关键点:

  1. 使用目标网络计算target_q_values保持稳定性
  2. targets.detach()切断梯度回传,防止干扰目标网络
  3. (1 - dones)项处理回合终止时的特殊情况

2.3 演员的自我修养

Actor的更新则更有意思——它试图讨好Critic:

actor_loss = -critic(states, actor(states)).mean()

这个简单的表达式蕴含着深度策略梯度:

  • 通过Critic评估Actor当前策略的表现
  • 负号表示我们要最大化这个评估值
  • 梯度上升转化为损失函数的极小化

2.4 温和的更新:软同步机制

DDPG最精妙的设计在于目标网络的更新方式:

def soft_update(target, source, tau): for target_param, param in zip(target.parameters(), source.parameters()): target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data) # 更新目标网络 soft_update(target_actor, actor, tau) soft_update(target_critic, critic, tau)

这种Polyak平均策略(tau通常取0.005)就像老演员缓慢吸收新演员的表演风格,避免突然的风格转变吓到观众。

3. 幕后花絮:关键技巧与调试经验

在实际制作中,有几个幕后技巧决定了演出成败:

3.1 经验回放的秘密配方

class ReplayBuffer: def __init__(self, max_size): self.buffer = [] self.max_size = max_size def add(self, state, action, reward, next_state, done): if len(self.buffer) >= self.max_size: self.buffer.pop(0) self.buffer.append((state, action, reward, next_state, done)) def sample(self, batch_size): indices = np.random.choice(len(self.buffer), batch_size) states, actions, rewards, next_states, dones = zip(*[self.buffer[i] for i in indices]) return np.array(states), np.array(actions), np.array(rewards), np.array(next_states), np.array(dones)

经验回放的两个关键参数:

  • buffer大小:通常1e5到1e6,太小导致样本相关性高,太大则学习缓慢
  • batch大小:一般从128开始尝试,复杂任务可能需要更大batch

3.2 学习率的舞蹈

Actor和Critic通常需要不同的学习节奏:

actor_optimizer = optim.Adam(actor.parameters(), lr=1e-4) critic_optimizer = optim.Adam(critic.parameters(), lr=1e-3)

典型配置:

  • Critic学习率是Actor的5-10倍
  • 太高的Actor学习率会导致策略震荡
  • 太低的Critic学习率则使反馈信号滞后

3.3 噪声退火策略

聪明的导演会随着排练进度减少即兴发挥:

def update_noise(noise_scale): noise_scale *= 0.9999 # 指数衰减 return max(noise_scale, 0.1) # 保持最小探索

这种退火策略平衡了:

  • 初期:高噪声促进探索
  • 后期:低噪声利于策略精修

4. 完整演出:Pendulum-v1实例解析

让我们看一个钟摆平衡的具体案例。以下是训练循环的核心代码:

for episode in range(total_episodes): state = env.reset() episode_reward = 0 noise_scale = initial_noise for step in range(max_steps): action = select_action(state, noise_scale * np.random.randn(action_dim)) next_state, reward, done, _ = env.step(action) replay_buffer.add(state, action, reward, next_state, done) state = next_state episode_reward += reward if len(replay_buffer) > batch_size: states, actions, rewards, next_states, dones = replay_buffer.sample(batch_size) # 转换为PyTorch张量 states = torch.FloatTensor(states) actions = torch.FloatTensor(actions) rewards = torch.FloatTensor(rewards).unsqueeze(1) next_states = torch.FloatTensor(next_states) dones = torch.FloatTensor(dones).unsqueeze(1) # Critic更新 critic_optimizer.zero_grad() critic_loss = compute_critic_loss(states, actions, rewards, next_states, dones) critic_loss.backward() critic_optimizer.step() # Actor更新 actor_optimizer.zero_grad() actor_loss = compute_actor_loss(states) actor_loss.backward() actor_optimizer.step() # 软更新目标网络 soft_update(target_actor, actor, tau) soft_update(target_critic, critic, tau) noise_scale = update_noise(noise_scale) print(f"Episode {episode}, Reward: {episode_reward}")

训练过程中常见的现象记录:

训练阶段典型现象解决方案
初期 (0-1k步)奖励随机波动增加噪声规模,增大回放缓冲区
中期 (1k-10k步)偶尔出现高分但不稳定检查Critic损失是否收敛,调整学习率
后期 (>10k步)性能平台期尝试减小噪声,微调网络结构

在Pendulum-v1环境中,成功的训练通常会在约50-100个episode后开始出现稳定的摆动策略,300个episode左右能达到接近最优的性能。

http://www.zskr.cn/news/1429555.html

相关文章:

  • 1379份真实中文临床文本,含手术/药物/疾病等六类实体的字符级标注数据
  • 终极解决方案:3分钟让魔兽争霸3在现代电脑上完美运行 [特殊字符]
  • 用Python玩转赌徒问题:手把手教你实现MDP的两种经典算法(附完整代码)
  • 工程洗车台选型避坑指南:从“会喷水”到真有效,这三点经常被忽略 - 品牌优选官
  • 告别ImageNet标注!用DINO+ViT在无标签数据上实现80%+准确率的保姆级复现教程
  • #三清侠# 最近发现一个超有安全感的“新侠客”[特殊字符]
  • YOLO训练翻车?可能是你的TXT标注文件‘回炉’没做好!手把手教你TXT转回Labelme JSON
  • 大语言模型如何“认识”你:从原理到个人数字身份监控实践
  • ABB 011865-003 3/8NPT 内外丝 90° 黄铜弯头
  • 2026 中央电教馆美术教育指导教师证书详解|职业前景、报考流程、官方报名渠道推荐、证书含金量等问题一站式解答 - 教育官方推荐官
  • Gemini隐私政策不是法律文件,而是信任协议——用可验证隐私(VP)框架重构起草逻辑(含零知识证明集成示例)
  • 基于OpenCV与Mediapipe的手势识别:实现石头剪刀布人机对战
  • 3D视觉赋能新能源补能无人化:自动充电 / 换电 / 加氢场景技术落地解析
  • 牛顿迭代算法及使用条件
  • 技术风险管理实战解析与核心技术落地指南
  • 校园失物招领系统|基于Spring boot+vue的校园失物招领系统设计与实现(源码+数据库+文档)
  • Mac mini缺货涨价,无头MacBook重出江湖成AI新宠!养虾还有啥靠谱选择?
  • 外卖订餐小程序|基于java微信小程序的外卖订餐系统设计与实现(源码+数据库+文档)
  • WinDirStat:终极磁盘空间分析神器,快速释放Windows存储空间
  • AI搜索隐私生死线:从查询脱敏到结果缓存,7个被99%用户忽略的泄露入口,及3步零配置加固方案
  • AI工具安全红线清单:3类数据泄露场景、4层防护机制、1套GDPR/等保2.0合规自查表
  • 电路设计融入生活创意:从工作坊实践到智能家居应用
  • HS2-HF Patch终极指南:三分钟解锁Honey Select 2完整汉化与功能增强
  • 从零构建可复现研究叙事(Gemini+Zotero+Overleaf闭环):中科院团队实测,投稿周期压缩至11.3天
  • 保姆级教程:用CMake快速集成CSerialPort 4.3.x到你的C++项目(附完整代码)
  • Python脚本录制与回放:Appium Inspector搭配网易MuMu模拟器快速生成自动化测试代码
  • Scarab:空洞骑士模组管理的终极智能解决方案
  • 为何Synology Drive Client不能同步?
  • RPG Maker MV插件宝库:300+插件让你的游戏开发效率翻倍
  • 多功能低温性能测定仪常见故障分析与解决方法