当前位置: 首页 > news >正文

第七章 手写数字识别V4

# 优化:
# 增加父类Module,输出每层信息
# 增加ReLU类,Tanh类
# 增加Dropout类,随机失活,防止过拟合,提高泛化能力
# 增加Parameter类,保存权重和梯度# 导入必要的库
import numpy as np
import os
import struct# 定义导入函数
def load_images(path):with open(path, "rb") as f:data = f.read()magic_number, num_items, rows, cols = struct.unpack(">iiii", data[:16])return np.asanyarray(bytearray(data[16:]), dtype=np.uint8).reshape(num_items, 28, 28)def load_labels(file):with open(file, "rb") as f:data = f.read()return np.asanyarray(bytearray(data[8:]), dtype=np.int32)# 定义sigmoid函数
def sigmoid(x):result = np.zeros_like(x)positive_mask = x >= 0result[positive_mask] = 1 / (1 + np.exp(-x[positive_mask]))negative_mask = x < 0exp_x = np.exp(x[negative_mask])result[negative_mask] = exp_x / (1 + exp_x)return result# 定义softmax函数
def softmax(x):max_x = np.max(x, axis=-1, keepdims=True)x = x - max_xex = np.exp(x)sum_ex = np.sum(ex, axis=1, keepdims=True)result = ex / sum_exresult = np.clip(result, 1e-10, 1e10)return result# 定义独热编码函数
def make_onehot(labels, class_num):result = np.zeros((labels.shape[0], class_num))for idx, cls in enumerate(labels):result[idx, cls] = 1return result# 定义dataset类
class Dataset:def __init__(self, all_images, all_labels):self.all_images = all_imagesself.all_labels = all_labelsdef __getitem__(self, index):image = self.all_images[index]label = self.all_labels[index]return image, labeldef __len__(self):return len(self.all_images)# 定义dataloader类
class DataLoader:def __init__(self, dataset, batch_size, shuffle=True):self.dataset = datasetself.batch_size = batch_sizeself.shuffle = shuffleself.idx = np.arange(len(self.dataset))def __iter__(self):# 如果需要打乱,则在每个 epoch 开始时重新排列索引if self.shuffle:np.random.shuffle(self.idx)self.cursor = 0return selfdef __next__(self):if self.cursor >= len(self.dataset):raise StopIteration# 使用索引来获取数据batch_idx = self.idx[self.cursor : min(self.cursor + self.batch_size, len(self.dataset))]batch_images = self.dataset.all_images[batch_idx]batch_labels = self.dataset.all_labels[batch_idx]self.cursor += self.batch_sizereturn batch_images, batch_labels# 定义Module类
class Module:  # 父类Moduledef __init__(self):self.info = "Module:\n"def __repr__(self):return self.info# 定义Parameters类
class Parameters:def __init__(self, weight):self.weight = weightself.grad = np.zeros_like(weight)# 定义linear类
class Linear(Module):def __init__(self, in_features, out_features):super().__init__()  # 调用父类初始化方法self.info += f"**    Linear({in_features}, {out_features})"  # 打印信息self.W = Parameters(np.random.normal(0, 1, size=(in_features, out_features)))self.B = Parameters(np.random.normal(0, 1, size=(1, out_features)))def forward(self, x):self.x = xreturn np.dot(x, self.W.weight) + self.B.weightdef backward(self, G):self.W.grad = np.dot(self.x.T, G)self.B.grad = np.mean(G, axis=0, keepdims=True)self.W.weight -= lr * self.W.gradself.B.weight -= lr * self.B.gradreturn np.dot(G, self.W.weight.T)# 定义Sigmoid类
class Sigmoid(Module):def __init__(self):super().__init__()self.info += "**    Sigmoid()"  # 打印信息return self.infodef forward(self, x):self.result = sigmoid(x)return self.resultdef backward(self, G):return G * self.result * (1 - self.result)# 定义Tanh类
class Tanh(Module):def __init__(self):super().__init__()self.info += "**    Tanh()"  # 打印信息def forward(self, x):self.result = 2 * sigmoid(2 * x) - 1return self.resultdef backward(self, G):return G * (1 - self.result**2)# 定义Softmax类
class Softmax(Module):def __init__(self):super().__init__()self.info += "**    Softmax()"  # 打印信息def forward(self, x):self.p = softmax(x)return self.pdef backward(self, G):G = (self.p - G) / len(G)return G# 定义ReLU类
class ReLU(Module):def __init__(self):super().__init__()self.info += "**    ReLU()"  # 打印信息def forward(self, x):self.x = xreturn np.maximum(0, x)def backward(self, G):grad = G.copy()grad[self.x <= 0] = 0return grad# 定义Dropout类
class Dropout(Module):def __init__(self, p=0.3):super().__init__()self.info += f"**    Dropout(p={p})"  # 打印信息self.p = pdef forward(self, x):r = np.random.rand(*x.shape)  # 矩阵r与x的shape相同,值在0-1之间随机生成self.nagtive = r < self.px[self.nagtive] = 0return xdef backward(self, G):G[self.nagtive] = 0return G# 定义ModelList类
class ModelList:def __init__(self, layers):self.layers = layersdef forward(self, x):for layer in self.layers:x = layer.forward(x)return xdef backward(self, G):for layer in self.layers[::-1]:G = layer.backward(G)def __repr__(self):info = ""for layer in self.layers:info += layer.info + "\n"return info# 主函数
if __name__ == "__main__":# 加载训练集图片、标签train_images = (load_images(os.path.join("Python", "NLP basic", "data", "minist", "train-images.idx3-ubyte"))/ 255)train_labels = make_onehot(load_labels(os.path.join("Python", "NLP basic", "data", "minist", "train-labels.idx1-ubyte")),10,)# 加载测试集图片、标签dev_images = (load_images(os.path.join("Python", "NLP basic", "data", "minist", "t10k-images.idx3-ubyte"))/ 255)dev_labels = load_labels(os.path.join("Python", "NLP basic", "data", "minist", "t10k-labels.idx1-ubyte"))# 设置超参数epochs = 10lr = 0.08  # V2版本调整了学习率batch_size = 200# 展开图片数据train_images = train_images.reshape(60000, 784)dev_images = dev_images.reshape(-1, 784)# 调用dataset类和dataloader类train_dataset = Dataset(train_images, train_labels)train_dataloader = DataLoader(train_dataset, batch_size)dev_dataset = Dataset(dev_images, dev_labels)dev_dataloader = DataLoader(dev_dataset, batch_size)# 定义模型model = ModelList([Linear(784, 512),ReLU(),Dropout(0.2),Linear(512, 256),Tanh(),Dropout(0.1),Linear(256, 10),Softmax(),])print(model)# 训练集训练过程for e in range(epochs):for x, l in train_dataloader:# 前向传播x = model.forward(x)# 计算损失loss = -np.mean(l * np.log(x))# 反向传播G = model.backward(l)# 验证集验证并输出预测准确率right_num = 0for x, batch_labels in dev_dataloader:x = model.forward(x)pre_idx = np.argmax(x, axis=-1)  # 预测类别right_num += np.sum(pre_idx == batch_labels)  # 统计正确个数acc = right_num / len(dev_images)  # 计算准确率print(f"Epoch {e}, Acc: {acc:.4f}")

image

http://www.zskr.cn/news/12946.html

相关文章:

  • 30.Linux DHCP 服务器 - 详解
  • 题解:QOJ9619/洛谷13568 [CCPC 2024 重庆站] 乘积,欧拉函数,求和(数论+状压DP)
  • Pytest+requests进行接口自动化测试6.0(Jenkins) - 指南
  • 解析01背包 - 教程
  • 电脑显示器黑屏(闪烁:隔几秒中黑一两秒),向日葵远程正常——DeepSeek问答
  • 消息队列Apache Kafka教程 - 指南
  • 9.21~9.27 周总结
  • 原码 反码 补码
  • Spring Framework 远程命令执行漏洞
  • python基本脚本要素
  • pip安装依赖包报错内容为User defined options,Native files 如何解决
  • edu 107 E(概率期望, dp)
  • Spring MVC的双向数据绑定
  • STM32定时器(寄存器与HAL库实现) - 实践
  • 微前端中iframe集成方式与应用微前端框架方式对比
  • 2025黄鹤杯线上wp
  • 一条频率信道是什么?
  • Unigine整合Myra UI Library全纪录(3):整合与优化
  • 实用指南:AI 时代的安全防线:国产大模型的数据风险与治理路径
  • 写给自己的年终复盘以及未来计划
  • 白居易-那个寒冷的夜晚,思念像潮水般袭来。想得家中夜深坐,还应说着远行人。
  • Metasploit Framework 6.4.90 (macOS, Linux, Windows) - 开源渗透测试框架
  • 秦岭迎来大丰收,徒步才能抵达的村庄,藏着有钱难买的山货!
  • 那些诗词那些花|君不见此玫瑰于晚秋的夜色中凄然绽放,别具一格。
  • Apache Doris性能优化全解析:慢查询定位与引擎深度调优 - 教程
  • 秋风中的窘境,一代诗圣的安居梦
  • 辛弃疾:明月团团高树影,十里水沉烟冷
  • MCP协议:重构AI协作的未来,打破模型边界的技术革命! - 详解
  • Go与C# 谁才更能节省内存? - 详解
  • shiro反序列化及规避检测