当前位置: 首页 > news >正文

模型训练中 平均损失值和平均准确率的深入理解

aver_loss

总损失的计算

对于求平均损失来说 需要先求总损失
而求总损失 就需要求一个批次中的损失

对于一个bs来说 损失的计算是利用
loss=criterion(out,labels)计算得出
而criterion 使用的nn.crossentropy
得出来的损失值 已经是对这一个bs传入的所有样本取过平均值了
所以得出来的loss是当前bs的aver_loss

上面标亮的这段话 是求损失值的关键,也是后面两种方法的基础。
则total_loss+=loss 就计算出总损失了。

对于aver_loss 是可以有两种处理方式的。

方法一:累加“总损失”,最后除以“总样本数”

这是更精确、更标准的方法,也是 PyTorch 官方教程中常见的方式。

  1. 循环内的操作:

    Python

    running_loss += loss.item() * inputs.size(0)
    
    • loss.item():这是 PyTorch CrossEntropyLoss 默认返回的一个批次 (batch) 的平均损失

    • inputs.size(0):这是当前批次中的样本数量(也就是 batch_size)。

    • loss.item() * inputs.size(0):用“平均损失”乘以“样本数”,我们得到的实际上是这个批次的“总损失”(即损失值的加和)。

    • running_loss += ...:所以,running_loss 累加的是所有批次的总损失之和,也就是整个 epoch 见过的所有样本的损失总和

  2. 循环外的操作:

    Python

    epoch_loss = running_loss / dataset_size[phase]
    
    • 因为 running_loss所有样本的损失总和,所以我们理应除以所有样本的总数量 (dataset_size[phase]),来得到最精确的“平均到每个样本的损失”
  • 优点:这种方法可以精确地处理最后一个批次样本数不足的情况(当数据集总数不能被 batch_size 整除时),因为 inputs.size(0) 会自动适应最后一个批次的实际大小。

方法二:累加“平均损失”,最后除以“总批次数”(您提出的方式)

您的这个逻辑也是完全正确的!它代表了另一种计算思路。

  1. 要使用您的计算方法,循环内的操作应该是:

    Python

    running_loss += loss.item() 
    
    • 这里,我们累加的是每个批次的“平均损失”running_loss 最终会变成所有批次的平均损失之和
  2. 循环外的操作(如您所写):

    Python

    aver_loss = running_loss / len(dataloaders[phase])
    
    • 因为 running_loss所有批次平均损失的和,所以我们理应除以总的批次数 (len(dataloaders[phase])),来得到“每个批次的平均损失的平均值”
  • 优点:实现起来非常直观。

  • 微小缺点:当最后一个批次样本数不足时,它在计算最终平均值时,给予了这个不完整的批次的“平均损失”与其他完整批次相同的权重,理论上会引入微小的计算偏差。但在实践中,当数据集很大时,这点偏差几乎可以忽略不计。

accuracy

对于准确率来说,他是在每一个批次(bs)中
使用_,preds=torch.max(outputs,1)
torch.max的使用参考https://www.cnblogs.com/zhuzhucheng/p/19109039先求出分类的类别。
然后调用torch.sum(preds == labels.data) 求出正确预测的总数。
preds==labels.data 返回的是一个bool数组。
torch.sum则是把bool数组的true视为1 false视为0 求和 最后返回一个整数

在每一个epoch训练前定义需定义total_acc=0
在每个批次累加正确的数量到total_acc上
再一个轮次所有批次训练完后 total_acc/样本总数(也就是参考损失值计算中的dataset_size[phase]) 即为正确率。

http://www.zskr.cn/news/10825.html

相关文章:

  • 一篇了解 Git 运用方式
  • torch.max函数在分类问题中的使用 学习
  • react native 国际化 react-i18next 和 i18n,运用高级组件的形式。 - 指南
  • react性能优化
  • Gitee如何重塑中国开发者的代码托管体验
  • 模块化面向对象 2章
  • Debezium + Kafka + Flink/Doris Stream Load 实时数仓
  • 实用指南:【Makefile】Linux内核模块编译
  • Gitee DevOps平台:中国企业数字化转型的代码管理新范式
  • 幂运算与航班中转的奇妙旅行:探索算法世界的两极 - 实践
  • 论Linux安装后需要进行的配置
  • 51单片机-驱动DS1302时钟芯片模块教程 - 实践
  • 数组和链表读取、插入、删除以及查找的区别
  • 在K8S中,日志分析工具有哪些可以与K8S集群通讯?
  • 【2025最新教程】Claude Code国内使用_保姆级新手安装使用教程_最强AI编程工具
  • 如何计算sequence粒度的负载均衡损失 - 教程
  • P13885 [蓝桥杯 2023 省 Java/Python A] 反异或 01 串
  • 西电PCB设计指南第3章学习笔记
  • Vitrualbox、kali、metaspolitable2下载安装
  • llm入门环境
  • 借助Aspose.HTML控件,使用 Python 编辑 HTML
  • 汽车视频总线采集过程中,如何兼顾响应速度和可靠性?
  • 2025年十大好用网盘推荐:功能、口碑与性价比大对比
  • 使用 Ansible 批量安装 Docker
  • 二十一、DevOps:从零建设基于K8s的DevOps平台(二)
  • 新手项目经理如何选工具?2025年这5款上手快、不复杂的项目管理软件适合你
  • 用DiskGenius重新分区,检测出U盘虚标容量。
  • 2025低空经济时空信息平台
  • CF2147G
  • 全栈开发者效率工具图谱:从IDE到云服务的最优组合 - 指南