当前位置: 首页 > news >正文

卷积神经网络模型搭建(pytorch版)

一个最简单的神经网络

也叫多层感知机(MLP),只有线性层(也称全连接层)和激活函数,pytorch语法写的网络结构如下:

import torch
import torch.nn as nnclass QNetwork(nn.Module):def __init__(self, state_dim, action_dim, hidden_dim=128):super(QNetwork, self).__init__()self.net = nn.Sequential(nn.Linear(state_dim, hidden_dim),  # 第1层:输入→隐藏层nn.ReLU(),                          # 激活函数(非线性)nn.Linear(hidden_dim, hidden_dim), # 第2层:隐藏层→隐藏层nn.ReLU(),nn.Linear(hidden_dim, action_dim)  # 第3层:隐藏层→输出)def forward(self, x):return self.net(x)

class QNetwork(nn.Module):
● 继承自 torch.nn.Module,这是 PyTorch 中所有神经网络模块的基类。
● 通过继承,该网络能够自动管理参数、支持 GPU 迁移、实现训练/评估模式切换等。

注意:搭配官方pytorch教程事半功倍
torch.nn.Module 是Base class for all neural network modules.

构造函数init
def __init__(self, state_dim, action_dim, hidden_dim=128):
init构造函数通常用来声明网络架构层,像上面的代码就是声明了需要的线性层和激活函数

○ 使用 nn.Sequential 将多个层按顺序组合,方便前向传播。
○ 三层全连接网络(输入层不算在内):
■ 输入 → 第一隐藏层(state_dim → hidden_dim)+ ReLU
■ 第二隐藏层(hidden_dim → hidden_dim)+ ReLU
■ 输出层(hidden_dim → action_dim):无激活函数,因为 Q 值可以是任意实数。
○ 激活函数 ReLU 引入非线性,使网络能够拟合复杂的函数关系。

forward 方法
在 PyTorch 中,当你调用网络实例(例如 q_net(state))时,会自动执行 forward 函数。它是网络的核心计算逻辑。

def forward(self, x):return self.net(x)

● 定义输入 x 如何通过网络得到输出。
● 这里直接调用 self.net(x),因为 nn.Sequential 已经包含了所有层。
● x 通常是状态张量,形状一般为 (batch_size, state_dim),输出形状为 (batch_size, action_dim),即每个动作对应的 Q 值。
image

卷积神经网络

1. 线性层(全连接层)

nn.Linear(in_features, out_features, bias=True)
● in_features:输入特征数
● out_features:输出特征数
● bias:是否使用偏置(默认 True)
layer = nn.Linear(128, 64) # 输入128维,输出64维

m = nn.Linear(20, 30)
input = torch.randn(12, 12, 20)
output = m(input)
print(output.size())

注意:nn.Linear 总是作用于输入张量的最后一个维度。
无论输入是几维张量,nn.Linear(in, out) 只关心最后一维,其他维度全部透传(batch、seq_len 等都不会受影响)。
通常作为神经网络的尾部作为分类头

self.classifier = nn.Sequential(nn.Linear(512, 256),nn.ReLU(),nn.Dropout(0.5),nn.Linear(256, 10)    # 10分类
)

2. 卷积层

卷积图解
image
具体的pytorch实现卷积操作模拟

import torch
import torch.nn.functional as Finput = torch.tensor([[1, 2, 0, 3, 1],[0, 1, 2, 3, 1],[1, 2, 1, 0, 0],[5, 2, 3, 1, 1],[2, 1, 0, 1, 1]])
kernel = torch.tensor([[1, 2, 1],[0, 1, 0],[2, 1, 0]])input = torch.reshape(input, [1, 1, 5, 5])
kernel = torch.reshape(kernel, [1, 1, 3, 3])output = F.conv2d(input,kernel, stride=1, padding=0)
output = torch.reshape(output,[3, 3])
print(output)

class torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros', device=None, dtype=None)
具体链接https://docs.pytorch.org/docs/2.12/generated/torch.nn.functional.conv2d.html

怎么理解卷积核的输出shape的计算公式
image
滑了多少步?
输入宽度 W_in
卷积核宽度 kernel_size
每一步"吃掉" kernel_size 个格子
步与步之间间隔 stride
滑的步数 = (W_in - kernel_size) / stride
但这只是起始位置的数量,卷积核还要覆盖自身的宽度,所以输出宽度 = 步数 + 1(最后一个位置还能放一个卷积核):image

可能不是整除strde,所以向下取整

3. 最大池化

最大池化是下采样的一种方式
pytorch定义如下
class torch.nn.MaxPool2d(kernel_size, stride=None, padding=0, dilation=1, return_indices=False, ceil_mode=False)

ceil_mode (bool) – when True, will use ceil instead of floor to compute the output shape
image
最大池化torch操作演示

# 最大池化import torch
import torch.nn as nninput = torch.tensor([[1, 2, 0, 3, 1],[0, 1, 2, 3, 1],[1, 2, 1, 0, 0],[5, 2, 3, 1, 1],[2, 1, 0, 1, 1]])class MyNet(nn.Module):def __init__(self):super(MyNet, self).__init__()self.maxPool1 = nn.MaxPool2d(kernel_size=3,ceil_mode=True)def forward(self, x):x = self.maxPool1(x)return xinput = torch.reshape(input,[1, 1, 5, 5])
mynet = MyNet()
output = mynet(input)
print(output)

注意:
(N, C, H, W) nn.MaxPool2d 要求的输入格式是 4维
↑ ↑ ↑ ↑
批次 通道 高 宽

4. 激活函数

激活函数为神经网络引入非线性,使其能够学习和表示复杂的函数映射。
常见的激活函数reLU,sigmod,softmax等
现代网络 90% 的情况用 ReLU 或 GELU,其他激活函数主要出现在特定场景。

m = nn.ReLU()
input = torch.randn(2)
output = m(input)

reLU
● 输出范围:[0, +∞)
● 优点:计算极快(一个 if 判断)、缓解梯度消失(正区间梯度恒为 1)
● 缺点:Dying ReLU — 负区间永远输出 0,梯度为 0,某些神经元可能"死掉"
● 用途:几乎所有现代网络默认首选
image

6. dropout 层

Dropout 是一种防止过拟合的正则化技术,通过在训练时随机"丢弃"部分神经元来工作。
具体做法
在训练过程中,以概率 随机地将输入张量的某些元素置零p。对于每个前向调用,零元素都是独立选择的,并且是从伯努利分布中采样的。每次前向呼叫时,每个通道都会独立地被清零。

class MyNet(nn.Module):def __init__(self):super().__init__()self.fc1 = nn.Linear(512, 256)self.dropout = nn.Dropout(p=0.5)    # p=0.5,丢弃50%def forward(self, x):x = self.fc1(x)x = self.relu(x)x = self.dropout(x)                  # 训练时有效果,测试时自动关闭return x

训练 vs 测试 的区别

训练阶段: 以概率 p 随机丢弃神经元,剩余的输出放大 1/(1-p)
测试阶段: 所有神经元都参与,不丢弃,但输出保持原样
为什么要放大?
因为训练时只有 (1-p) 的神经元在工作,期望输出变小了,测试时所有神经元都工作,所以要补偿回去。如下: image

为什么能防止过拟合?
强迫网络不依赖特定神经元。
没有 Dropout 时,网络可能学到"第 5 个神经元总是负责识别猫"——这叫协同适应(co-adaptation),某个神经元坏了网络就废了。
加 Dropout 后,第 5 个神经元可能被随机丢弃,网络被迫让多个神经元共同承担识别猫的任务,结果是泛化能力更强。
通常加在隐藏层后面,不加在输出层。

简单模型搭建

image
注意激活函数没有在图片中标出,卷积、全连接等线性层之后通常紧跟一个激活函数(如 ReLU),两者结合才构成一个完整的“非线性变换层”。


import torch
from torch.nn import Sequential, Conv2d, MaxPool2d, Flatten, Linear, Module, ReLUclass MaoNet(Module):def __init__(self):super(MaoNet, self).__init__()self.model1 = Sequential(Conv2d(3, 32, 5, padding=2),ReLU(),  # 激活函数MaxPool2d(2),Conv2d(32, 32, 5, padding=2),ReLU(),MaxPool2d(2),Conv2d(32, 64, 5, padding=2),ReLU(),MaxPool2d(2),Flatten(),Linear(1024, 64),ReLU(),Linear(64, 10))def forward(self, x):x = self.model1(x)return x
maonet = MaoNet()
input = torch.ones([64, 3, 32, 32])
output = maonet(input)
print(output.shape)
http://www.zskr.cn/news/1495700.html

相关文章:

  • TPM2-TSS快速入门:5步搭建可信计算开发环境
  • Audacity音频编辑神器:3大核心功能解决你的音频处理难题
  • 从一次信息泄露事件复盘:你的邮箱密码还在这些高危网站用吗?
  • Runtime昇腾运行时引擎深度解析:算子调度与执行管理的核心原理
  • 纪念币真假鉴别技巧!普通人在家就能查,杜绝高仿假货 - 深鉴新闻
  • CodeIsland与竞争对手对比:为什么它是AI编程助手监控的终极选择 [特殊字符]
  • 喜马拉雅音频离线神器:跨平台下载工具全面解析
  • 如何在Windows上安装安卓应用:APK安装器的完整指南
  • 卡梅德生物技术快报|纯化重组蛋白实操详解
  • Scala Pickling 源码解析:编译时生成与运行时反射的实现原理
  • 智能对话革命:ChatALL助你一站式管理所有AI助手
  • Finance-Python部署指南:生产环境配置与性能调优
  • 从SRResNet到SRGAN:一个ResNet块如何‘进化’成GAN,彻底改变图像超分的游戏规则
  • 雷达原理与系统基础教程
  • Win32 - 进程间通信(IPC)1
  • 上海寄快递哪家便宜?我对比了5家告诉你 - 快递物流资讯
  • 基于趋化因子CCL21与细胞因子IL-7协同作用的CAR-T细胞策略:从机制探索到实体瘤治疗应用
  • Week 3 -- Day 1:LangGraph 入门
  • 2025 Alpha活性助焊膏官方授权榜:爱法核心工艺领衔,五家高口碑品牌深度解析 - 品牌发掘
  • 完整指南:5步掌握Switch宝可梦ROM编辑器pkNX的核心技巧
  • Node.js 事件循环与异步调度:从单线程到高并发的底层机制,理解 libuv 的调度哲学
  • 从手动重复到智能自动化:Templater如何彻底改变你的Obsidian笔记体验
  • 如何设计一个幂等接口
  • 卡梅德生物科普:MAPT(微管相关蛋白Tau)
  • 神经渲染+GIS:当数字地球拥有“大脑”,未来已来!
  • 专业级磁盘健康监控实战指南:smartmontools 7.5深度解析
  • 5分钟搭建PUBG雷达系统:免费开源的游戏地图可视化工具终极指南
  • 5大技术突破:Midscene.js如何重新定义跨平台AI自动化测试
  • 2026年全自动绕线机厂家TOP榜:专用收线绕线机/精密绕线机/多功能绕线机源头厂家与技术创新推荐 - 企业推荐官【官方】
  • 成都2026瓷砖空鼓翘边拱起原因及解决办法 免砸砖快速修复 - 苏易房屋修缮