当前位置: 首页 > news >正文

用Python玩转赌徒问题:手把手教你实现MDP的两种经典算法(附完整代码)

用Python玩转赌徒问题:手把手教你实现MDP的两种经典算法(附完整代码)

马尔科夫决策过程(MDP)是强化学习的基础框架之一,而赌徒问题则是理解MDP的绝佳案例。本文将带你从零开始,用Python实现策略迭代和值迭代这两种经典算法,并通过可视化分析不同参数下的策略变化。无论你是想巩固理论知识,还是希望获得可复用的代码模板,这篇文章都能满足你的需求。

1. 环境准备与问题建模

在开始编码前,我们需要明确赌徒问题的数学模型。假设一个赌徒初始有s美元(1≤s≤99),每次可以选择下注1到min(s,100-s)美元。硬币正面朝上的概率为ph,获胜则获得下注金额,失败则失去下注金额。游戏在达到100美元或破产时结束。

首先安装必要的库:

pip install numpy matplotlib seaborn

定义问题参数:

GOAL = 100 # 目标金额 STATES = np.arange(GOAL + 1) # 所有可能状态(0到100) ph = 0.4 # 硬币正面概率 gamma = 1 # 折扣因子

状态值函数初始化时,只有达到目标状态(100)才有奖励1:

state_values = np.zeros(GOAL + 1) state_values[GOAL] = 1.0

2. 策略迭代算法实现

策略迭代分为两个交替进行的阶段:策略评估和策略改进。我们先来看完整的类实现:

class PolicyIteration: def __init__(self, goal=100, proba_h=0.4, theta=1e-9, gamma=1): self.ph = proba_h self.gamma = gamma self.goal = goal self.theta = theta self.states = np.arange(goal + 1) self.state_values = np.zeros(goal + 1) self.state_values[goal] = 1.0 self.policy = np.zeros(goal + 1) # 初始策略全0 self.sweeps_history = [] # 记录每次迭代的值函数 def policy_evaluation(self): while True: old_values = self.state_values.copy() self.sweeps_history.append(old_values) for s in self.states[1:self.goal]: actions = np.arange(min(s, self.goal - s)) + 1 action_returns = [] for a in actions: ret = self.ph * (self.gamma * self.state_values[s + a]) + \ (1 - self.ph) * (self.gamma * self.state_values[s - a]) action_returns.append(ret) # 使用当前策略选择动作 current_a = int(self.policy[s]) if current_a == 0 and s < self.goal: # 初始策略为0,需要处理 current_a = actions[0] self.policy[s] = current_a self.state_values[s] = action_returns[actions.tolist().index(current_a)] delta = np.abs(self.state_values - old_values).max() if delta <= self.theta: break def policy_improvement(self): policy_stable = True for s in self.states[1:self.goal]: old_a = self.policy[s] actions = np.arange(min(s, self.goal - s)) + 1 action_returns = [] for a in actions: ret = self.ph * (self.gamma * self.state_values[s + a]) + \ (1 - self.ph) * (self.gamma * self.state_values[s - a]) action_returns.append(ret) # 选择回报最大的动作 max_a = actions[np.argmax(np.round(action_returns, 5))] self.policy[s] = max_a if old_a != max_a: policy_stable = False return policy_stable def solve(self): while True: self.policy_evaluation() if self.policy_improvement(): break

关键点说明:

  • policy_evaluation通过迭代更新状态值函数,直到变化小于阈值theta
  • policy_improvement根据当前值函数选择最优动作
  • solve方法交替执行上述两个步骤直到策略稳定

3. 值迭代算法实现

值迭代将策略评估和策略改进合并为一个步骤,直接更新最优值函数:

class ValueIteration: def __init__(self, goal=100, proba_h=0.4, theta=1e-9, gamma=1): self.ph = proba_h self.gamma = gamma self.goal = goal self.theta = theta self.states = np.arange(goal + 1) self.state_values = np.zeros(goal + 1) self.state_values[goal] = 1.0 self.policy = np.zeros(goal + 1) self.sweeps_history = [] def value_iteration(self): while True: old_values = self.state_values.copy() self.sweeps_history.append(old_values) for s in self.states[1:self.goal]: actions = np.arange(min(s, self.goal - s)) + 1 action_returns = [] for a in actions: ret = self.ph * (self.gamma * self.state_values[s + a]) + \ (1 - self.ph) * (self.gamma * self.state_values[s - a]) action_returns.append(ret) # 直接取最大值作为新状态值 self.state_values[s] = np.max(action_returns) delta = np.abs(self.state_values - old_values).max() if delta <= self.theta: break def derive_policy(self): for s in self.states[1:self.goal]: actions = np.arange(min(s, self.goal - s)) + 1 action_returns = [] for a in actions: ret = self.ph * (self.gamma * self.state_values[s + a]) + \ (1 - self.ph) * (self.gamma * self.state_values[s - a]) action_returns.append(ret) # 选择最优动作 self.policy[s] = actions[np.argmax(np.round(action_returns, 5))] def solve(self): self.value_iteration() self.derive_policy()

与策略迭代的主要区别:

  • 每次直接更新为最优值(取max),而不是当前策略下的期望值
  • 值收敛后才一次性推导出策略

4. 结果分析与可视化

实现算法后,我们比较ph=0.4和ph=0.55两种情况下的策略差异:

def plot_results(ph, title): # 策略迭代 pi = PolicyIteration(proba_h=ph) pi.solve() # 值迭代 vi = ValueIteration(proba_h=ph) vi.solve() plt.figure(figsize=(12, 8)) # 绘制策略 plt.subplot(2, 2, 1) plt.step(pi.states, pi.policy, where='post') plt.title(f'Policy Iteration (ph={ph})') plt.xlabel('Capital') plt.ylabel('Optimal stake') plt.subplot(2, 2, 2) plt.step(vi.states, vi.policy, where='post') plt.title(f'Value Iteration (ph={ph})') plt.xlabel('Capital') plt.ylabel('Optimal stake') # 绘制值函数 plt.subplot(2, 2, 3) plt.plot(pi.states, pi.state_values) plt.title('State Values (PI)') plt.xlabel('Capital') plt.ylabel('Value estimate') plt.subplot(2, 2, 4) plt.plot(vi.states, vi.state_values) plt.title('State Values (VI)') plt.xlabel('Capital') plt.ylabel('Value estimate') plt.tight_layout() plt.show() plot_results(0.4, "ph=0.4") plot_results(0.55, "ph=0.55")

关键发现:

  1. 当ph=0.4(劣势赌局)时,两种算法都建议保守策略,只在特定资本时下注较大金额
  2. 当ph=0.55(优势赌局)时,最优策略变得更激进,建议更大胆的下注
  3. 值迭代收敛更快,但策略迭代的策略变化过程更平滑

5. 算法对比与工程实践

在实际应用中,两种算法各有优劣:

特性策略迭代值迭代
收敛速度较慢较快
每次迭代计算量较大较小
中间结果可用性每次迭代都有完整策略只有最终策略
实现复杂度较高较低
适合场景需要中间策略/策略变化平缓只需最终结果/快速原型开发

工程优化建议:

  1. 向量化计算:将内部循环改为矩阵运算
# 替代原来的for循环 returns = ph * values[s + actions] + (1 - ph) * values[s - actions]
  1. 并行化:使用多进程处理状态更新
  2. 早期终止:检测策略是否早停滞
  3. 日志记录:保存每次迭代变化用于调试

常见问题解决:

  • 振荡问题:适当减小学习率或增加theta值
  • 收敛慢:检查奖励设置和折扣因子
  • 内存不足:使用稀疏矩阵表示大状态空间
# 示例:带收敛诊断的改进版值迭代 def value_iteration_enhanced(max_iter=1000): for i in range(max_iter): old_values = values.copy() for s in states[1:GOAL]: # ... 更新逻辑 ... delta = np.abs(values - old_values).max() if delta < theta: print(f"Converged at iteration {i}") break elif i % 10 == 0: print(f"Iter {i}, delta={delta:.4f}")

通过这个完整的实现案例,我们不仅掌握了MDP两种基本算法的编程技巧,还深入理解了它们在策略形成上的差异。建议读者尝试修改参数(如ph、gamma)或奖励函数,观察策略如何随之变化,这是巩固MDP概念的最佳方式。

http://www.zskr.cn/news/1429545.html

相关文章:

  • 工程洗车台选型避坑指南:从“会喷水”到真有效,这三点经常被忽略 - 品牌优选官
  • 告别ImageNet标注!用DINO+ViT在无标签数据上实现80%+准确率的保姆级复现教程
  • #三清侠# 最近发现一个超有安全感的“新侠客”[特殊字符]
  • YOLO训练翻车?可能是你的TXT标注文件‘回炉’没做好!手把手教你TXT转回Labelme JSON
  • 大语言模型如何“认识”你:从原理到个人数字身份监控实践
  • ABB 011865-003 3/8NPT 内外丝 90° 黄铜弯头
  • 2026 中央电教馆美术教育指导教师证书详解|职业前景、报考流程、官方报名渠道推荐、证书含金量等问题一站式解答 - 教育官方推荐官
  • Gemini隐私政策不是法律文件,而是信任协议——用可验证隐私(VP)框架重构起草逻辑(含零知识证明集成示例)
  • 基于OpenCV与Mediapipe的手势识别:实现石头剪刀布人机对战
  • 3D视觉赋能新能源补能无人化:自动充电 / 换电 / 加氢场景技术落地解析
  • 牛顿迭代算法及使用条件
  • 技术风险管理实战解析与核心技术落地指南
  • 校园失物招领系统|基于Spring boot+vue的校园失物招领系统设计与实现(源码+数据库+文档)
  • Mac mini缺货涨价,无头MacBook重出江湖成AI新宠!养虾还有啥靠谱选择?
  • 外卖订餐小程序|基于java微信小程序的外卖订餐系统设计与实现(源码+数据库+文档)
  • WinDirStat:终极磁盘空间分析神器,快速释放Windows存储空间
  • AI搜索隐私生死线:从查询脱敏到结果缓存,7个被99%用户忽略的泄露入口,及3步零配置加固方案
  • AI工具安全红线清单:3类数据泄露场景、4层防护机制、1套GDPR/等保2.0合规自查表
  • 电路设计融入生活创意:从工作坊实践到智能家居应用
  • HS2-HF Patch终极指南:三分钟解锁Honey Select 2完整汉化与功能增强
  • 从零构建可复现研究叙事(Gemini+Zotero+Overleaf闭环):中科院团队实测,投稿周期压缩至11.3天
  • 保姆级教程:用CMake快速集成CSerialPort 4.3.x到你的C++项目(附完整代码)
  • Python脚本录制与回放:Appium Inspector搭配网易MuMu模拟器快速生成自动化测试代码
  • Scarab:空洞骑士模组管理的终极智能解决方案
  • 为何Synology Drive Client不能同步?
  • RPG Maker MV插件宝库:300+插件让你的游戏开发效率翻倍
  • 多功能低温性能测定仪常见故障分析与解决方法
  • 胖头鱼的技术专栏-430 国产数据库的下半场:固疆也须扩土(20260529)
  • Unity 2021+ 开发者的福音:用这个Editor脚本告别Ctrl+S后的漫长编译等待
  • Lovable区块链平台治理模块逆向工程:Governance Token经济学模型与投票延迟根因分析(仅限首批内测伙伴解密版)