当前位置: 首页 > news >正文

pytorch第66页

点击查看代码
import torch
from torch import optim, nn
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import torch.nn.functional as F
import numpy as np
from sklearn.metrics import classification_report
from PIL import Image
import time
from matplotlib import pyplot as pltdevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')#数据加载
transform_train = transforms.Compose([transforms.RandomCrop(32, padding=4),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])transform_test = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])# 加载CIFAR-10数据集
def load_data():train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=2)test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False, num_workers=2)return train_loader, test_loader, test_dataset#定义MYVGG模型
class MYVGG(nn.Module):def __init__(self, num_classes=10):super(MYVGG, self).__init__()self.features = nn.Sequential(# Block 1nn.Conv2d(3, 64, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.Conv2d(64, 64, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.MaxPool2d(2, 2),# Block 2nn.Conv2d(64, 128, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.Conv2d(128, 128, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.MaxPool2d(2, 2),# Block 3nn.Conv2d(128, 256, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.Conv2d(256, 256, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.MaxPool2d(2, 2),# Block 4nn.Conv2d(256, 512, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.Conv2d(512, 512, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.MaxPool2d(2, 2),# Block 5nn.Conv2d(512, 512, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.Conv2d(512, 512, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.MaxPool2d(2, 2),)self.classifier = nn.Sequential(nn.Dropout(0.5),nn.Linear(512, num_classes))def forward(self, x):x = self.features(x)x = x.view(x.size(0), -1)x = self.classifier(x)return x#训练函数
model = MYVGG().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)def train(model, train_loader, criterion, optimizer, epoch_num=50):model.train()train_loss = []train_acc = []for epoch in range(epoch_num):start_time = time.time()running_loss = 0.0current = 0total = 0for batch_idx, (data, target) in enumerate(train_loader):data, target = data.to(device), target.to(device)output = model(data)loss = criterion(output, target)optimizer.zero_grad()loss.backward()optimizer.step()running_loss += loss.item()_, predicted = torch.max(output.data, 1)total += target.size(0)current += (predicted == target).sum().item()epoch_loss = running_loss / len(train_loader)epoch_acc = 100.0 * current / totaltrain_loss.append(epoch_loss)train_acc.append(epoch_acc)end_time = time.time()print(f"Epoch [{epoch+1}/{epoch_num}], Loss: {epoch_loss:.4f}, Acc: {epoch_acc:.2f}%, Time: {end_time-start_time:.2f}s")return train_loss, train_acc#测试函数
def test(model, test_loader):model.eval()all_pred = []all_label = []with torch.no_grad():for data, target in test_loader:data, target = data.to(device), target.to(device)outputs = model(data)_, predicted = torch.max(outputs.data, 1)all_pred.extend(predicted.cpu().numpy())all_label.extend(target.cpu().numpy())all_pred = np.array(all_pred)all_label = np.array(all_label)accuracy = (all_pred == all_label).mean()accuracy = 100.0 * accuracyprint(f'测试准确率: {accuracy:.4f}%')print("分类效果评估:")target_names = [str(i) for i in range(10)]report = classification_report(all_label, all_pred, target_names=target_names)print(report)if __name__ == '__main__':print(f"24信计2班 佘婷婷 2024310143102")print(f"device:{device}")epoch_num = 20train_loader, test_loader, test_dataset = load_data()train_loss, train_acc = train(model, train_loader, criterion, optimizer, epoch_num)test(model, test_loader)#绘制结果plt.figure(figsize=(10, 5))plt.subplot(1, 2, 1)plt.plot(range(1, epoch_num+1), train_loss)plt.title("Training Loss")plt.xlabel("Epoch")plt.ylabel("Loss")plt.subplot(1, 2, 2)plt.plot(range(1, epoch_num+1), train_acc)plt.title("Training Accuracy")plt.xlabel("Epoch")plt.ylabel("Accuracy (%)")plt.tight_layout()plt.show()

image

http://www.zskr.cn/news/27370.html

相关文章:

  • 有什么指标可以判断手机是否降频
  • 实用指南:Linux内核kallsyms符号压缩与解压机制
  • 从埋点到用户行为分析:ClkLog 如何帮助企业读懂用户
  • 深入解析:领码方案 | 掌控研发管理成熟度:从理论透视到AI驱动的实战进阶
  • 函数的高级
  • C#实现OPC客户端
  • 卷积神经网络的读后感
  • Calibre 8.11技术拆解:AI集成与二次开发的实战指南 - 教程
  • 5G企业应用的七大场景与商业机遇
  • 类的多态(Num020) - 实践
  • 数据类型,二元运算符,自动类型提升规则,关系运算,取余模运算
  • WPF使用MediaCapture开发相机应用(四、相机录视频)
  • 2025年10月中国婚姻家事与财富管理律师评价榜:五强评测
  • Timing Signoff 技术精要
  • 02-GPIO-铁头山羊STM32标准库新版笔记
  • 读书笔记:白话解读Oracle范围分区
  • Oracle故障处理:10G RAC srvctl注册实例正常,但是crs切不能管理实例
  • 详细介绍:资产信息收集与指纹识别:HTTPX联动工具实战指南
  • 易基因:剑桥大学团队利用微量WGBS等揭示DNMT3L在胎盘发育中的DNA甲基化调控机制:CSC(IF20.5)
  • 102302134陈蔡裔数据采集第一次作业
  • 2025吹塑机厂家权威推荐:鼎浩包装科技实力企业,专业定制高效生产方案
  • 2025年10月注册公司推荐:权威榜对比评测
  • 神经网络之激活函数Softmax - 教程
  • 2025年10月工业洗地机厂家推荐:权威评测榜与全维度对比指南
  • 2025年10月上海装修公司评测榜:五强性价比与资质全对比
  • 2025年10月洗碗机品牌评测:海信领衔榜单
  • 2025年10月代理记账公司推荐:权威榜单对比全维度评测
  • OOP-实验2
  • 2025年10月小红书代理商推荐榜:官方授权与实战案例对比评测
  • 2025年10月西安种植牙医院推荐榜:五强对比评测