chinese直男口爆体育生外卖, 99久久er热在这里只有精品99, 又色又爽又黄18禁美女裸身无遮挡, gogogo高清免费观看日本电视,私密按摩师高清版在线,人妻视频毛茸茸,91论坛 兴趣闲谈,欧美 亚洲 精品 8区,国产精品久久久久精品免费

0
  • 聊天消息
  • 系統(tǒng)消息
  • 評論與回復(fù)
登錄后你可以
  • 下載海量資料
  • 學(xué)習(xí)在線課程
  • 觀看技術(shù)視頻
  • 寫文章/發(fā)帖/加入社區(qū)
會員中心
創(chuàng)作中心

完善資料讓更多小伙伴認識你,還能領(lǐng)取20積分哦,立即完善>

3天內(nèi)不再提示

如何在TPU上使用JAX訓(xùn)練GPT-2模型

谷歌開發(fā)者 ? 來源:谷歌開發(fā)者 ? 2025-09-03 11:39 ? 次閱讀
加入交流群
微信小助手二維碼

掃碼添加小助手

加入工程師交流群

作者 / 魏巍,開發(fā)技術(shù)推廣工程師

如果您對如何使用 JAX 從頭開始構(gòu)建語言模型感到好奇,那么本文非常適合您。我們在 2025 年 Google Cloud Next 大會上舉辦了一場關(guān)于此主題的研討會,并獲得了一些很好的反饋,我們也為所有無法參會的開發(fā)者編寫了這份指南。

本文和代碼示例將引導(dǎo)您構(gòu)建并預(yù)訓(xùn)練 GPT-2 模型,了解 JAX 如何直接利用 Google TPU 的強大能力。您可以使用 Colab 或 Kaggle 中的 TPU 免費運行整個項目,并獲取完整的Notebook。

Notebook

https://github.com/windmaple/LLM_from_scratch.JAX/tree/main/02.GPT2-pretraining

這是一個實踐教程,如果您還不熟悉 JAX,我們建議您從《PyTorch 開發(fā)者指南: JAX 基礎(chǔ)知識》入手。

PyTorch 開發(fā)者指南: JAX 基礎(chǔ)知識

https://cloud.google.com/blog/products/ai-machine-learning/guide-to-jax-for-pytorch-developers

首先,讓我們快速了解一下將要用到的工具。

JAX 生態(tài)系統(tǒng)

在開始構(gòu)建模型之前,讓我們先簡要介紹一下 JAX 生態(tài)系統(tǒng)。JAX 生態(tài)系統(tǒng)采用模塊化方法,通過 JAX 核心提供核心數(shù)值處理能力,而一系列豐富的庫則在此基礎(chǔ)上構(gòu)建而成,以滿足不同應(yīng)用的特定需求,如用于構(gòu)建神經(jīng)網(wǎng)絡(luò)的Flax、用于檢查點和模型持久性的Orbax以及用于優(yōu)化的Optax(在本文中,這 3 個工具都將被用到)。內(nèi)置函數(shù)轉(zhuǎn)換,如 autograd、矢量化和 JIT 編譯,加上強大的性能和易于使用的 API,使 JAX 非常適合訓(xùn)練大語言模型。

JAX 生態(tài)系統(tǒng)

https://docs.jax.dev/en/latest/#ecosystem

Flax

https://github.com/google/flax

Orbax

https://github.com/google/orbax

Optax

https://github.com/google-deepmind/optax

入門指南: 構(gòu)建您的 GPT-2 模型

OpenAI 此前發(fā)布了GPT-2 模型代碼和權(quán)重,這為我們提供了寶貴的參考資料,并且社區(qū)也付出了很多努力來復(fù)現(xiàn)該模型,例如nanoGPT。以下是 GPT-2 的高層級模型架構(gòu)圖:

dedd83ce-84bb-11f0-a18e-92fbcf53809c.png

GPT-2 模型代碼和權(quán)重

https://github.com/openai/gpt-2

nanoGPT

https://github.com/karpathy/nanoGPT

我們將使用NNX (新的 Flax 接口)來構(gòu)建 GPT-2 模型。簡潔起見,我們重點關(guān)注 Transformer Block,這是現(xiàn)代大語言模型的關(guān)鍵所在。Transformer Block 會捕獲任何序列的長程依賴關(guān)系,并構(gòu)建豐富的上下文理解。GPT-2 Transformer Block 由 2 個 LayerNorm 層、1 個多頭注意力 (MHA) 層、2 個 Dropout 層、2 個線性投影層和 2 個殘差連接組成。因此,我們首先需要在TransformerBlock類的__init__函數(shù)中定義這些層:

classTransformerBlock(nnx.Module):
 def__init__(
    self,
    embed_dim:int,
    num_heads:int,
    ff_dim:int,
    dropout_rate:float,
    rngs: nnx.Rngs,
  ):
    self.layer_norm1 = nnx.LayerNorm(
      epsilon=1e-6, num_features=embed_dim, rngs=rngs
    )
    self.mha = nnx.MultiHeadAttention(
      num_heads=num_heads, in_features=embed_dim, rngs=rngs
    )
    self.dropout1 = nnx.Dropout(rate=dropout_rate)
    self.layer_norm2 = nnx.LayerNorm(
      epsilon=1e-6, num_features=embed_dim, rngs=rngs
    )
    self.linear1 = nnx.Linear(
      in_features=embed_dim, out_features=ff_dim, rngs=rngs
    )
    self.linear2 = nnx.Linear(
      in_features=ff_dim, out_features=embed_dim, rngs=rngs
    )
    self.dropout2 = nnx.Dropout(rate=dropout_rate)

NNX (新的 Flax 接口)

https://flax.readthedocs.io/en/v0.8.3/experimental/nnx/index.html#

接下來,我們需要在__call__函數(shù)中對這些層進行組合:

classTransformerBlock(nnx.Module):
 def__call__(self, inputs, training:bool=False):
    input_shape = inputs.shape
    bs, seq_len, emb_sz = input_shape


    attention_output = self.mha(
      inputs_q=self.layer_norm1(inputs),
      mask=causal_attention_mask(seq_len),
      decode=False,
    )
    x = inputs + self.dropout1(
      attention_output, deterministic=nottraining
    )


   # MLP
    mlp_output = self.linear1(self.layer_norm2(x))
    mlp_output = nnx.gelu(mlp_output)
    mlp_output = self.linear2(mlp_output)
    mlp_output = self.dropout2(
      mlp_output, deterministic=nottraining
    )


   returnx + mlp_output

如果您使用過任何其他機器學(xué)習(xí)框架 (如 PyTorch 或 TensorFlow) 來訓(xùn)練語言模型,那么您對這段代碼應(yīng)該非常熟悉。但 JAX 具有通過SPMD(Single Program Multiple Data) 自動并行運行代碼的強大能力。這項功能至關(guān)重要,因為我們將在多個加速器 (多個 TPU 核心) 上運行代碼。讓我們來看看它的工作原理。

SPMD

https://docs.jax.dev/en/latest/sharded-computation.html

要執(zhí)行 SPMD,首先我們需要確保自己使用的是 TPU。如果您使用的是 Colab 或 Kaggle,請選擇 TPU 運行時 (您也可以使用 Cloud TPU 虛擬機)。

import jax
jax.devices()


# Free-tier Colab offers TPU v2:
#[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
# TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1),
# TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0),
# TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1),
# TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0),
# TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1),
# TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0),
# TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]

Colab 和 Kaggle 提供 TPU v2 或 v3,其中含有 8 個獨立的 TPU 核心。TPU v3 托盤的外觀如下所示:

def0e464-84bb-11f0-a18e-92fbcf53809c.png

訓(xùn)練您的 GPT-2 模型

為了高效訓(xùn)練 GPT-2 模型,我們將通過 SPMD 讓所有 TPU 核心協(xié)同運行,并利用 JAX 中的數(shù)據(jù)并行。為此,我們定義了一個硬件網(wǎng)格:

mesh= jax.make_mesh((8,1), ('batch','model'))

數(shù)據(jù)并行

https://en.wikipedia.org/wiki/Data_parallelism

我們可以將網(wǎng)格視為加速器的 2D 矩陣。在本例中,我們?yōu)榫W(wǎng)格定義了兩個軸:batch軸和model軸。因此,我們總共有 8 x 1 個核心,也就是 8 個核心。這些軸決定了我們?nèi)绾蝿澐謹?shù)據(jù)和模型參數(shù)。如果之后想嘗試其他并行方案,我們可以對這些軸進行調(diào)整。

現(xiàn)在,我們通過告訴 JAX 如何使用 "model" 軸劃分模型參數(shù)來更改__init__函數(shù)。這是通過在初始化權(quán)重張量 (weight tensors) 時添加nnx.with_partitioning來實現(xiàn)的: 對于像 LayerNorm 縮放/偏置張量這樣的 1D 權(quán)重張量 (weight tensors),我們直接沿著 "model" 軸對它們進行分片;對于像 MHA 和線性內(nèi)核張量這樣的 2D 權(quán)重張量,我們沿著model軸對第二維度進行分片。

classTransformerBlock(nnx.Module):
 def__init__(
    self,
    embed_dim:int,
    num_heads:int,
    ff_dim:int,
    dropout_rate:float,
    rngs: nnx.Rngs,
  ):
    self.layer_norm1 = nnx.LayerNorm(
      epsilon=1e-6, num_features=embed_dim,rngs=rngs, rngs=rngs,
      scale_init=nnx.with_partitioning(
        nnx.initializers.ones_init(),
        ("model"),
      ),
      bias_init=nnx.with_partitioning(
        nnx.initializers.zeros_init(),
       ("model"),
      ),
    )
    self.mha = nnx.MultiHeadAttention(
      num_heads=num_heads, in_features=embed_dim,
      kernel_init=nnx.with_partitioning(
        nnx.initializers.xavier_uniform(),
       (None,"model"),
      ),
      bias_init=nnx.with_partitioning(
        nnx.initializers.zeros_init(),
       ("model"),
      ),
    )
   # Other layers in the block are omitted for brevity

我們需要像這樣劃分其他層,以便為整個 GPT-2 模型啟用模型張量并行。即使我們在本教程中不會使用模型張量并行,實現(xiàn)這一功能仍然是比較好的做法,因為隨著模型規(guī)模的增長,我們將來可能需要對模型參數(shù)進行分區(qū)。實現(xiàn)后,我們只需更改一行代碼即可立即運行更大的模型。例如:

mesh= jax.make_mesh((4,2), ('batch','model'))

接下來,我們需要定義loss_fn和train_step函數(shù),與此前文章類似。train_step()函數(shù)會計算交叉熵損失函數(shù)的梯度,并通過優(yōu)化器更新權(quán)重,然后在循環(huán)中被調(diào)用來訓(xùn)練模型。為了獲得最佳性能,我們使用@nnx.jit裝飾器對這兩個函數(shù)進行 JIT 編譯,因為它們屬于計算密集型函數(shù)。

@nnx.jit
defloss_fn(model, batch):
  logits = model(batch[0])
  loss = optax.softmax_cross_entropy_with_integer_labels(
    logits=logits, labels=batch[1]
  ).mean()
 returnloss, logits




@nnx.jit
deftrain_step(
  model: nnx.Module,
  optimizer: nnx.Optimizer,
  metrics: nnx.MultiMetric,
  batch,
):
  grad_fn = nnx.value_and_grad(loss_fn, has_aux=True)
  (loss, logits), grads = grad_fn(model, batch)
  metrics.update(loss=loss, logits=logits, lables=batch[1])
  optimizer.update(grads)

此前文章

https://cloud.google.com/blog/products/ai-machine-learning/guide-to-jax-for-pytorch-developers

對于優(yōu)化器,我們使用 Optax 中的 AdamW 以及余弦衰減調(diào)度。您也可以在 Optax 中試用其他優(yōu)化器或調(diào)度計劃。

schedule = optax.cosine_decay_schedule(
  init_value=init_learning_rate, decay_steps=max_steps
)
optax_chain = optax.chain(
  optax.adamw(learning_rate=schedule, weight_decay=weight_decay)
)
optimizer = nnx.Optimizer(model, optax_chain)

其他優(yōu)化器

https://optax.readthedocs.io/en/latest/api/optimizers.html

調(diào)度計劃

https://optax.readthedocs.io/en/latest/api/optimizer_schedules.html

最后,我們需要創(chuàng)建一個簡單的訓(xùn)練循環(huán)。

while True:
  input_batch, target_batch =get_batch("train")


 train_step(
    model,
    optimizer,
    train_metrics,
    jax.device_put(
      (input_batch, target_batch),
     NamedSharding(mesh,P("batch", None)),
    ),
  )


  step +=1
  if step > max_steps:
    break

請注意我們使用jax.device_put函數(shù)沿著 batch 軸對輸入數(shù)據(jù)進行分區(qū)。在這種情況下,JAX 將啟用數(shù)據(jù)并行,并通過自動插入通信集合 (AllReduce) 將所有內(nèi)容整合在一起,同時盡可能多地實現(xiàn)計算與通信的重疊。有關(guān)并行計算更深入的討論,請參閱 JAX 的并行編程入門文檔。

并行編程入門

https://docs.jax.dev/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html#intro-and-a-quick-example

模型此時應(yīng)處于訓(xùn)練狀態(tài),如果使用權(quán)重和偏差來跟蹤運行情況,我們便可以觀察訓(xùn)練損失。以下是訓(xùn)練 GPT-2 124M 模型的測試運行結(jié)果:

df146da8-84bb-11f0-a18e-92fbcf53809c.png

權(quán)重和偏差

https://wandb.ai/site

如果使用 Kaggle TPU v3,訓(xùn)練時間大約為 7 個小時 (我們可以不中斷地使用 Kaggle TPU v3 9 個小時);但如果使用Trillium,訓(xùn)練時間將縮短至約 1.5 個小時 (請注意,Trillium 的每個芯片配備 32G 高帶寬內(nèi)存 (HBM),因此我們可以將批量大小加倍,并將訓(xùn)練步數(shù)減半)。

Trillium

https://cloud.google.com/blog/products/compute/trillium-tpu-is-ga

最終的損失情況與nanoGPT 的損失情況大致相符。我們在編寫此代碼示例時對 nanoGPT 進行了研究。

df270288-84bb-11f0-a18e-92fbcf53809c.png

nanoGPT 的損失情況

https://github.com/karpathy/nanoGPT/tree/master?tab=readme-ov-file#baselines

如果使用 Cloud TPU,我們還可以通過 "tpu-info" 命令 (Cloud TPU 監(jiān)控調(diào)試包的一部分) 或權(quán)重和偏差儀表盤監(jiān)控 TPU 利用率。我們的 TPU 正在全力運行!

df3f1f4e-84bb-11f0-a18e-92fbcf53809c.png

Cloud TPU 監(jiān)控調(diào)試

https://github.com/AI-Hypercomputer/cloud-tpu-monitoring-debugging

完成模型訓(xùn)練后,我們可以使用Orbax保存模型:

checkpointer = orbax.PyTreeCheckpointer()
train_state = nnx.pure(nnx.state(model))
checkpointer.save(checkpoint_path, train_state)

Orbax

https://github.com/google/orbax

后續(xù)步驟: 探索高級 LLM 訓(xùn)練和擴展

這基本上就是我們訓(xùn)練 GPT-2 模型所需了解的全部內(nèi)容。您可以在完整的Notebook中找到其他詳細信息,如數(shù)據(jù)加載、超參數(shù)、指標等。

Notebook

https://github.com/windmaple/LLM_from_scratch.JAX/tree/main/02.GPT2-pretraining

當(dāng)然,GPT-2 如今還是一個小模型,許多前沿實驗室正在訓(xùn)練擁有數(shù)千億參數(shù)的模型。但是,現(xiàn)在您已經(jīng)學(xué)習(xí)了如何使用 JAX 和 TPU 構(gòu)建小語言模型,為深入了解如何擴展模型做好了準備。

如何擴展模型

https://jax-ml.github.io/scaling-book/

此外,您既可以使用MaxText來訓(xùn)練預(yù)構(gòu)建的前沿 LLM,也可以通過參考JAX LLM 示例或Stanford Marin 模型來學(xué)習(xí)如何從頭開始構(gòu)建最新的模型。

MaxText

https://github.com/AI-Hypercomputer/maxtext

JAX LLM 示例

https://github.com/jax-ml/jax-llm-examples/

Stanford Marin 模型

https://developers.googleblog.com/en/stanfords-marin-foundation-model-first-fully-open-model-developed-using-jax/

我們期待看到您使用 JAX 和 TPU 構(gòu)建的出色模型!

聲明:本文內(nèi)容及配圖由入駐作者撰寫或者入駐合作網(wǎng)站授權(quán)轉(zhuǎn)載。文章觀點僅代表作者本人,不代表電子發(fā)燒友網(wǎng)立場。文章及其配圖僅供工程師學(xué)習(xí)之用,如有內(nèi)容侵權(quán)或者其他違規(guī)問題,請聯(lián)系本站處理。 舉報投訴
  • 模型
    +關(guān)注

    關(guān)注

    1

    文章

    3607

    瀏覽量

    51408
  • 代碼
    +關(guān)注

    關(guān)注

    30

    文章

    4921

    瀏覽量

    72199
  • TPU
    TPU
    +關(guān)注

    關(guān)注

    0

    文章

    160

    瀏覽量

    21430
  • pytorch
    +關(guān)注

    關(guān)注

    2

    文章

    812

    瀏覽量

    14413

原文標題:實戰(zhàn)指南|手把手教您在 TPU 上免費使用 JAX 訓(xùn)練 GPT-2 模型

文章出處:【微信號:Google_Developers,微信公眾號:谷歌開發(fā)者】歡迎添加關(guān)注!文章轉(zhuǎn)載請注明出處。

收藏 人收藏
加入交流群
微信小助手二維碼

掃碼添加小助手

加入工程師交流群

    評論

    相關(guān)推薦
    熱點推薦

    用PaddleNLP在4060單卡實踐大模型預(yù)訓(xùn)練技術(shù)

    手把手教您如何在單張消費級顯卡,利用PaddleNLP實踐OpenAI的GPT-2模型的預(yù)訓(xùn)練。GPT
    的頭像 發(fā)表于 02-19 16:10 ?1733次閱讀
    用PaddleNLP在4060單卡<b class='flag-5'>上</b>實踐大<b class='flag-5'>模型</b>預(yù)<b class='flag-5'>訓(xùn)練</b>技術(shù)

    如何利用Google Colab的云TPU加速Keras模型訓(xùn)練

    TPU包含8個TPU核,每個核都作為獨立的處理單元運作。如果沒有用上全部8個核心,那就沒有充分利用TPU。為了充分加速訓(xùn)練,相比在單GPU
    的頭像 發(fā)表于 11-16 09:10 ?1.1w次閱讀

    OpenAI發(fā)布了一個“逆天”的AI模型——GPT2整個模型包含15億個參數(shù)

    能有這樣出色的表現(xiàn),不是沒有原因的,GPT-2各種特定領(lǐng)域的語言建模任務(wù)中都取得了很好的分數(shù)。作為一個沒有經(jīng)過任何領(lǐng)域數(shù)據(jù)專門訓(xùn)練模型,它的表現(xiàn),比那些專為特定領(lǐng)域數(shù)據(jù)集(例如維基百科,新聞,書籍)
    的頭像 發(fā)表于 03-07 14:45 ?9146次閱讀

    OpenAI發(fā)布一款令人印象深刻的語言模型GPT-2

    今年2月,OpenAI發(fā)布了一款令人印象深刻的語言模型GPT-2,它可以寫短篇小說、詩歌,甚至輕松辨別《哈利波特》和《指環(huán)王》中的角色。最近,一位加拿大工程師用它創(chuàng)建了一個向公眾開放的文本生成器,只需提供一個句子,機器便能自動生
    的頭像 發(fā)表于 05-17 18:48 ?4940次閱讀

    布朗大學(xué)90后研究生:我們復(fù)現(xiàn)了15億參數(shù)GPT-2模型,你也行!

    模型的實現(xiàn)基于Grover模型,并修改其代碼庫以匹配GPT-2的語言建模訓(xùn)練目標。由于他們的模型是在類似的大型語料庫上進行
    的頭像 發(fā)表于 09-01 07:11 ?3711次閱讀

    OpenAI宣布,發(fā)布了7.74億參數(shù)GPT-2語言模型

    就在本周,OpenAI宣布,發(fā)布了7.74億參數(shù)GPT-2語言模型,15.58億的完整模型也有望于幾個月內(nèi)發(fā)布,并將GPT-2這6個月的進展情況在博客
    的頭像 發(fā)表于 09-01 09:10 ?3400次閱讀

    和AI聊天,自然語言模型 GPT-2可能會推出個人信息

    Stroudsburg……” 自然語言模型 GPT-2就像是收到了某種暗號,立刻“送出”一套 個人信息:姓名、電話號碼,還有地址、郵箱和傳真 (部分信息已打碼)。 這可不是GPT-2瞎編的,而是真實存在的個人信息!這些個人信息
    的頭像 發(fā)表于 01-02 09:22 ?2872次閱讀

    GPT系列的“高仿” 最大可達GPT-3大小 自主訓(xùn)練

    雖然GPT-3沒有開源,卻已經(jīng)有人在復(fù)刻GPT系列的模型了。 例如,慕尼黑工業(yè)大學(xué)的Connor Leahy,此前用200個小時、6000RMB,復(fù)現(xiàn)了GPT-2。 又例如,基于150
    的頭像 發(fā)表于 02-13 09:24 ?3187次閱讀

    使用NVIDIA TensorRT優(yōu)化T5和GPT-2

    在這篇文章中,我們向您介紹了如何將擁抱臉 PyTorch T5 和 GPT-2 模型轉(zhuǎn)換為優(yōu)化的 TensorRT 推理引擎。 TensorRT 推理機用作原始 HuggingFace T5
    的頭像 發(fā)表于 03-31 17:25 ?4429次閱讀
    使用NVIDIA TensorRT優(yōu)化T5和<b class='flag-5'>GPT-2</b>

    基于OpenAI的GPT-2的語言模型ProtGPT2可生成新的蛋白質(zhì)序列

    人類語言與蛋白質(zhì)有很多共同點,至少在計算建模方面。這使得研究團隊將自然語言處理(NLP)的新方法應(yīng)用于蛋白質(zhì)設(shè)計。其中,德國Bayreuth大學(xué)Birte H?cker的蛋白質(zhì)設(shè)計實驗室,描述了基于OpenAI的GPT-2的語言模型ProtGPT
    的頭像 發(fā)表于 09-08 16:24 ?3033次閱讀

    GPT/GPT-2/GPT-3/InstructGPT進化之路

    在預(yù)訓(xùn)練階段,GPT 選擇 transformer 的 decoder 部分作為模型的主要模塊,transformer 是 2017年 google 提出的一種特征抽取模型,
    的頭像 發(fā)表于 03-03 11:14 ?4730次閱讀

    ELMER: 高效強大的非自回歸預(yù)訓(xùn)練文本生成模型

    每個單詞都依賴于輸入文本與之前生成的單詞。自回歸生成模型只建模了前向的單詞依賴關(guān)系,依次生成的結(jié)構(gòu)也使得自回歸模型難以并行化。目前大部分預(yù)訓(xùn)練生成模型均采用自回歸方式,包括
    的頭像 發(fā)表于 03-13 10:39 ?1996次閱讀

    DeepSpeed里面和Zero相關(guān)技術(shù)教程

    和NVMe 分配大規(guī)模Megatron-LM模型 以內(nèi)存為中心的分塊優(yōu)化 提取權(quán)重 ZeRO-Offload概述 訓(xùn)練環(huán)境 在單個 V100 GPU 訓(xùn)練10B的
    的頭像 發(fā)表于 06-12 10:25 ?5423次閱讀
    DeepSpeed里面和Zero相關(guān)技術(shù)教程

    DeepSpeed結(jié)合Megatron-LM訓(xùn)練GPT2模型筆記

    本文基于DeepSpeedExamples倉庫中給出的Megatron相關(guān)例子探索一下訓(xùn)練GPT2模型的流程。主要包含3個部分,第一個部分是基于原始的Megatron如何訓(xùn)練
    的頭像 發(fā)表于 06-19 14:45 ?4541次閱讀
    DeepSpeed結(jié)合Megatron-LM<b class='flag-5'>訓(xùn)練</b><b class='flag-5'>GPT2</b><b class='flag-5'>模型</b>筆記

    用PaddleNLP為GPT-2模型制作FineWeb二進制預(yù)訓(xùn)練數(shù)據(jù)集

    作者:算力魔方創(chuàng)始人/英特爾創(chuàng)新大使劉力 《用PaddleNLP在4060單卡實踐大模型預(yù)訓(xùn)練技術(shù)》發(fā)布后收到讀者熱烈反響,很多讀者要求進一步講解更多的技術(shù)細節(jié)。本文主要針對大語言模型
    的頭像 發(fā)表于 03-21 18:24 ?3355次閱讀
    用PaddleNLP為<b class='flag-5'>GPT-2</b><b class='flag-5'>模型</b>制作FineWeb二進制預(yù)<b class='flag-5'>訓(xùn)練</b>數(shù)據(jù)集