《Born》第9章:神经网络模块——从 Linear 到 Transformer Block
Born 的nn包提供了 20+ 原生模块,从最简单的全连接层到完整的 Transformer Block。所有模块都实现了统一的接口。
统一模块接口
typeModule[T DType,B Backend]interface{Forward(x Tensor[T,B])Tensor[T,B]Parameters()[]Tensor[T,B]To(B)Module[T,B]// 切换后端}基础模块
// 全连接层linear:=nn.NewLinear[float32](784,128,backend)// 卷积层conv:=nn.NewConv2D[float32](1,32,3,backend)// 输入1通道,输出32通道,核3×3// 激活函数relu:=nn.NewReLU[float32]()silu:=nn.NewSiLU[float32]()// 归一化layerNorm:=nn.NewLayerNorm[float32](128,backend)rmsNorm:=nn.NewRMSNorm[float32](128,backend)// LLaMA 使用容器模块
// Sequential:按顺序执行model:=nn.NewSequential[float32](nn.NewLinear[float32](784,128,backend),nn.NewReLU[float32](),nn.NewLinear[float32](128,10,backend),)output:=model.Forward(input)Transformer Block(LLaMA 风格)
typeTransformerBlock[T DType,B Backend]struct{Norm1*RMSNorm[T,B]// 预归一化Attention*GQA[T,B]// 分组查询注意力Norm2*RMSNorm[T,B]FFN*SwiGLUFFN[T,B]// SwiGLU 前馈网络}func(b*TransformerBlock[T,B])Forward(x Tensor[T,B])Tensor[T,B]{// 预归一化 + 残差连接h:=x.Add(b.Attention.Forward(b.Norm1.Forward(x)))// 前馈 + 残差连接out:=h.Add(b.FFN.Forward(b.Norm2.Forward(h)))returnout}这是 LLaMA/Mistral/DeepSeek 的核心架构。
参数初始化
// Xavier/Glorot 初始化(适用于 tanh/sigmoid)nn.Xavier(linear.Weight)// Kaiming 初始化(适用于 ReLU)nn.Kaiming(linear.Weight)📘 《Born》连载技术书,第 9/22 章。
