当前位置: 首页 > news >正文

NiN模型

NiN模型

import torch
from torch import nn
from d2l import torch as d2l
def NiN_block(in_chanels,out_chanels,kernel_size,padding,stride):#NiN块return nn.Sequential(nn.Conv2d(in_chanels,out_chanels,kernel_size,padding=padding,stride=stride),nn.ReLU(),nn.Conv2d(out_chanels,out_chanels,kernel_size=1),nn.ReLU(),nn.Conv2d(out_chanels,out_chanels,kernel_size=1),nn.ReLU())#NiN网络
net=nn.Sequential(NiN_block(1,96,11,stride=4,padding=0),nn.MaxPool2d(kernel_size=3,stride=2),NiN_block(96,256,kernel_size=5,padding=2,stride=1),nn.MaxPool2d(kernel_size=3,stride=2),NiN_block(256,384,kernel_size=3,padding=1,stride=1),nn.MaxPool2d(kernel_size=3,stride=2),nn.Dropout(p=0.5),NiN_block(384,10,3,1,1),#输出通道最终为10,因为等会要用于数字0~9分类nn.AdaptiveAvgPool2d((1,1)),nn.Flatten()
)
X=torch.rand((1,1,224,224))
for layer in net:X=layer(X)print(layer.__class__.__name__,X.shape)

 解释分析:

nn.AdaptiveAvgPool2d((1, 1)) 是 PyTorch 中的自适应平均池化层,它的作用是将输入的任意尺寸的特征图,通过平均池化操作,固定输出为 (1, 1) 大小的特征图(即高和宽都为 1)。

具体解释:

  1. 自适应(Adaptive)
     
    与普通的 nn.AvgPool2d 不同,它不需要手动指定池化核的大小(kernel_size)和步长(stride),而是直接指定输出特征图的尺寸。PyTorch 会自动计算所需的池化核大小和步长,以确保输出符合指定尺寸。
  2. 参数 (1, 1)
     
    表示输出特征图的高和宽都为 1。例如:
    • 如果输入是形状为 (N, C, H, W) 的特征图(N 是批量大小,C 是通道数,H 是高,W 是宽),
    • 经过 nn.AdaptiveAvgPool2d((1, 1)) 后,输出形状会变为 (N, C, 1, 1)
  3. 在你的代码中的作用
     
    在 NiN 网络中,最后一个 NiN_block 的输出通道数是 10(对应 10 个类别),假设此时特征图形状为 (N, 10, H, W)(例如经过前面的层后,H 和 W 可能是 7 左右)。
     
    通过 nn.AdaptiveAvgPool2d((1, 1)) 后,特征图会被压缩为 (N, 10, 1, 1),再经过 nn.Flatten() 展平为 (N, 10),正好对应 10 个类别的输出,可直接用于分类任务(如计算交叉熵损失)。
 
简单说,这个层的核心作用是 **“压缩空间维度,保留通道信息”**,方便后续将特征图转换为分类所需的向量形式。

二.训练NiN网络

import torch
from torch import nn
from d2l import torch as d2l
def NiN_block(in_chanels,out_chanels,kernel_size,padding,stride):return nn.Sequential(nn.Conv2d(in_chanels,out_chanels,kernel_size,padding=padding,stride=stride),nn.ReLU(),nn.Conv2d(out_chanels,out_chanels,kernel_size=1),nn.ReLU(),nn.Conv2d(out_chanels,out_chanels,kernel_size=1),nn.ReLU())
net=nn.Sequential(NiN_block(1,96,11,stride=4,padding=0),nn.MaxPool2d(kernel_size=3,stride=2),NiN_block(96,256,kernel_size=5,padding=2,stride=1),nn.MaxPool2d(kernel_size=3,stride=2),NiN_block(256,384,kernel_size=3,padding=1,stride=1),nn.MaxPool2d(kernel_size=3,stride=2),nn.Dropout(p=0.5),NiN_block(384,10,3,1,1),#输出通道最终为10,因为等会要用于数字0~9分类nn.AdaptiveAvgPool2d((1,1)),nn.Flatten()
)
batch_size=128#批量数
train_iter,test_iter=d2l.load_data_fashion_mnist(batch_size,resize=224)
lr=0.05#学习率
nums_epochs=10#学习10代
d2l.train_ch6(net,train_iter,test_iter,nums_epochs,lr,d2l.try_gpu())#这个函数封装在d2l(本书的一个包)
 
http://www.zskr.cn/news/22600.html

相关文章:

  • 可能是 ICPC2025 西安站游记
  • 知识学报:DP(1)
  • Git SSH 推送完整流程总结
  • 运筹学奖学金项目促进科研多元化发展
  • 非托管内存怎么计算?
  • ubuntu配置镜像源和配置containerd安装源
  • 【题解】CF2086C Disappearing Permutation
  • 5-互评-OO之接口-DAO模式代码阅读及应用
  • PWN手的成长之路-18-ciscn_2019_ne_5-rettext
  • 3.springboot-容器机制-@注解
  • 日志分析-windows日志分析base
  • 课后作业3
  • KMP和Manacher
  • 索引有什么作用?
  • LinuxC++——etcd-cpp-api精简源代码函数参数查询参考 - 教程
  • mongoDB体验
  • TELUS如何通过Google技术栈实现业务增长与生产力跃升
  • 你的程序为何卡顿?从LINUX I/O三大模式寻找答案
  • 日总结 13
  • 题解:P8019 [ONTAK2015] OR-XOR
  • DP 思维好题(转载)
  • python sse的是什么?
  • 万字长文详述单据引擎原理、流程、单据管理 - 智慧园区
  • 【比赛记录】2025NOIP 冲刺模拟赛合集I
  • 12 继承--instanceof和类型转换
  • CSDN Markdown 编辑器快捷键大全 - 实践
  • Java了解
  • NVIDIA Jetson AGX Xavier刷机教程
  • 洛谷p1462-通往奥格瑞码道路
  • AI安全新威胁:提示注入与模型中毒攻击深度解析