当前位置: 首页 > news >正文

吴恩达深度学习笔记第三周:手把手推导单隐层神经网络的前向与反向传播

吴恩达深度学习笔记第三周:手把手推导单隐层神经网络的前向与反向传播

在Coursera的深度学习专项课程中,吴恩达教授将神经网络的基础知识拆解为易于消化的模块。但对于真正想掌握其数学本质的学习者来说,仅观看视频远远不够。本文将以单隐层神经网络为例,带你用纸笔一步步推导矩阵运算与梯度计算的全过程,这是理解更复杂架构的必经之路。

1. 单隐层神经网络的结构拆解

一个标准的单隐层神经网络包含三层结构:输入层(第0层)、隐藏层(第1层)和输出层(第2层)。假设我们有:

  • 输入特征维度:3(即x₁, x₂, x₃)
  • 隐藏层神经元数量:4
  • 输出层神经元数量:1(二分类问题)

各层参数矩阵的维度必须严格对应:

  • W⁽¹⁾:隐藏层权重矩阵,维度为(4,3)
    行数=当前层神经元数,列数=前一层特征数
  • b⁽¹⁾:隐藏层偏置向量,维度为(4,1)
  • W⁽²⁾:输出层权重矩阵,维度为(1,4)
  • b⁽²⁾:输出层偏置标量,维度为(1,1)

初始化技巧:W应采用np.random.randn()*0.01进行微小随机初始化,b可初始化为零向量

2. 前向传播的矩阵运算推导

前向传播需要依次计算隐藏层和输出层的线性组合(z)与激活输出(a)。我们以单个样本为例:

2.1 隐藏层计算

z⁽¹⁾ = W⁽¹⁾·x + b⁽¹⁾ # (4,3)×(3,1)+(4,1) → (4,1) a⁽¹⁾ = g(z⁽¹⁾) # g为激活函数如ReLU

关键步骤验证:

  1. W⁽¹⁾·x执行矩阵乘法,每个隐藏神经元接收所有输入特征的加权和
  2. 偏置b⁽¹⁾逐元素添加到结果中
  3. 激活函数g(·)按元素作用,引入非线性

2.2 输出层计算

z⁽²⁾ = W⁽²⁾·a⁽¹⁾ + b⁽²⁾ # (1,4)×(4,1)+(1,1) → (1,1) a⁽²⁾ = σ(z⁽²⁾) # 二分类常用sigmoid

此时a⁽²⁾即为预测输出ŷ。对于m个样本的批量处理,只需将输入x扩展为矩阵X(每列一个样本),所有中间结果维度末尾增加m。

3. 反向传播的梯度计算原理

反向传播通过链式法则计算损失函数对各参数的偏导。假设使用交叉熵损失J,关键导数如下:

3.1 输出层梯度

dz⁽²⁾ = a⁽²⁾ - y # (1,1) dW⁽²⁾ = dz⁽²⁾·(a⁽¹⁾)ᵀ # (1,1)×(1,4) → (1,4) db⁽²⁾ = dz⁽²⁾ # (1,1)

3.2 隐藏层梯度

dz⁽¹⁾ = (W⁽²⁾)ᵀ·dz⁽²⁾ * g'(z⁽¹⁾) # (4,1)×(1,1) → (4,1) dW⁽¹⁾ = dz⁽¹⁾·xᵀ # (4,1)×(1,3) → (4,3) db⁽¹⁾ = dz⁽¹⁾ # (4,1)

其中g'(z⁽¹⁾)是激活函数的导数:

  • ReLU导数为1(z>0)或0(z≤0)
  • Sigmoid导数为a(1-a)

4. 激活函数的选择与实现

不同层的激活函数选择直接影响模型性能:

激活函数适用场景优点缺点
Sigmoid输出层(二分类)输出范围(0,1)易梯度消失
Tanh隐藏层零中心化计算量较大
ReLU隐藏层(默认)计算简单/缓解梯度消失负数区失效

实际编码时,激活函数及其导数可并行计算:

def relu(z): return np.maximum(0,z) def relu_derivative(z): return (z > 0).astype(float)

5. 向量化实现技巧

批量处理m个样本时,矩阵运算需特别注意维度对齐。以隐藏层为例:

Z⁽¹⁾ = W⁽¹⁾·X + b⁽¹⁾ # (4,3)×(3,m)+(4,1) → (4,m)

这里通过广播机制自动扩展b⁽¹⁾。反向传播时:

dW⁽²⁾ = (1/m) * dz⁽²⁾·(A⁽¹⁾)ᵀ # (1,m)×(m,4) → (1,4)

这种实现比循环快100倍以上。一个完整的训练迭代包含:

  1. 前向传播计算预测值
  2. 计算损失J
  3. 反向传播求梯度
  4. 梯度下降更新参数:
    W⁽¹⁾ -= α·dW⁽¹⁾ b⁽¹⁾ -= α·db⁽¹⁾

推导过程中最容易出错的是矩阵维度匹配。建议在编写代码前先手写验证维度变化,例如:

  • W⁽¹⁾·x的(4,3)×(3,1)确实得到(4,1)
  • 反向传播时(W⁽²⁾)ᵀ·dz⁽²⁾的(4,1)×(1,1)通过广播变为(4,1)
http://www.zskr.cn/news/1469610.html

相关文章:

  • AI工具如何重构排序逻辑:7个被90%团队忽略的智能排序性能拐点
  • 不用下载直接改!主流网盘在线编辑功能深度实测 - 品牌测评鉴赏家
  • 家用台式洗碗机实力品牌推荐榜单:GORGENOX歌嘉诺凭精工高性价比领跑,台式洗碗机、免安装洗碗机、超窄洗碗机、嵌入式美妆冰箱、台下嵌入式冰箱高口碑全解析 - 变量人生001
  • 实在Agent有没有针对开发者的个人终身免费版?2026开发者政策与企业级AI智能体演进深度评测
  • TIA Portal避坑指南:Get_Alarm指令读取ProDiag报警的5个常见错误与调试技巧
  • opencv识别抖音的评论区其实很简单
  • AcFunDown:你的A站视频离线收藏神器
  • 2026年委托公证最新办理方法有哪些?网上办公证流程 - GrowthUME
  • 北京京顺斋,天津全域上门收宝,让每一件藏品都有归处 - 深鉴新闻
  • AKM系列有铁芯直线电机:大推力与高刚性的精密驱动之选
  • AI辅助开发网络加密应用:让快马智能生成WebSocket安全通信代码
  • 3分钟找回Navicat密码:你的数据库连接救星工具
  • Cursor Free VIP技术解析:机器标识重置与账户管理机制深度剖析
  • 工程师自学三大误区:从目标分解到MVP思维,高效掌握嵌入式开发
  • 【AI伦理治理实战框架】:从0到1搭建企业级AI使用审计体系——含GDPR/网信办双标对照矩阵
  • 如何用uBlock Origin在5分钟内打造无广告、保护隐私的浏览体验
  • 2026年针织大圆机/纺织设备/针织布源头厂家推荐榜:高端机械与精湛工艺的全景解析及选购指南 - 品牌企业推荐师(官方)
  • 读水识鱼——钓鱼高手的必修课 - 教育信息速递
  • Linux 内核参数企业级优化(生产稳定调优)
  • 5个技巧让Windows Terminal成为你的终极命令行工作台
  • 从IMU预积分到VIO:手把手推导ESKF,并聊聊它为什么比EKF更适合SLAM
  • LSTM实战:基于快马平台生成智能古诗创作应用完整项目
  • 实测Win11Debloat:系统化优化Windows体验的完整解决方案
  • Windows平台APK安装三步法:零基础实现安卓应用无缝运行
  • 别急着换IDE!PIL的DecompressionBombWarning,用这3招在PyCharm里也能搞定大图拼接
  • MATLAB版CAN报文实时解析与工程值可视化工具
  • Flutter热更新原理与实现方法
  • 从零开始:如何用ReadCat打造你的专属数字书房
  • DVWA-Command Injection
  • 安装 Python 3.10+