作者 / 魏巍,開發(fā)技術(shù)推廣工程師
如果您對如何使用 JAX 從頭開始構(gòu)建語言模型感到好奇,那么本文非常適合您。我們在 2025 年 Google Cloud Next 大會上舉辦了一場關(guān)于此主題的研討會,并獲得了一些很好的反饋,我們也為所有無法參會的開發(fā)者編寫了這份指南。
本文和代碼示例將引導(dǎo)您構(gòu)建并預(yù)訓(xùn)練 GPT-2 模型,了解 JAX 如何直接利用 Google TPU 的強大能力。您可以使用 Colab 或 Kaggle 中的 TPU 免費運行整個項目,并獲取完整的Notebook。
Notebook
https://github.com/windmaple/LLM_from_scratch.JAX/tree/main/02.GPT2-pretraining
這是一個實踐教程,如果您還不熟悉 JAX,我們建議您從《PyTorch 開發(fā)者指南: JAX 基礎(chǔ)知識》入手。
PyTorch 開發(fā)者指南: JAX 基礎(chǔ)知識
https://cloud.google.com/blog/products/ai-machine-learning/guide-to-jax-for-pytorch-developers
首先,讓我們快速了解一下將要用到的工具。
JAX 生態(tài)系統(tǒng)
在開始構(gòu)建模型之前,讓我們先簡要介紹一下 JAX 生態(tài)系統(tǒng)。JAX 生態(tài)系統(tǒng)采用模塊化方法,通過 JAX 核心提供核心數(shù)值處理能力,而一系列豐富的庫則在此基礎(chǔ)上構(gòu)建而成,以滿足不同應(yīng)用的特定需求,如用于構(gòu)建神經(jīng)網(wǎng)絡(luò)的Flax、用于檢查點和模型持久性的Orbax以及用于優(yōu)化的Optax(在本文中,這 3 個工具都將被用到)。內(nèi)置函數(shù)轉(zhuǎn)換,如 autograd、矢量化和 JIT 編譯,加上強大的性能和易于使用的 API,使 JAX 非常適合訓(xùn)練大語言模型。
JAX 生態(tài)系統(tǒng)
https://docs.jax.dev/en/latest/#ecosystem
Flax
https://github.com/google/flax
Orbax
https://github.com/google/orbax
Optax
https://github.com/google-deepmind/optax
入門指南: 構(gòu)建您的 GPT-2 模型
OpenAI 此前發(fā)布了GPT-2 模型代碼和權(quán)重,這為我們提供了寶貴的參考資料,并且社區(qū)也付出了很多努力來復(fù)現(xiàn)該模型,例如nanoGPT。以下是 GPT-2 的高層級模型架構(gòu)圖:
GPT-2 模型代碼和權(quán)重
https://github.com/openai/gpt-2
nanoGPT
https://github.com/karpathy/nanoGPT
我們將使用NNX (新的 Flax 接口)來構(gòu)建 GPT-2 模型。簡潔起見,我們重點關(guān)注 Transformer Block,這是現(xiàn)代大語言模型的關(guān)鍵所在。Transformer Block 會捕獲任何序列的長程依賴關(guān)系,并構(gòu)建豐富的上下文理解。GPT-2 Transformer Block 由 2 個 LayerNorm 層、1 個多頭注意力 (MHA) 層、2 個 Dropout 層、2 個線性投影層和 2 個殘差連接組成。因此,我們首先需要在TransformerBlock類的__init__函數(shù)中定義這些層:
classTransformerBlock(nnx.Module): def__init__( self, embed_dim:int, num_heads:int, ff_dim:int, dropout_rate:float, rngs: nnx.Rngs, ): self.layer_norm1 = nnx.LayerNorm( epsilon=1e-6, num_features=embed_dim, rngs=rngs ) self.mha = nnx.MultiHeadAttention( num_heads=num_heads, in_features=embed_dim, rngs=rngs ) self.dropout1 = nnx.Dropout(rate=dropout_rate) self.layer_norm2 = nnx.LayerNorm( epsilon=1e-6, num_features=embed_dim, rngs=rngs ) self.linear1 = nnx.Linear( in_features=embed_dim, out_features=ff_dim, rngs=rngs ) self.linear2 = nnx.Linear( in_features=ff_dim, out_features=embed_dim, rngs=rngs ) self.dropout2 = nnx.Dropout(rate=dropout_rate)
NNX (新的 Flax 接口)
https://flax.readthedocs.io/en/v0.8.3/experimental/nnx/index.html#
接下來,我們需要在__call__函數(shù)中對這些層進行組合:
classTransformerBlock(nnx.Module): def__call__(self, inputs, training:bool=False): input_shape = inputs.shape bs, seq_len, emb_sz = input_shape attention_output = self.mha( inputs_q=self.layer_norm1(inputs), mask=causal_attention_mask(seq_len), decode=False, ) x = inputs + self.dropout1( attention_output, deterministic=nottraining ) # MLP mlp_output = self.linear1(self.layer_norm2(x)) mlp_output = nnx.gelu(mlp_output) mlp_output = self.linear2(mlp_output) mlp_output = self.dropout2( mlp_output, deterministic=nottraining ) returnx + mlp_output
如果您使用過任何其他機器學(xué)習(xí)框架 (如 PyTorch 或 TensorFlow) 來訓(xùn)練語言模型,那么您對這段代碼應(yīng)該非常熟悉。但 JAX 具有通過SPMD(Single Program Multiple Data) 自動并行運行代碼的強大能力。這項功能至關(guān)重要,因為我們將在多個加速器 (多個 TPU 核心) 上運行代碼。讓我們來看看它的工作原理。
SPMD
https://docs.jax.dev/en/latest/sharded-computation.html
要執(zhí)行 SPMD,首先我們需要確保自己使用的是 TPU。如果您使用的是 Colab 或 Kaggle,請選擇 TPU 運行時 (您也可以使用 Cloud TPU 虛擬機)。
import jax jax.devices() # Free-tier Colab offers TPU v2: #[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0), # TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1), # TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0), # TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1), # TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0), # TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1), # TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0), # TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]
Colab 和 Kaggle 提供 TPU v2 或 v3,其中含有 8 個獨立的 TPU 核心。TPU v3 托盤的外觀如下所示:
訓(xùn)練您的 GPT-2 模型
為了高效訓(xùn)練 GPT-2 模型,我們將通過 SPMD 讓所有 TPU 核心協(xié)同運行,并利用 JAX 中的數(shù)據(jù)并行。為此,我們定義了一個硬件網(wǎng)格:
mesh= jax.make_mesh((8,1), ('batch','model'))
數(shù)據(jù)并行
https://en.wikipedia.org/wiki/Data_parallelism
我們可以將網(wǎng)格視為加速器的 2D 矩陣。在本例中,我們?yōu)榫W(wǎng)格定義了兩個軸:batch軸和model軸。因此,我們總共有 8 x 1 個核心,也就是 8 個核心。這些軸決定了我們?nèi)绾蝿澐謹?shù)據(jù)和模型參數(shù)。如果之后想嘗試其他并行方案,我們可以對這些軸進行調(diào)整。
現(xiàn)在,我們通過告訴 JAX 如何使用 "model" 軸劃分模型參數(shù)來更改__init__函數(shù)。這是通過在初始化權(quán)重張量 (weight tensors) 時添加nnx.with_partitioning來實現(xiàn)的: 對于像 LayerNorm 縮放/偏置張量這樣的 1D 權(quán)重張量 (weight tensors),我們直接沿著 "model" 軸對它們進行分片;對于像 MHA 和線性內(nèi)核張量這樣的 2D 權(quán)重張量,我們沿著model軸對第二維度進行分片。
classTransformerBlock(nnx.Module): def__init__( self, embed_dim:int, num_heads:int, ff_dim:int, dropout_rate:float, rngs: nnx.Rngs, ): self.layer_norm1 = nnx.LayerNorm( epsilon=1e-6, num_features=embed_dim,rngs=rngs, rngs=rngs, scale_init=nnx.with_partitioning( nnx.initializers.ones_init(), ("model"), ), bias_init=nnx.with_partitioning( nnx.initializers.zeros_init(), ("model"), ), ) self.mha = nnx.MultiHeadAttention( num_heads=num_heads, in_features=embed_dim, kernel_init=nnx.with_partitioning( nnx.initializers.xavier_uniform(), (None,"model"), ), bias_init=nnx.with_partitioning( nnx.initializers.zeros_init(), ("model"), ), ) # Other layers in the block are omitted for brevity
我們需要像這樣劃分其他層,以便為整個 GPT-2 模型啟用模型張量并行。即使我們在本教程中不會使用模型張量并行,實現(xiàn)這一功能仍然是比較好的做法,因為隨著模型規(guī)模的增長,我們將來可能需要對模型參數(shù)進行分區(qū)。實現(xiàn)后,我們只需更改一行代碼即可立即運行更大的模型。例如:
mesh= jax.make_mesh((4,2), ('batch','model'))
接下來,我們需要定義loss_fn和train_step函數(shù),與此前文章類似。train_step()函數(shù)會計算交叉熵損失函數(shù)的梯度,并通過優(yōu)化器更新權(quán)重,然后在循環(huán)中被調(diào)用來訓(xùn)練模型。為了獲得最佳性能,我們使用@nnx.jit裝飾器對這兩個函數(shù)進行 JIT 編譯,因為它們屬于計算密集型函數(shù)。
@nnx.jit defloss_fn(model, batch): logits = model(batch[0]) loss = optax.softmax_cross_entropy_with_integer_labels( logits=logits, labels=batch[1] ).mean() returnloss, logits @nnx.jit deftrain_step( model: nnx.Module, optimizer: nnx.Optimizer, metrics: nnx.MultiMetric, batch, ): grad_fn = nnx.value_and_grad(loss_fn, has_aux=True) (loss, logits), grads = grad_fn(model, batch) metrics.update(loss=loss, logits=logits, lables=batch[1]) optimizer.update(grads)
此前文章
https://cloud.google.com/blog/products/ai-machine-learning/guide-to-jax-for-pytorch-developers
對于優(yōu)化器,我們使用 Optax 中的 AdamW 以及余弦衰減調(diào)度。您也可以在 Optax 中試用其他優(yōu)化器或調(diào)度計劃。
schedule = optax.cosine_decay_schedule( init_value=init_learning_rate, decay_steps=max_steps ) optax_chain = optax.chain( optax.adamw(learning_rate=schedule, weight_decay=weight_decay) ) optimizer = nnx.Optimizer(model, optax_chain)
其他優(yōu)化器
https://optax.readthedocs.io/en/latest/api/optimizers.html
調(diào)度計劃
https://optax.readthedocs.io/en/latest/api/optimizer_schedules.html
最后,我們需要創(chuàng)建一個簡單的訓(xùn)練循環(huán)。
while True: input_batch, target_batch =get_batch("train") train_step( model, optimizer, train_metrics, jax.device_put( (input_batch, target_batch), NamedSharding(mesh,P("batch", None)), ), ) step +=1 if step > max_steps: break
請注意我們使用jax.device_put函數(shù)沿著 batch 軸對輸入數(shù)據(jù)進行分區(qū)。在這種情況下,JAX 將啟用數(shù)據(jù)并行,并通過自動插入通信集合 (AllReduce) 將所有內(nèi)容整合在一起,同時盡可能多地實現(xiàn)計算與通信的重疊。有關(guān)并行計算更深入的討論,請參閱 JAX 的并行編程入門文檔。
并行編程入門
https://docs.jax.dev/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html#intro-and-a-quick-example
模型此時應(yīng)處于訓(xùn)練狀態(tài),如果使用權(quán)重和偏差來跟蹤運行情況,我們便可以觀察訓(xùn)練損失。以下是訓(xùn)練 GPT-2 124M 模型的測試運行結(jié)果:
權(quán)重和偏差
https://wandb.ai/site
如果使用 Kaggle TPU v3,訓(xùn)練時間大約為 7 個小時 (我們可以不中斷地使用 Kaggle TPU v3 9 個小時);但如果使用Trillium,訓(xùn)練時間將縮短至約 1.5 個小時 (請注意,Trillium 的每個芯片配備 32G 高帶寬內(nèi)存 (HBM),因此我們可以將批量大小加倍,并將訓(xùn)練步數(shù)減半)。
Trillium
https://cloud.google.com/blog/products/compute/trillium-tpu-is-ga
最終的損失情況與nanoGPT 的損失情況大致相符。我們在編寫此代碼示例時對 nanoGPT 進行了研究。
nanoGPT 的損失情況
https://github.com/karpathy/nanoGPT/tree/master?tab=readme-ov-file#baselines
如果使用 Cloud TPU,我們還可以通過 "tpu-info" 命令 (Cloud TPU 監(jiān)控調(diào)試包的一部分) 或權(quán)重和偏差儀表盤監(jiān)控 TPU 利用率。我們的 TPU 正在全力運行!
Cloud TPU 監(jiān)控調(diào)試
https://github.com/AI-Hypercomputer/cloud-tpu-monitoring-debugging
完成模型訓(xùn)練后,我們可以使用Orbax保存模型:
checkpointer = orbax.PyTreeCheckpointer() train_state = nnx.pure(nnx.state(model)) checkpointer.save(checkpoint_path, train_state)
Orbax
https://github.com/google/orbax
后續(xù)步驟: 探索高級 LLM 訓(xùn)練和擴展
這基本上就是我們訓(xùn)練 GPT-2 模型所需了解的全部內(nèi)容。您可以在完整的Notebook中找到其他詳細信息,如數(shù)據(jù)加載、超參數(shù)、指標等。
Notebook
https://github.com/windmaple/LLM_from_scratch.JAX/tree/main/02.GPT2-pretraining
當(dāng)然,GPT-2 如今還是一個小模型,許多前沿實驗室正在訓(xùn)練擁有數(shù)千億參數(shù)的模型。但是,現(xiàn)在您已經(jīng)學(xué)習(xí)了如何使用 JAX 和 TPU 構(gòu)建小語言模型,為深入了解如何擴展模型做好了準備。
如何擴展模型
https://jax-ml.github.io/scaling-book/
此外,您既可以使用MaxText來訓(xùn)練預(yù)構(gòu)建的前沿 LLM,也可以通過參考JAX LLM 示例或Stanford Marin 模型來學(xué)習(xí)如何從頭開始構(gòu)建最新的模型。
MaxText
https://github.com/AI-Hypercomputer/maxtext
JAX LLM 示例
https://github.com/jax-ml/jax-llm-examples/
Stanford Marin 模型
https://developers.googleblog.com/en/stanfords-marin-foundation-model-first-fully-open-model-developed-using-jax/
我們期待看到您使用 JAX 和 TPU 構(gòu)建的出色模型!
-
模型
+關(guān)注
關(guān)注
1文章
3607瀏覽量
51408 -
代碼
+關(guān)注
關(guān)注
30文章
4921瀏覽量
72199 -
TPU
+關(guān)注
關(guān)注
0文章
160瀏覽量
21430 -
pytorch
+關(guān)注
關(guān)注
2文章
812瀏覽量
14413
原文標題:實戰(zhàn)指南|手把手教您在 TPU 上免費使用 JAX 訓(xùn)練 GPT-2 模型
文章出處:【微信號:Google_Developers,微信公眾號:谷歌開發(fā)者】歡迎添加關(guān)注!文章轉(zhuǎn)載請注明出處。
發(fā)布評論請先 登錄
用PaddleNLP在4060單卡上實踐大模型預(yù)訓(xùn)練技術(shù)

如何利用Google Colab的云TPU加速Keras模型訓(xùn)練
OpenAI發(fā)布了一個“逆天”的AI模型——GPT2整個模型包含15億個參數(shù)
OpenAI發(fā)布一款令人印象深刻的語言模型GPT-2
布朗大學(xué)90后研究生:我們復(fù)現(xiàn)了15億參數(shù)GPT-2模型,你也行!
OpenAI宣布,發(fā)布了7.74億參數(shù)GPT-2語言模型
和AI聊天,自然語言模型 GPT-2可能會推出個人信息
GPT系列的“高仿” 最大可達GPT-3大小 自主訓(xùn)練
使用NVIDIA TensorRT優(yōu)化T5和GPT-2

基于OpenAI的GPT-2的語言模型ProtGPT2可生成新的蛋白質(zhì)序列
GPT/GPT-2/GPT-3/InstructGPT進化之路
ELMER: 高效強大的非自回歸預(yù)訓(xùn)練文本生成模型
DeepSpeed里面和Zero相關(guān)技術(shù)教程

DeepSpeed結(jié)合Megatron-LM訓(xùn)練GPT2模型筆記

用PaddleNLP為GPT-2模型制作FineWeb二進制預(yù)訓(xùn)練數(shù)據(jù)集

評論