本文轉(zhuǎn)自:Coggle數(shù)據(jù)科學(xué)
在 Transformer 架構(gòu)中,注意力機(jī)制的計(jì)算復(fù)雜度與序列長(zhǎng)度(即文本長(zhǎng)度)呈平方關(guān)系()。這意味著,當(dāng)模型需要處理更長(zhǎng)的文本時(shí)(比如從幾千個(gè)詞到幾萬(wàn)個(gè)詞),計(jì)算時(shí)間和所需的內(nèi)存會(huì)急劇增加。最開(kāi)始的標(biāo)準(zhǔn)注意力機(jī)制存在兩個(gè)主要問(wèn)題:
- 內(nèi)存占用高:模型需要生成一個(gè)巨大的注意力矩陣 (N×N)。這個(gè)矩陣需要被保存在高帶寬內(nèi)存 (HBM)中。對(duì)于長(zhǎng)序列,這很快就會(huì)超出 GPU 的內(nèi)存容量。
- 計(jì)算效率低:標(biāo)準(zhǔn)實(shí)現(xiàn)會(huì)將注意力計(jì)算分解成多個(gè)獨(dú)立的步驟(矩陣乘法、softmax 等)。每一步都需要將數(shù)據(jù)從速度較慢的 HBM 中讀取,計(jì)算后又寫(xiě)回 HBM。這種頻繁的數(shù)據(jù)移動(dòng)(內(nèi)存讀寫(xiě))成為了性能瓶頸,導(dǎo)致 GPU 的計(jì)算單元(如 Tensor Cores)利用率低下。
什么是 FlashAttention?
FlashAttention 使得處理長(zhǎng)達(dá)數(shù)萬(wàn)甚至數(shù)十萬(wàn)個(gè) token 的超長(zhǎng)文本成為可能。這解鎖了新的應(yīng)用場(chǎng)景,例如分析法律文檔、總結(jié)長(zhǎng)篇小說(shuō)或處理整個(gè)代碼庫(kù)。
FlashAttention 使得模型的訓(xùn)練和推理速度更快,尤其是在長(zhǎng)序列場(chǎng)景下。例如,F(xiàn)lashAttention-2 在長(zhǎng)序列上比標(biāo)準(zhǔn)實(shí)現(xiàn)快 10 倍,使得訓(xùn)練成本更低,用戶(hù)體驗(yàn)更好。
最新的 FlashAttention-3 利用了新硬件(如 NVIDIA H100)的 FP8 精度,進(jìn)一步提升了性能,同時(shí)通過(guò)特殊的算法保持了計(jì)算的準(zhǔn)確性,讓模型訓(xùn)練更加高效。
FlashAttention v1
許多研究提出了近似注意力方法,試圖通過(guò)減少計(jì)算量(FLOPs)來(lái)提高效率。然而,這些方法通常忽略了GPU不同層級(jí)內(nèi)存(如高速的片上SRAM和相對(duì)較慢的高帶寬HBM)之間的I/O開(kāi)銷(xiāo),導(dǎo)致它們?cè)趯?shí)際運(yùn)行時(shí)并沒(méi)有帶來(lái)顯著的加速。
FlashAttention的核心思想是I/O感知,即在設(shè)計(jì)算法時(shí),將數(shù)據(jù)在不同層級(jí)內(nèi)存之間的讀寫(xiě)開(kāi)銷(xiāo)考慮在內(nèi)。論文指出,在現(xiàn)代GPU上,計(jì)算速度已經(jīng)遠(yuǎn)超內(nèi)存訪(fǎng)問(wèn)速度,因此大多數(shù)操作都受限于內(nèi)存訪(fǎng)問(wèn)。FlashAttention通過(guò)以下兩個(gè)關(guān)鍵技術(shù)來(lái)解決這一問(wèn)題:
- Tiling (平鋪):將輸入數(shù)據(jù)(Q、K、V矩陣)分割成小塊,并在GPU的片上SRAM中進(jìn)行計(jì)算。這樣可以避免將龐大的 N×N 注意力矩陣完整地寫(xiě)入到速度較慢的HBM中。
- 內(nèi)存優(yōu)化:在反向傳播時(shí),F(xiàn)lashAttention 不存儲(chǔ)巨大的中間注意力矩陣,而是只保存前向傳播中計(jì)算出的Softmax歸一化因子。這樣,反向傳播時(shí)可以利用這些因子在SRAM中快速地重新計(jì)算注意力矩陣,從而避免了從HBM讀取大矩陣的開(kāi)銷(xiāo)。
GPU內(nèi)存層級(jí)
- HBM (高帶寬內(nèi)存):容量大(如A100 GPU的40-80 GB),但速度相對(duì)較慢(帶寬1.5-2.0 TB/s)。
- 片上SRAM (靜態(tài)隨機(jī)存取存儲(chǔ)器):容量?。總€(gè)流式多處理器有192 KB),但速度極快(帶寬估計(jì)達(dá)19 TB/s),比HBM快一個(gè)數(shù)量級(jí)以上。
由于GPU的計(jì)算速度增長(zhǎng)快于內(nèi)存速度,許多操作的性能瓶頸在于內(nèi)存訪(fǎng)問(wèn),而不是計(jì)算本身。因此,如何高效利用快速的SRAM變得至關(guān)重要。
運(yùn)算類(lèi)型
根據(jù)算術(shù)強(qiáng)度(每字節(jié)內(nèi)存訪(fǎng)問(wèn)的算術(shù)運(yùn)算次數(shù)),操作可分為兩類(lèi):
- 計(jì)算密集型 (Compute-bound):運(yùn)算時(shí)間由算術(shù)操作數(shù)量決定,內(nèi)存訪(fǎng)問(wèn)時(shí)間相對(duì)較小。例如,大規(guī)模矩陣乘法。
- 內(nèi)存密集型 (Memory-bound):運(yùn)算時(shí)間由內(nèi)存訪(fǎng)問(wèn)次數(shù)決定,計(jì)算時(shí)間相對(duì)較小。例如,大多數(shù)元素級(jí)操作(如激活函數(shù)、Dropout)和歸約操作(如Softmax、LayerNorm)。
注意力實(shí)現(xiàn)改進(jìn)
給定查詢(xún) Q、鍵 K 和值 V 矩陣,注意力的計(jì)算分三步:
- 相似度計(jì)算:
- Softmax歸一化:
- 加權(quán)求和:
標(biāo)準(zhǔn)實(shí)現(xiàn)(如“Algorithm 0”所示)將每一步都作為一個(gè)獨(dú)立的GPU核函數(shù),并物化(materialize)中間矩陣 S 和 P 到HBM中。
這種實(shí)現(xiàn)方式導(dǎo)致了兩個(gè)主要問(wèn)題:
- 巨大的內(nèi)存占用:中間矩陣 S 和 P 的大小為 N×N,其內(nèi)存占用與序列長(zhǎng)度 N 的平方成正比。
- 大量的HBM訪(fǎng)問(wèn):由于每個(gè)步驟都需要讀寫(xiě)HBM,導(dǎo)致I/O開(kāi)銷(xiāo)巨大。論文指出,這種方法對(duì)HBM的訪(fǎng)問(wèn)次數(shù)是 O(N2) 級(jí)別的,這在長(zhǎng)序列(通常 N?d)時(shí)會(huì)成為主要的性能瓶頸,導(dǎo)致運(yùn)行時(shí)間慢。
FlashAttention旨在減少對(duì)GPU高帶寬內(nèi)存(HBM)的讀寫(xiě),實(shí)現(xiàn)對(duì)確切注意力(exact attention)的快速、內(nèi)存高效的計(jì)算。為此,它采用了兩種關(guān)鍵技術(shù):
- Tiling(分塊):將輸入的 Q,K,V 矩陣分成若干小塊。然后,在計(jì)算過(guò)程中,每次只將一小塊數(shù)據(jù)從慢速的HBM加載到快速的片上SRAM進(jìn)行計(jì)算,而不是一次性加載整個(gè)大矩陣。
- Recomputation(重計(jì)算):為了避免在反向傳播時(shí)存儲(chǔ) O(N2) 的中間注意力矩陣 S 和 P,F(xiàn)lashAttention只存儲(chǔ) Softmax 的歸一化統(tǒng)計(jì)量(即 m 和 ?)。在反向傳播時(shí),它會(huì)利用這些統(tǒng)計(jì)量,按需在SRAM中重新計(jì)算必要的注意力矩陣塊。
通過(guò)Tiling和Recomputation,F(xiàn)lashAttention能夠?qū)⑺杏?jì)算步驟(矩陣乘法、Softmax、可選的遮蔽和Dropout)融合成一個(gè)單一的CUDA核函數(shù)。這避免了在每個(gè)步驟之間反復(fù)地將數(shù)據(jù)寫(xiě)入HBM。
實(shí)現(xiàn)效果
lashAttention在BERT-large模型上的訓(xùn)練速度超過(guò)了MLPerf 1.1的記錄保持者。與Nvidia的實(shí)現(xiàn)相比,F(xiàn)lashAttention的訓(xùn)練時(shí)間縮短了15%,這證明了其在標(biāo)準(zhǔn)長(zhǎng)序列任務(wù)上的卓越性能。
FlashAttention在訓(xùn)練GPT-2模型時(shí),相比于流行的HuggingFace和Megatron-LM實(shí)現(xiàn),實(shí)現(xiàn)了顯著的端到端加速。
- 與Huggingface相比,速度提升高達(dá)3倍。
- 與Megatron-LM相比,速度提升高達(dá)1.7倍。
- 重要的是,F(xiàn)lashAttention在不改變模型定義的情況下,實(shí)現(xiàn)了與基線(xiàn)模型相同的困惑度(perplexity),證明了其數(shù)值穩(wěn)定性。
在Long-Range Arena基準(zhǔn)測(cè)試中,F(xiàn)lashAttention相比于標(biāo)準(zhǔn)的Transformer實(shí)現(xiàn),實(shí)現(xiàn)了2.4倍的加速。此外,塊稀疏FlashAttention的表現(xiàn)甚至優(yōu)于所有已測(cè)試的近似注意力方法,證明了其在處理超長(zhǎng)序列時(shí)的優(yōu)越性。
lashAttention的內(nèi)存占用與序列長(zhǎng)度呈線(xiàn)性關(guān)系,而標(biāo)準(zhǔn)實(shí)現(xiàn)是平方關(guān)系。這使得FlashAttention的內(nèi)存效率比標(biāo)準(zhǔn)方法高出20倍。
FlashAttention v2
第一代FlashAttention通過(guò)利用 GPU 內(nèi)存層次結(jié)構(gòu)的特性,顯著降低了內(nèi)存占用(從二次方降為線(xiàn)性)并實(shí)現(xiàn)了 2-4 倍的加速,且沒(méi)有引入任何近似。
然而,F(xiàn)lashAttention 的效率仍然不如優(yōu)化的矩陣乘法(GEMM)操作,其浮點(diǎn)運(yùn)算性能(FLOPs/s)僅能達(dá)到理論峰值的 25-40%。這主要是因?yàn)?FlashAttention 存在不優(yōu)化的工作劃分(work partitioning),導(dǎo)致 GPU 線(xiàn)程塊(thread blocks)和線(xiàn)程束(warps)之間的并行度不足、占用率低或產(chǎn)生不必要的共享內(nèi)存讀寫(xiě)。
為了解決這些問(wèn)題,論文提出了FlashAttention-2,通過(guò)以下改進(jìn)實(shí)現(xiàn)了更好的工作劃分:
- 減少非矩陣乘法(non-matmul)的浮點(diǎn)運(yùn)算:雖然這類(lèi)操作占總 FLOPs 的比例小,但執(zhí)行起來(lái)很慢。
- 在序列長(zhǎng)度維度上并行化:即使對(duì)于單個(gè)注意力頭,也將其計(jì)算任務(wù)分配給不同的線(xiàn)程塊,以提高 GPU 的占用率。
- 優(yōu)化線(xiàn)程塊內(nèi)部的工作分配:在每個(gè)線(xiàn)程塊內(nèi),重新分配線(xiàn)程束之間的工作,以減少通過(guò)共享內(nèi)存進(jìn)行的通信。
前向傳播改進(jìn)
FlashAttention-2對(duì)在線(xiàn) Softmax 技巧進(jìn)行了兩處微調(diào):
- 延遲歸一化:在每個(gè)循環(huán)迭代中,不立即對(duì)輸出進(jìn)行歸一化。相反,它維護(hù)一個(gè)“未縮放”的中間結(jié)果,并在整個(gè)循環(huán)結(jié)束時(shí)僅進(jìn)行一次最終的歸一化。這減少了每個(gè)塊的縮放操作,從而減少了非 matmul 的 FLOPs。
- 簡(jiǎn)化統(tǒng)計(jì)量:為反向傳播存儲(chǔ)數(shù)據(jù)時(shí),只保存logsumexp統(tǒng)計(jì)量 L(j)=m(j)+log(?(j)),而不是同時(shí)存儲(chǔ)最大值 m(j) 和指數(shù)和 ?(j)。
并行化改進(jìn)
第一代 FlashAttention 僅在批處理大小和注意力頭數(shù)量上進(jìn)行并行化。當(dāng)序列長(zhǎng)度很長(zhǎng)時(shí),批處理大小通常很小,導(dǎo)致 GPU 資源的利用率(occupancy)不高。FlashAttention-2 通過(guò)在序列長(zhǎng)度維度上增加并行化來(lái)解決這個(gè)問(wèn)題。
- 前向傳播:FlashAttention-2 將注意力矩陣的行塊任務(wù)分配給不同的線(xiàn)程塊,這些線(xiàn)程塊之間無(wú)需通信。通過(guò)在行維度上并行,當(dāng)批次大小和注意力頭數(shù)較小時(shí),GPU 的 SM(流式多處理器)能夠被更充分地利用,從而提高整體吞吐量。
- 后向傳播:類(lèi)似地,后向傳播則在注意力矩陣的列塊上進(jìn)行并行。由于反向傳播中的某些更新需要跨線(xiàn)程塊通信,作者使用了原子加法(atomic adds)來(lái)更新共享的梯度 dK 和 dV,確保了線(xiàn)程安全。
除了線(xiàn)程塊級(jí)別的并行,F(xiàn)lashAttention-2 還優(yōu)化了線(xiàn)程塊內(nèi)部線(xiàn)程束之間的工作分配,以減少共享內(nèi)存的讀寫(xiě)。
- 前向傳播:
- FlashAttention:采用“split-K”方案,將 K 和 V 矩陣的計(jì)算任務(wù)分配給不同的線(xiàn)程束。這要求所有線(xiàn)程束將中間結(jié)果寫(xiě)入共享內(nèi)存,再進(jìn)行同步和求和,導(dǎo)致不必要的共享內(nèi)存訪(fǎng)問(wèn)。
- FlashAttention-2:改為將 Q 矩陣的計(jì)算任務(wù)分配給不同的線(xiàn)程束。每個(gè)線(xiàn)程束負(fù)責(zé)計(jì)算 Q 的一個(gè)分片與完整的 K 的乘積。這樣,每個(gè)線(xiàn)程束可以獨(dú)立地完成其部分輸出,而無(wú)需與其他線(xiàn)程束進(jìn)行共享內(nèi)存通信,從而顯著提高了效率。
- 后向傳播:后向傳播的依賴(lài)關(guān)系更復(fù)雜,但 FlashAttention-2 仍然通過(guò)避免“split-K”方案來(lái)減少共享內(nèi)存的讀寫(xiě),實(shí)現(xiàn)了性能提升。
實(shí)現(xiàn)效果
FlashAttention-2 比第一代 FlashAttention 快1.7-3.0 倍,比 Triton 實(shí)現(xiàn)的 FlashAttention 快1.3-2.5 倍。
在 A100 GPU 上,F(xiàn)lashAttention-2 在前向傳播中達(dá)到了230 TFLOPs/s的峰值,相當(dāng)于理論最大吞吐量的73%。在后向傳播中,它達(dá)到了理論最大吞吐量的 63%。
FlashAttention v3
雖然之前的 FlashAttention 通過(guò)減少內(nèi)存讀寫(xiě)來(lái)加速計(jì)算,但它未能充分利用現(xiàn)代硬件(如 Hopper GPU)的新特性。例如,F(xiàn)lashAttention-2 在 H100 GPU 上的利用率僅為 35%。
與 FlashAttention-2 類(lèi)似,F(xiàn)lashAttention-3 也將任務(wù)并行化到不同的線(xiàn)程塊(CTA),但其創(chuàng)新之處在于在單個(gè)線(xiàn)程塊內(nèi)部,將線(xiàn)程束(warps)劃分為不同的角色。
- 生產(chǎn)者(Producer):負(fù)責(zé)將數(shù)據(jù)從 HBM(全局內(nèi)存)異步加載到 SMEM(共享內(nèi)存)。
- 消費(fèi)者(Consumer):在數(shù)據(jù)加載完成后,從 SMEM 讀取數(shù)據(jù)并執(zhí)行計(jì)算。
生產(chǎn)者和消費(fèi)者通過(guò)一個(gè)循環(huán)緩沖區(qū)(circular buffer)進(jìn)行同步。生產(chǎn)者將數(shù)據(jù)放入緩沖區(qū),消費(fèi)者從中取出。當(dāng)緩沖區(qū)中的一個(gè)“階段”被消費(fèi)后,生產(chǎn)者就可以繼續(xù)向其中加載新數(shù)據(jù)。
線(xiàn)程內(nèi)部的 GEMM 和 Softmax 重疊
在標(biāo)準(zhǔn) FlashAttention 中,GEMM 和 Softmax 存在順序依賴(lài):Softmax 必須在第一個(gè) GEMM 計(jì)算完成后才能開(kāi)始,而第二個(gè) GEMM 必須等待 Softmax 的結(jié)果。
FlashAttention-3 通過(guò)在寄存器中使用額外的緩沖區(qū),打破了這種依賴(lài)關(guān)系。在每次循環(huán)中,它異步啟動(dòng)下一個(gè) GEMM 的計(jì)算,而同時(shí)執(zhí)行當(dāng)前 GEMM 結(jié)果的 Softmax 和更新操作。這樣,GEMM 和 Softmax 的執(zhí)行就可以重疊,提高了效率。
FP8 低精度計(jì)算
FP8 的 WGMMA(Warp Group Matrix-Multiply-Accumulate)指令要求輸入矩陣具有特定的k-major 布局,而輸入張量通常是mn-major 布局。
FlashAttention-3 選擇在 GPU 內(nèi)核中(in-kernel)進(jìn)行轉(zhuǎn)置。它利用 LDSM/STSM 指令,這些指令能夠高效地在 SMEM 和 RMEM(寄存器)之間進(jìn)行數(shù)據(jù)傳輸,并在傳輸過(guò)程中完成布局轉(zhuǎn)置,避免了代價(jià)高昂的 HBM 讀寫(xiě)。
同于傳統(tǒng)的逐張量(per-tensor)量化,F(xiàn)lashAttention-3 對(duì)每個(gè)塊進(jìn)行單獨(dú)量化。這使得每個(gè)塊可以有自己的縮放因子,從而更有效地處理離群值,減少量化誤差。
實(shí)現(xiàn)效果
FlashAttention-3 的前向傳播速度比 FlashAttention-2 快1.5-2.0 倍,后向傳播快1.5-1.75 倍。FP16 版本的 FlashAttention-3 達(dá)到了740 TFLOPs/s的峰值,相當(dāng)于 H100 GPU 理論最大吞吐量的 **75%**。
在處理中長(zhǎng)序列(1k 及以上)時(shí),F(xiàn)lashAttention-3 的性能甚至超過(guò)了 NVIDIA 自家閉源、針對(duì) H100 優(yōu)化的cuDNN庫(kù)。
-
內(nèi)存
+關(guān)注
關(guān)注
8文章
3156瀏覽量
75882 -
人工智能
+關(guān)注
關(guān)注
1811文章
49498瀏覽量
258209 -
大模型
+關(guān)注
關(guān)注
2文章
3348瀏覽量
4718
發(fā)布評(píng)論請(qǐng)先 登錄
AD的3D模型繪制功能介紹
MRAS模型和可調(diào)模型參考
壓縮模型會(huì)加速推理嗎?
3D模型基礎(chǔ)

小白開(kāi)始學(xué)RTOS 1

如何改進(jìn)和加速擴(kuò)散模型采樣的方法1

自動(dòng)駕駛車(chē)輛控制(車(chē)輛運(yùn)動(dòng)學(xué)模型)

加速度傳感器的基本力學(xué)模型是什么
寫(xiě)給小白的大模型入門(mén)科普

小白學(xué)大模型:訓(xùn)練大語(yǔ)言模型的深度指南

小白學(xué)大模型:從零實(shí)現(xiàn) LLM語(yǔ)言模型

小白學(xué)大模型:國(guó)外主流大模型匯總

評(píng)論