当前位置: 首页 > news >正文

在 Rust 中从头开始训练 LLM

原文:towardsdatascience.com/training-llm-from-scratch-in-rust-03381bbd7204

在这篇配套文章中,我将展示我在 Rust 中从头开始训练一个类似 GPT 模型的实现。没有 GPU,只有 CPU,性能比原生 C 代码高出 30 倍。

https://github.com/OpenDocCN/towardsdatascience-blog-zh-2024/raw/master/docs/img/41aa162bccc173da174237588256a530.png

图片由 GoogleDeepMind 在 Unsplash 提供

在我上一篇文章medium.com/towards-data-science/writing-llms-in-rust-looking-for-an-efficient-matrix-multiplication-e9539b0cb9d3中,我介绍了矩阵乘法的问题,注意力算法如何使用矩阵乘法执行平均过程,以及如何高效地实现——至少对我来说——一个矩阵乘法函数在 Rust 中使用 Blas。

在这篇新文章中,我想展示我在 Rust 中实现 llm.c 的第一个构建块,即使用 Rust 从头开始训练一个类似 GPT 的模型。这是我学习更多关于 Rust 生态系统并了解它与 C 相比如何的方法。特别是,我希望我的代码能够从 GPT 权重开始,仅使用 CPU 训练类似 GPT 的模型—— 因此没有 GPU 或 TPU。我的目标是了解我们可以在简单的笔记本电脑上推动这些模型多远,以及 Rust 生态系统可以用于此的程度。最终,此代码也可能有助于使用给定的输入语料库微调 GPT 模型

所有相关的代码片段都可以在 这里 找到。

我希望每个人都能记住以下要点:

  • 如何在 Rust 中处理内存参数和 GPT 参数

  • 注意力前向是如何实现的

  • 如何利用 Rayon 的并行性以及线程锁

1. 处理所有参数和内存布局

类似 GPT 的模型具有大量的参数和张量:嵌入矩阵、层归一化参数、查询、键、值矩阵、注意力输出、前馈层输出等等。如果我们处理的是 PyTorch,所有这些都会自动包含在代码范式之中,因此无需担心独立的张量对象,或者这些张量如何适应内存。相反,在我们的 Rust 实现中,我们确实需要担心并管理所有这些参数。

所有参数都存储为向量Vec<f32>,单个向量,在内存中连续,因此我们可以有:

  • 一种简单的内存方法——我们可以提高缓存局部性,加快矩阵乘法操作、加载和参数保存过程;

  • 一个更简单的接口,可以与 Blas 的sgemm一起使用以实现更好的矩阵乘法;

  • 从磁盘进行简单的读写操作;

另一方面,我们必须为此付出代价,并与数组切片交互:

  • 我们必须知道每个数组的大小;

  • 我们需要谨慎处理切片操作,即正确选择张量的正确部分,以避免索引边界错误;

  • 对于每个参数,我们必须有一个offset变量。例如,wte占用vocab_size*channels个浮点数,所以params_memory[0..vocab_size*channels]。然后,下一个张量wpe将占用params_memory[vocab_size*channels..(vocab_size*channels + max_seq_len*channels)]

let wte=&amp;self.params_memory[0..vocab_size*channels];let wpe=&amp;self.params_memory[vocab_size*channels..(vocab_size*channels+max_seq_len*channels)];

总体而言,唯一的风险是正确切片params_memory数组。如果我们知道大小,我们就不会进行无效的内存访问,我们也在params_memory变量中有一个单一的真实来源。

核心功能,用于参数和内存分配,是gpt_build_from_checkpoint。从这一点,我们正在读取输入文件file.read_f32::<LittleEndian>()>LittleEndian用于读取和写入数字,无论是小端还是大端字节,直接到/从字节数组。然后创建参数如下:

model.params.wte=model.params_memory[offset..offset+model.param_size[0]].to_vec();offset+=model.param_sizes[0];

2a. 编码器前向,编码器后向

在构建类似 GPT 的模型时,需要考虑的第二步是创建单词和位置嵌入,即高维向量嵌入。这是通过encoder_forward实现的,它返回一个大小为[B, T, C]的激活张量。需要记住的一个重要事项是 B、T 和 C 维度的含义。输入数据被细分为块,或大小为 B 的批次。每个批次有一个块大小,T。假设取一个文本语料库,每个句子可能是一个批次,每个句子可能被分为大小为 T 的块。通道大小,C,是最“奇怪”的一个。我们的思维立刻转向图像处理领域,在那里我们有 R、G 和 B 通道。在我们的句子语料库中,通道参数是我们创建嵌入的维度。在我们的情况下C = 768。此参数直接在检查点文件中读取:

let(max_t,v,l,nh,c)=(model_header[2]asusize,model_header[3]asusize,model_header[4]asusize,model_header[5]asusize,model_header[6]asusize,);

特别是对于encoder_forward,我们正在处理一个包含我们标记 ID 集合的输入向量inp,形状为[B x T]。单词嵌入矩阵wte的大小为[V x C],其中V是词汇量大小,即模型可以表示的唯一标记数量,在我们的情况下是 50,000。另一方面,位置嵌入wpe的大小为[max_t x C],其中max_t是最大序列长度,即 1024。

为了适应这些值,我们正在使用 Rust 切片方法,例如在encoder_forward中:

let out_start_idx=b_idx*t*c+t_idx*c;//slicing let out_bt=&amp;mut out[out_start_idx..out_start_idx+c];//slicing let ix=inp[b_idx*t+t_idx]asusize;//take theinputvalues let wte_start_idx=ix*c;//slicing let wte_ix=&amp;wte[wte_start_idx..wte_start_idx+c];//take wte values let wpe_start_idx=t_idx*c;//slicing let wpe_t=&amp;wpe[wpe_start_idx..wpe_start_idx+c];//take wpe valuesforiin0..c{out_bt[i]=wte_ix[i]+wpe_t[i];}

要欣赏 Rust 切片,考虑B=2, T=3, C=4。这意味着输出长度为B x T x C = 24,因此:

out:[out[b=0,t=0,:],out[b=0,t=1,:],out[b=0,t=2,:],out[b=1,t=0,:],out[b=1,t=1,:],out[b=1,t=2,:]]

其中out[b, t, :]是 4 个元素。因此,对于b=1, t=2,即第二个批次和第三个标记,切片从out_start_idx = b_idx x t x c + t_ids x c = 1 x 3 x 4 + 2 x 4 = 20开始。

2b. 层归一化和注意力

GPT 类型的模型在其架构中有一个归一化步骤,这样我们就可以稳定训练并提高性能。LayerNorm 确保每个单独的层都有一个归一化的高斯分布。我们在通道维度C上对每个向量进行归一化,确保均值为零,方差为 1。

在实现layernorm_forward时,我们使用变量eps = 1e-5f32;以防止在计算1 / sqrt( var + eps )时除以零。

归一化完成后,我们可以开始讨论注意力层。在我们的代码中,注意力是一个多头自注意力:

  • 输入维度C被分割到每个头中,使用C/nh

  • 我们为每个标记计算查询Q、键K和值V

  • 我们计算注意力分数

  • 最后,我们根据注意力分数对值进行加权求和

值得记住注意力做了什么。考虑到我们有B, T, and C输入元素,我们想要做的是从输入字符串中取出最多T个标记,并让算法理解它们是如何“相互关联”的。例如,第五个标记应该只考虑它之前的标记,即第一个、第二个、第三个和第四个标记。这样,流始终从当前标记到上一个时间戳标记。

要了解所有标记之间的相互关联程度,我们只需要计算 t-th 标记有多少次可能与前一个 (t – 1)-th 标记连接。为了进行高效的平均,我们使用矩阵乘法技巧。特别是,我们使用三个向量来帮助我们进行平均。前两个向量QK是查询和键向量。查询向量回答的问题是:“我在寻找什么?”,而键向量回答的问题是:“我包含什么?”。现在,在KQ之间进行点积将返回这两个向量对齐的程度,即标记内容(我包含什么?)和标记关联(我在寻找什么?)之间的对齐。

为了使代码更高效,我们正在使用自注意力头。这意味着每个KQ向量都将具有大小B, T, head_size。计算返回一个大小为B, T, T的权重向量。这意味着权重将有B行,正如我们的批次一样多。对于每个批次,我们将有一个T x T的方阵,这是我们的标记的大小。因此,对于每个 t-th 行和 t-th 列的组合,我们将有一个“统计”权重,表示这两个标记一起出现的可能性。

最后一步是对权重向量W使用值向量V进行查询。值向量只是一个简单的线性神经网络层,它应用于输入标记。这直接跟在 softmax 处理之后。输出将具有大小B, T, head_size。我们在这里的挑战将是将每个头大小维度的所有通道维度连接起来。

让我们转到实际方面。我们注意力前向函数的输入项是:

  • out输出缓冲区,这将是一个大小为[B, T, C]的张量

  • preatt存储预 softmax 注意力分数的张量

  • att存储最终 softmax 后概率的张量

  • inp输入特征,我们将从中派生出查询、键和值向量

  • b, t, c, nh分别是批大小、序列长度、总通道数(词汇大小)和注意力过程中的头数

首先,我们准备所有常数。选择c3 = c x 3是为了最终连接Q, K, V向量。

主循环处理所有头,对于所有头,我们遍历所有标记,然后遍历所有批次。偏移量再次计算如下:

let query_start=b_idx*t*c3+t_idx*c3+h*hs;let preatt_start=b_idx*nh*t*t+h*t*t+t_idx*t;let att_start=b_idx*nh*t*t+h*t*t+t_idx*t;

这样就可以提取查询向量:

let query_vec=&amp;inp[query_start..query_start+hs];

特别是,我们有当前头和标记的hs维度查询向量。记住,查询表示“这个标记在之前的标记中寻找什么”。

然后,我们构建键矩阵:

let mut keys_mat=Vec::with_capacity((t_idx+1)*hs);fort2in0..=t_idx{let key_start=b_idx*t*c3+t2*c3+h*hs+c;//+c to skip Qandaccess K keys_mat.extend_from_slice(&amp;inp[key_start..key_start+hs]);}

这里with_capacity构造一个新的、空的Vec<T>,其容量至少为指定的容量。该向量将能够容纳至少capacity个元素而不进行重新分配。如果capacity为 0,则向量不会分配新元素。我们将所有键收集到当前时间步t_idx + 1– 记住,键是hs维度的,正如我们可以从key_start看到的那样。

接下来,我们使用 Blas 进行预注意力分数的计算。

let mut preatt_row=vec![0.0f32;t_idx+1];unsafe{sgemm(Layout::RowMajor,Transpose::None,Transpose::None,(t_idx+1)asi32,1,hsasi32,1.0,&amp;keys_mat,hsasi32,query_vec,1,0.0,&amp;mut preatt_row,1,);}

这里的矩阵乘法是(t_idx + 1) x hs * hs x 1 = (t_idx + 1) x 1,这给出了当前标记的查询与每个先前标记的键的匹配度得分(更详细的解释见这里)。这些是 logits,并且它们通过 softmax 进行归一化,并存储在数组att中。

最后,我们有值向量与注意力分数之间的矩阵乘法。这给我们一个加权求和,对于每个标记,我们知道它与所有之前看到的标记的分数。

3. 尝试利用 Rayon 并行化

在真正使用代码之前,我想花点时间谈谈 Rayon,以及我们如何利用它进行数据并行化。

Rayon 是一个数据并行库,它允许我们轻松地在多个线程上运行。正如我在关于 matmul 的帖子中之前看到的,我们可以使用并行迭代器par_iter()par_chunks()par_chunks_mut()。这些迭代器可以直接在所有需要的线程上分区数据负载,而不需要你做原始和脏活。这给了我们在使用和安全性方面的一些简单性。

你可能在代码中看到如下行:

out.par_chunks_mut(oc).for_each(|row|{for(o,val)inrow.iter_mut().enumerate(){*val+=bias[o];}});forrowinout.chunks_mut(oc){//...}

Rayon 将out数组分割成大小为oc的块,并在并行线程中处理它们。每个线程都得到一个单独的块来处理,这样就不会有重叠或对相同数据的竞争。这可以添加到layernorm函数以及encoder函数中,因为我们能够处理更大的数据集,并确保更好的并行化。

然而,并非所有闪亮的东西都是金子。一些操作,如将梯度累积到单个数组中,或在多个线程中汇总统计数据,需要共享状态。共享状态意味着我们不能分割数据,但我们需要所有数据同时存在,因此需要同步。实现共享状态很复杂,因为我们需要防止线程在不协调的情况下同时写入相同的内存地址。因此,我们需要Mutex<T>Mutex提供了互斥排他,这样一次只有一个线程可以锁定互斥锁,确保它是唯一修改包含数据的线程。

use std::sync::Mutex;let shared_data=Mutex::new(vec![0.0f32;size]);(batches_in_parallel).for_each(|batch|{let mut guard=shared_data.lock().unwrap();for(g,val)inguard.iter_mut().zip(batch){*g+=val;}});

如果你看到我的attention_backward函数,你会看到它被分成多个子块。这主要是为了避免错误Cannot borrow as mutable more than once at a time。此外,在这里我强烈使用 Rayon 和 Mutex,以允许在过程中进行一些并行处理。

实际上,在反向传播过程中,我们需要计算相对于输入dinp的梯度,相对于预 softmax 注意力分数dpreatt,以及相对于注意力概率datt的梯度。这是在整个批次上进行的,因此,不出所料,我们确实需要并行化,以避免瓶颈(你将在下面看到这是最耗时的步骤)。我们想要的并行过程是针对批次和每个注意力头,这样我们就可以独立处理这些。然而,因为我总是避免错误Cannot borrow as mutable more than once at a time,我们需要每个线程计算其局部结果,并将所有这些合并成最终的全球梯度。为此,我需要使用Mutex

let global_dinp=Mutex::new(vec![0.0f32;dinp.len()]);let global_datt=Mutex::new(vec![0.0f32;datt.len()]);let global_dpreatt=Mutex::new(vec![0.0f32;dpreatt.len()]);

这样我就可以在循环中使用 Rayon 进行并行工作:

(0..b).into_par_iter().for_each(|b_idx|{let mut local_dinp=vec![0.0f32;dinp.len()];let mut local_datt=vec![0.0f32;datt.len()];let mut local_dpreatt=vec![0.0f32;dpreatt.len()];});

让每个线程在不同的b_idx上工作,这样我们就可以局部计算梯度。所有这些都是在隔离状态下完成的,因此每个线程都在其局部数组上工作。

在进行局部计算后,我们需要将所有数组组合成全局梯度数组:

{let mut g_dinp=global_dinp.lock().unwrap();g_dinp.iter_mut().zip(local_dinp.iter()).for_each(|(g,l)|*g+=l);}

这种lock机制防止了线程以不一致的方式交错写入。当一个线程持有互斥锁时,它具有独占访问权。

在最后一步,我们将切片复制到那里,以将局部结果传播到最终数组中

dinp.copy_from_slice(&amp;global_dinp.lock().unwrap());datt.copy_from_slice(&amp;global_datt.lock().unwrap());dpreatt.copy_from_slice(&amp;global_dpreatt.lock().unwrap());

是时候享受乐趣了:代码、性能和推理!

现在是时候与代码玩耍了。所有这些计算都是在 MacBook Pro,M2,16 GB 内存上运行的。

首先,确保使用python prepro_tinyshakespeare.py下载所需的数据。这将下载输入语料库到data文件夹中。文本被转换为输入训练和验证标记(分别为tiny_shakespeare_train.bintiny_shakespeare_val.bin)。文本使用 GPT-2 标记器进行标记化。然后,你可以使用以下命令构建 rust 代码:

cd llm bash build.sh

经过 2000 步之后,你可能得到一个类似的推理输出:

3792,Is340,it922,good11,,611,if345,you423,have26246,pity11,,284,to423,have281,an45618,ornament11,,257,a1486,design11,,198,2514,To9280,dance11,,7365,bat258,he11,,18044,breathe290,and4545,teach30,?440,O11,,611,if340,it307,be2081,true11,,198,1026,It318,is2081,true356,we743,may307,be991,still2877,living11,,611,if340,it307,be2081,true25,:198,46,O11,,2652,stay11,,393,or314,I2236,shall307,be2636,dead13,.628,

其中我正在打印标记 ID 及其文本值。代码在 16 个线程上运行。要选择线程数,你可以修改此代码中的行和此bash 构建脚本中的行。

图 1 显示了每一步的前向和反向传递时间。耗时以ms为单位。总体来看,我们可以看到前向传递有相当好的优化,平均时间为272.01 +/- 57.71 ms。必须做一些工作来使反向传递更高效,因为它的平均时间为472.63 +/- 51.75 ms。这些耗时比 Karpathy 的原始提交——作为我 Rust 代码的主要灵感来源——平均每步需要30 秒要好 30 倍。

<…/Images/c182273420335c98b3e3cd67a2834d81.png>

图 1:LLM 训练中前向和反向传递的耗时。耗时以毫秒(ms)为单位,代码在针对小型莎士比亚数据集运行了 2000 步。[图片由作者提供]

同时,我们可以测量和跟踪训练损失,如图 2 所示。总体来看,有一个趋势,从最初的平均 4.5 下降到最后的 3.2 左右。

<…/Images/1614f4509d6fcb84c1165b05984f5e58.png>

图 2:2000 步后的训练损失。[图片由作者提供]

进一步的推理示例,在 2000 步之后使用 GPT-2 标记化生成

Tis come;you'll bear it,this fierce protestor..JULIET:We will.O' the loaning makers.First watch man:Youandyour lads,your actions must be controll'd by Sir John

这可能不是 LLM 的最佳结果,但这是在 2000 步之后,仅仅在 30 分钟训练后,对小型莎士比亚数据集进行微调的结果。


结论

如果你已经到达这里,非常感谢阅读我的文章。我希望你已经全面地查看了我的代码,并准备好微调 GPT 模型。

文章展示了我深入 Rust 语言并学习如何优化训练 GPT-like 模型代码的方法。特别是,我们学习了:

  • 如何在 Rust 中实现一个类似 GPT 的模型,主要障碍有哪些(例如矩阵乘法、设置参数),以及如何处理内存管理、张量切片和并行化。

  • 如何利用 Rayon,使用数据并行循环,使用Mutex同步数据。

  • 该项目的目标是运行基于简单 CPU 的笔记本电脑上的 GPT 训练,甚至在这样硬件上,Rust 代码看起来也比 C 语言版本要好得多。许多更多项目可以从这个初始种子开始,从微调到在 Rust 中实现更好的优化。

如果你想与我联系,你可以发送电子邮件到 [email protected]。

http://www.zskr.cn/news/1458912.html

相关文章:

  • 工业吸尘器品牌选择要点:从性能到服务的全面解析 - 品牌排行榜
  • Step 3.5 Flash:面向工业API的7B大模型推理范式重构
  • 告别示教器:用C#写个WinForm小工具,实时监控ABB机器人状态和日志
  • 3分钟颠覆传统:百度网盘提取码智能获取工具如何重构你的数字资源世界
  • LLVM IR指令避坑指南:`nuw`/`nsw`、`exact`这些关键字用错了会怎样?
  • 质量好的工业吸尘器选购要点与品牌解析 - 品牌排行榜
  • 实战指南:基于快马生成生产级PyTorch模型推理镜像与部署方案
  • 【Redis从入门到精通】第44篇:Sentinel启动与监控——它是怎么盯着主服务器的
  • 别再死记硬背!用‘客户服务系统’实战案例,轻松搞懂UML类图与包图设计
  • PHP风控系统与反欺诈策略
  • 新手避坑指南:用BC35-G模块和AT指令,5分钟搞定NBIOT设备上云OneNET
  • FPGA上跑的纯硬件俄罗斯方块:Verilog代码+VGA显示+完整编译工程
  • PHP魔术方法深入理解与实战
  • DeepSeek V4实测:MoE架构与百万上下文的工程真相
  • 从零打造 99.99% 在线 CRM:高可用架构设计与系统化工程方法论
  • 魔兽争霸III终极性能优化:三大核心功能免费解决宽屏适配、地图加载与帧率限制
  • Qwen3.6-Plus工程落地指南:Agent底座的可交付实践
  • AI生成可玩游戏:单文件HTML卡丁车实战指南
  • 从啤酒瓶到二维码:手把手教你复用Gazebo官方模型,打造自定义贴图仿真资产
  • AI工具如何重塑法律服务效率?揭秘2024智能法务整合的7个关键决策点
  • 开源报表工具JimuReport实战:手把手教你配置SQL数据源并生成动态销售报表
  • Spartan-6 FPGA上跑通AD9238双路12位25MHz实时采集的完整ISE工程包
  • 道路积水数据集 路面积水识别数据集 图片数量4524,xml和txt标签都有;公路积水数据集 ✓类别:puddle;
  • 第九章:Token 优化与高效省钱配置(重点)
  • 语义内核形式化模型:AI内容生成的统一数学原理与工程实践
  • Vue版Cesium卫星轨道+雷达扫描三维可视化组件(含CZML数据与小程序适配)
  • 气缸驱动并联机器人位姿控制策略【附仿真】
  • DeepSeek V4实测:百万上下文与MoE架构如何重构AI成本模型
  • 深耕车载数字健康场景,守护全维度驾乘安全与体验
  • GBase 8s数据库高可用之—RHAC远程高可用集群详解