作者 / 魏巍,開發(fā)技術(shù)推廣工程師
如果您對如何使用 JAX 從頭開始構(gòu)建語言模型感到好奇,那么本文非常適合您。我們在 2025 年 Google Cloud Next 大會上舉辦了一場關(guān)于此主題的研討會,并獲得了一些很好的反饋,我們也為所有無法參會的開發(fā)者編寫了這份指南。
本文和代碼示例將引導(dǎo)您構(gòu)建并預(yù)訓(xùn)練 GPT-2 模型,了解 JAX 如何直接利用 Google TPU 的強(qiáng)大能力。您可以使用 Colab 或 Kaggle 中的 TPU 免費(fèi)運(yùn)行整個項目,并獲取完整的Notebook。
Notebook
https://github.com/windmaple/LLM_from_scratch.JAX/tree/main/02.GPT2-pretraining
這是一個實踐教程,如果您還不熟悉 JAX,我們建議您從《PyTorch 開發(fā)者指南: JAX 基礎(chǔ)知識》入手。
PyTorch 開發(fā)者指南: JAX 基礎(chǔ)知識
https://cloud.google.com/blog/products/ai-machine-learning/guide-to-jax-for-pytorch-developers
首先,讓我們快速了解一下將要用到的工具。
JAX 生態(tài)系統(tǒng)
在開始構(gòu)建模型之前,讓我們先簡要介紹一下 JAX 生態(tài)系統(tǒng)。JAX 生態(tài)系統(tǒng)采用模塊化方法,通過 JAX 核心提供核心數(shù)值處理能力,而一系列豐富的庫則在此基礎(chǔ)上構(gòu)建而成,以滿足不同應(yīng)用的特定需求,如用于構(gòu)建神經(jīng)網(wǎng)絡(luò)的Flax、用于檢查點和模型持久性的Orbax以及用于優(yōu)化的Optax(在本文中,這 3 個工具都將被用到)。內(nèi)置函數(shù)轉(zhuǎn)換,如 autograd、矢量化和 JIT 編譯,加上強(qiáng)大的性能和易于使用的 API,使 JAX 非常適合訓(xùn)練大語言模型。
JAX 生態(tài)系統(tǒng)
https://docs.jax.dev/en/latest/#ecosystem
Flax
https://github.com/google/flax
Orbax
https://github.com/google/orbax
Optax
https://github.com/google-deepmind/optax
入門指南: 構(gòu)建您的 GPT-2 模型
OpenAI 此前發(fā)布了GPT-2 模型代碼和權(quán)重,這為我們提供了寶貴的參考資料,并且社區(qū)也付出了很多努力來復(fù)現(xiàn)該模型,例如nanoGPT。以下是 GPT-2 的高層級模型架構(gòu)圖:

GPT-2 模型代碼和權(quán)重
https://github.com/openai/gpt-2
nanoGPT
https://github.com/karpathy/nanoGPT
我們將使用NNX (新的 Flax 接口)來構(gòu)建 GPT-2 模型。簡潔起見,我們重點關(guān)注 Transformer Block,這是現(xiàn)代大語言模型的關(guān)鍵所在。Transformer Block 會捕獲任何序列的長程依賴關(guān)系,并構(gòu)建豐富的上下文理解。GPT-2 Transformer Block 由 2 個 LayerNorm 層、1 個多頭注意力 (MHA) 層、2 個 Dropout 層、2 個線性投影層和 2 個殘差連接組成。因此,我們首先需要在TransformerBlock類的__init__函數(shù)中定義這些層:
classTransformerBlock(nnx.Module):
def__init__(
self,
embed_dim:int,
num_heads:int,
ff_dim:int,
dropout_rate:float,
rngs: nnx.Rngs,
):
self.layer_norm1 = nnx.LayerNorm(
epsilon=1e-6, num_features=embed_dim, rngs=rngs
)
self.mha = nnx.MultiHeadAttention(
num_heads=num_heads, in_features=embed_dim, rngs=rngs
)
self.dropout1 = nnx.Dropout(rate=dropout_rate)
self.layer_norm2 = nnx.LayerNorm(
epsilon=1e-6, num_features=embed_dim, rngs=rngs
)
self.linear1 = nnx.Linear(
in_features=embed_dim, out_features=ff_dim, rngs=rngs
)
self.linear2 = nnx.Linear(
in_features=ff_dim, out_features=embed_dim, rngs=rngs
)
self.dropout2 = nnx.Dropout(rate=dropout_rate)
NNX (新的 Flax 接口)
https://flax.readthedocs.io/en/v0.8.3/experimental/nnx/index.html#
接下來,我們需要在__call__函數(shù)中對這些層進(jìn)行組合:
classTransformerBlock(nnx.Module):
def__call__(self, inputs, training:bool=False):
input_shape = inputs.shape
bs, seq_len, emb_sz = input_shape
attention_output = self.mha(
inputs_q=self.layer_norm1(inputs),
mask=causal_attention_mask(seq_len),
decode=False,
)
x = inputs + self.dropout1(
attention_output, deterministic=nottraining
)
# MLP
mlp_output = self.linear1(self.layer_norm2(x))
mlp_output = nnx.gelu(mlp_output)
mlp_output = self.linear2(mlp_output)
mlp_output = self.dropout2(
mlp_output, deterministic=nottraining
)
returnx + mlp_output
如果您使用過任何其他機(jī)器學(xué)習(xí)框架 (如 PyTorch 或 TensorFlow) 來訓(xùn)練語言模型,那么您對這段代碼應(yīng)該非常熟悉。但 JAX 具有通過SPMD(Single Program Multiple Data) 自動并行運(yùn)行代碼的強(qiáng)大能力。這項功能至關(guān)重要,因為我們將在多個加速器 (多個 TPU 核心) 上運(yùn)行代碼。讓我們來看看它的工作原理。
SPMD
https://docs.jax.dev/en/latest/sharded-computation.html
要執(zhí)行 SPMD,首先我們需要確保自己使用的是 TPU。如果您使用的是 Colab 或 Kaggle,請選擇 TPU 運(yùn)行時 (您也可以使用 Cloud TPU 虛擬機(jī))。
import jax jax.devices() # Free-tier Colab offers TPU v2: #[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0), # TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1), # TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0), # TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1), # TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0), # TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1), # TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0), # TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]
Colab 和 Kaggle 提供 TPU v2 或 v3,其中含有 8 個獨(dú)立的 TPU 核心。TPU v3 托盤的外觀如下所示:

訓(xùn)練您的 GPT-2 模型
為了高效訓(xùn)練 GPT-2 模型,我們將通過 SPMD 讓所有 TPU 核心協(xié)同運(yùn)行,并利用 JAX 中的數(shù)據(jù)并行。為此,我們定義了一個硬件網(wǎng)格:
mesh= jax.make_mesh((8,1), ('batch','model'))
數(shù)據(jù)并行
https://en.wikipedia.org/wiki/Data_parallelism
我們可以將網(wǎng)格視為加速器的 2D 矩陣。在本例中,我們?yōu)榫W(wǎng)格定義了兩個軸:batch軸和model軸。因此,我們總共有 8 x 1 個核心,也就是 8 個核心。這些軸決定了我們?nèi)绾蝿澐謹(jǐn)?shù)據(jù)和模型參數(shù)。如果之后想嘗試其他并行方案,我們可以對這些軸進(jìn)行調(diào)整。
現(xiàn)在,我們通過告訴 JAX 如何使用 "model" 軸劃分模型參數(shù)來更改__init__函數(shù)。這是通過在初始化權(quán)重張量 (weight tensors) 時添加nnx.with_partitioning來實現(xiàn)的: 對于像 LayerNorm 縮放/偏置張量這樣的 1D 權(quán)重張量 (weight tensors),我們直接沿著 "model" 軸對它們進(jìn)行分片;對于像 MHA 和線性內(nèi)核張量這樣的 2D 權(quán)重張量,我們沿著model軸對第二維度進(jìn)行分片。
classTransformerBlock(nnx.Module):
def__init__(
self,
embed_dim:int,
num_heads:int,
ff_dim:int,
dropout_rate:float,
rngs: nnx.Rngs,
):
self.layer_norm1 = nnx.LayerNorm(
epsilon=1e-6, num_features=embed_dim,rngs=rngs, rngs=rngs,
scale_init=nnx.with_partitioning(
nnx.initializers.ones_init(),
("model"),
),
bias_init=nnx.with_partitioning(
nnx.initializers.zeros_init(),
("model"),
),
)
self.mha = nnx.MultiHeadAttention(
num_heads=num_heads, in_features=embed_dim,
kernel_init=nnx.with_partitioning(
nnx.initializers.xavier_uniform(),
(None,"model"),
),
bias_init=nnx.with_partitioning(
nnx.initializers.zeros_init(),
("model"),
),
)
# Other layers in the block are omitted for brevity
我們需要像這樣劃分其他層,以便為整個 GPT-2 模型啟用模型張量并行。即使我們在本教程中不會使用模型張量并行,實現(xiàn)這一功能仍然是比較好的做法,因為隨著模型規(guī)模的增長,我們將來可能需要對模型參數(shù)進(jìn)行分區(qū)。實現(xiàn)后,我們只需更改一行代碼即可立即運(yùn)行更大的模型。例如:
mesh= jax.make_mesh((4,2), ('batch','model'))
接下來,我們需要定義loss_fn和train_step函數(shù),與此前文章類似。train_step()函數(shù)會計算交叉熵?fù)p失函數(shù)的梯度,并通過優(yōu)化器更新權(quán)重,然后在循環(huán)中被調(diào)用來訓(xùn)練模型。為了獲得最佳性能,我們使用@nnx.jit裝飾器對這兩個函數(shù)進(jìn)行 JIT 編譯,因為它們屬于計算密集型函數(shù)。
@nnx.jit
defloss_fn(model, batch):
logits = model(batch[0])
loss = optax.softmax_cross_entropy_with_integer_labels(
logits=logits, labels=batch[1]
).mean()
returnloss, logits
@nnx.jit
deftrain_step(
model: nnx.Module,
optimizer: nnx.Optimizer,
metrics: nnx.MultiMetric,
batch,
):
grad_fn = nnx.value_and_grad(loss_fn, has_aux=True)
(loss, logits), grads = grad_fn(model, batch)
metrics.update(loss=loss, logits=logits, lables=batch[1])
optimizer.update(grads)
此前文章
https://cloud.google.com/blog/products/ai-machine-learning/guide-to-jax-for-pytorch-developers
對于優(yōu)化器,我們使用 Optax 中的 AdamW 以及余弦衰減調(diào)度。您也可以在 Optax 中試用其他優(yōu)化器或調(diào)度計劃。
schedule = optax.cosine_decay_schedule( init_value=init_learning_rate, decay_steps=max_steps ) optax_chain = optax.chain( optax.adamw(learning_rate=schedule, weight_decay=weight_decay) ) optimizer = nnx.Optimizer(model, optax_chain)
其他優(yōu)化器
https://optax.readthedocs.io/en/latest/api/optimizers.html
調(diào)度計劃
https://optax.readthedocs.io/en/latest/api/optimizer_schedules.html
最后,我們需要創(chuàng)建一個簡單的訓(xùn)練循環(huán)。
while True:
input_batch, target_batch =get_batch("train")
train_step(
model,
optimizer,
train_metrics,
jax.device_put(
(input_batch, target_batch),
NamedSharding(mesh,P("batch", None)),
),
)
step +=1
if step > max_steps:
break
請注意我們使用jax.device_put函數(shù)沿著 batch 軸對輸入數(shù)據(jù)進(jìn)行分區(qū)。在這種情況下,JAX 將啟用數(shù)據(jù)并行,并通過自動插入通信集合 (AllReduce) 將所有內(nèi)容整合在一起,同時盡可能多地實現(xiàn)計算與通信的重疊。有關(guān)并行計算更深入的討論,請參閱 JAX 的并行編程入門文檔。
并行編程入門
https://docs.jax.dev/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html#intro-and-a-quick-example
模型此時應(yīng)處于訓(xùn)練狀態(tài),如果使用權(quán)重和偏差來跟蹤運(yùn)行情況,我們便可以觀察訓(xùn)練損失。以下是訓(xùn)練 GPT-2 124M 模型的測試運(yùn)行結(jié)果:

權(quán)重和偏差
https://wandb.ai/site
如果使用 Kaggle TPU v3,訓(xùn)練時間大約為 7 個小時 (我們可以不中斷地使用 Kaggle TPU v3 9 個小時);但如果使用Trillium,訓(xùn)練時間將縮短至約 1.5 個小時 (請注意,Trillium 的每個芯片配備 32G 高帶寬內(nèi)存 (HBM),因此我們可以將批量大小加倍,并將訓(xùn)練步數(shù)減半)。
Trillium
https://cloud.google.com/blog/products/compute/trillium-tpu-is-ga
最終的損失情況與nanoGPT 的損失情況大致相符。我們在編寫此代碼示例時對 nanoGPT 進(jìn)行了研究。

nanoGPT 的損失情況
https://github.com/karpathy/nanoGPT/tree/master?tab=readme-ov-file#baselines
如果使用 Cloud TPU,我們還可以通過 "tpu-info" 命令 (Cloud TPU 監(jiān)控調(diào)試包的一部分) 或權(quán)重和偏差儀表盤監(jiān)控 TPU 利用率。我們的 TPU 正在全力運(yùn)行!

Cloud TPU 監(jiān)控調(diào)試
https://github.com/AI-Hypercomputer/cloud-tpu-monitoring-debugging
完成模型訓(xùn)練后,我們可以使用Orbax保存模型:
checkpointer = orbax.PyTreeCheckpointer() train_state = nnx.pure(nnx.state(model)) checkpointer.save(checkpoint_path, train_state)
Orbax
https://github.com/google/orbax
后續(xù)步驟: 探索高級 LLM 訓(xùn)練和擴(kuò)展
這基本上就是我們訓(xùn)練 GPT-2 模型所需了解的全部內(nèi)容。您可以在完整的Notebook中找到其他詳細(xì)信息,如數(shù)據(jù)加載、超參數(shù)、指標(biāo)等。
Notebook
https://github.com/windmaple/LLM_from_scratch.JAX/tree/main/02.GPT2-pretraining
當(dāng)然,GPT-2 如今還是一個小模型,許多前沿實驗室正在訓(xùn)練擁有數(shù)千億參數(shù)的模型。但是,現(xiàn)在您已經(jīng)學(xué)習(xí)了如何使用 JAX 和 TPU 構(gòu)建小語言模型,為深入了解如何擴(kuò)展模型做好了準(zhǔn)備。
如何擴(kuò)展模型
https://jax-ml.github.io/scaling-book/
此外,您既可以使用MaxText來訓(xùn)練預(yù)構(gòu)建的前沿 LLM,也可以通過參考JAX LLM 示例或Stanford Marin 模型來學(xué)習(xí)如何從頭開始構(gòu)建最新的模型。
MaxText
https://github.com/AI-Hypercomputer/maxtext
JAX LLM 示例
https://github.com/jax-ml/jax-llm-examples/
Stanford Marin 模型
https://developers.googleblog.com/en/stanfords-marin-foundation-model-first-fully-open-model-developed-using-jax/
我們期待看到您使用 JAX 和 TPU 構(gòu)建的出色模型!
-
模型
+關(guān)注
關(guān)注
1文章
3640瀏覽量
51678 -
代碼
+關(guān)注
關(guān)注
30文章
4940瀏覽量
73052 -
TPU
+關(guān)注
關(guān)注
0文章
164瀏覽量
21523 -
pytorch
+關(guān)注
關(guān)注
2文章
812瀏覽量
14657
原文標(biāo)題:實戰(zhàn)指南|手把手教您在 TPU 上免費(fèi)使用 JAX 訓(xùn)練 GPT-2 模型
文章出處:【微信號:Google_Developers,微信公眾號:谷歌開發(fā)者】歡迎添加關(guān)注!文章轉(zhuǎn)載請注明出處。
發(fā)布評論請先 登錄
用PaddleNLP在4060單卡上實踐大模型預(yù)訓(xùn)練技術(shù)
如何利用Google Colab的云TPU加速Keras模型訓(xùn)練
OpenAI發(fā)布了一個“逆天”的AI模型——GPT2整個模型包含15億個參數(shù)
OpenAI發(fā)布一款令人印象深刻的語言模型GPT-2
布朗大學(xué)90后研究生:我們復(fù)現(xiàn)了15億參數(shù)GPT-2模型,你也行!
OpenAI宣布,發(fā)布了7.74億參數(shù)GPT-2語言模型
和AI聊天,自然語言模型 GPT-2可能會推出個人信息
GPT系列的“高仿” 最大可達(dá)GPT-3大小 自主訓(xùn)練
使用NVIDIA TensorRT優(yōu)化T5和GPT-2
基于OpenAI的GPT-2的語言模型ProtGPT2可生成新的蛋白質(zhì)序列
GPT/GPT-2/GPT-3/InstructGPT進(jìn)化之路
ELMER: 高效強(qiáng)大的非自回歸預(yù)訓(xùn)練文本生成模型
DeepSpeed里面和Zero相關(guān)技術(shù)教程
DeepSpeed結(jié)合Megatron-LM訓(xùn)練GPT2模型筆記
用PaddleNLP為GPT-2模型制作FineWeb二進(jìn)制預(yù)訓(xùn)練數(shù)據(jù)集

如何在TPU上使用JAX訓(xùn)練GPT-2模型
評論