作者:LLM-Finder,某廠研究大語言模型和多模態(tài)學(xué)習(xí)
寫這篇文章的動機(jī)
1. 在筆者看來RLHF是LLMs智能的關(guān)鍵之一;
2. 國內(nèi)廠商在這方面投入比較少,目前看起來并沒有很重視;
3. 大家偏向于認(rèn)為ChatGPT的RLHF做法最多的線索來源于InstructGPT,但是InstructGPT原文的描述也挺含糊的,很多東西只能靠猜和結(jié)合開源的實(shí)現(xiàn)來解讀;
4. 通常學(xué)習(xí)強(qiáng)化學(xué)習(xí)所依賴鏈路比較長,筆者希望以最直觀的方式幫助大家通關(guān)。
筆者會分兩篇文章來介紹,第一篇是理論篇,第二篇是實(shí)踐篇。讀者會在第一篇學(xué)習(xí)到PPO的原理和instrcutGPT中的RLHF做法;在第二篇中學(xué)習(xí)到目前影響比較大的開源RLHF實(shí)現(xiàn)。
據(jù)公開可獲得的信息來看,ChatGPT需要有大致三個階段的訓(xùn)練過程,如上圖所示:
1.Pretraining: 在大規(guī)模“無監(jiān)督”的語料上訓(xùn)練,訓(xùn)練任務(wù)是預(yù)測下一個詞。
2.Supervised Fine-Tuning(SFT):在人類標(biāo)注上進(jìn)行微調(diào),所謂人類標(biāo)注就是人類寫Prompt,人類寫答案。然后語言模型學(xué)習(xí)模仿人類是如何作答的。這部分通常要求數(shù)據(jù)集多樣性很好,也因?yàn)闃?biāo)注成本很高,通常量級很小。
3.Reinforcement Learning with human feedback(RLHF):對于同一個Prompt把模型的多個輸出給人類排序,獲取人類偏好標(biāo)注。用人類的偏好標(biāo)注,訓(xùn)練一個reward model。訓(xùn)練得到的reward model會作為PPO算法中的reawrd function,來繼續(xù)優(yōu)化SFT得到的模型。
通常來說,第一步最有資源門檻,第三步最有技術(shù)門檻(同時也需要大量的資源),第二步最簡單。所以目前很多廠商是直接拿了開源的第一步的模型,做SFT,或者continue-pretrain(比較小規(guī)模的無監(jiān)督訓(xùn)練)再SFT。他們PR的時候可能會嘴一句,無需復(fù)雜的RLHF,只需做細(xì)致的微調(diào)也能達(dá)到很好的效果。
后面兩個步驟,通常被視作是人類偏好對齊(alignment),讓模型更好地跟隨人類的指令作回復(fù)。而一些研究發(fā)現(xiàn),對齊后的模型是會有對齊稅的現(xiàn)象的(alignment tax),即在通用能力上會有所下降。
因此,不少人是這樣認(rèn)為的:第一步預(yù)訓(xùn)練得到的模型就已經(jīng)決定了后續(xù)模型的能力上限;后面兩步要做的事情僅僅是在盡可能減少對齊稅的情況下,對齊人類偏好。
這里可以分兩種情況分析:
? SFT過數(shù)據(jù)太多遍了,導(dǎo)致大模型出現(xiàn)遺忘;
?安全性對齊很多模型能回答的問題,強(qiáng)制不讓回答肯定會對模型能力有所牽制。
在筆者看來,某種意義下RL提供了對LLM的response的Global-level的監(jiān)督,在一些需要答案非常精確的場景上,RL可能可以發(fā)揮出更大的威力。這個看法的依據(jù)也很樸素:
1. 比如在coding、數(shù)學(xué)推導(dǎo)等場景,只要response在關(guān)鍵的地方犯了一點(diǎn)點(diǎn)錯給人的感覺就是模型不會,但是SFT的loss可能區(qū)分不出來是犯錯了還是只是寫法風(fēng)格的差異。
2. SFT給定了標(biāo)準(zhǔn)答案,LLM的上限可能會被標(biāo)注者的水平所限制;RLHF則只給定了人類偏好,得到了一定(有可能是很大)程度的解放,模型有可能探索出更高程度的智能。這一點(diǎn)并不是無中生有的想法,在游戲AI領(lǐng)域有太多的驗(yàn)證,即在模仿人類玩法(imitation learning)之后,再用RL訓(xùn)練出來的模型,就是能獲得更高的智能。這里語言模型跟游戲又有多少本質(zhì)的區(qū)別呢。
InstructGPT中的RLHF
這里簡要帶過具體數(shù)據(jù)構(gòu)造和訓(xùn)練細(xì)節(jié),后面會專門有一篇對InstructGPT像素級的解讀。
如前文所述,InstructGPT也是包含3階段的訓(xùn)練,同時我們應(yīng)該注意到,RLHF這一步訓(xùn)練,實(shí)則包含兩步訓(xùn)練:
1. 訓(xùn)練Reward Model(RM);
2. 用Reward Model和SFT Model構(gòu)造Reward Function,基于PPO算法來訓(xùn)練LLM。
數(shù)據(jù)集
SFT、RM和PPO用到的數(shù)據(jù)集數(shù)據(jù)量如下表所示:
注意,上表統(tǒng)計(jì)的是prompts數(shù)量,在RM數(shù)據(jù)中每個prompt,對應(yīng)會有4~9個responses。
在構(gòu)造RM數(shù)據(jù)的時候,作者采集了用戶的prompts,每個prompts包含4~9個模型的輸出,模型的輸出會給標(biāo)注員進(jìn)行排序。
訓(xùn)練Reward Model(RM)
目標(biāo):給pormpt-response pair打分,擬合人類的偏好。
模型:這InstructGPT的paper中,雖然用了1.3B、6B和175B的GPT-3來做實(shí)驗(yàn),但是綜合考慮下,只用6B的模型來訓(xùn)練Reward Model,因?yàn)樽髡甙l(fā)現(xiàn)用175B的模型會不穩(wěn)定。把最后的unembedding層換成一個輸出為scalar的線性層。這里讀者可能會有點(diǎn)混亂,眾所周知,GPT的模型結(jié)構(gòu)是sequence-in,sequence-out的,怎么變成scalar呢?這里文章似乎也沒提到,根據(jù)筆者的判斷和開源實(shí)現(xiàn),推測是直接用最后一個token的輸出接一個linear。
Reward Model的初始化:6B的GPT-3模型在多個公開數(shù)據(jù)((ARC, BoolQ, CoQA, DROP, MultiNLI, OpenBookQA, QuAC, RACE, and Winogrande)上fintune。不過Paper中提到其實(shí)從預(yù)訓(xùn)練模型或者SFT模型開始訓(xùn)練結(jié)果也差不多。
訓(xùn)練:以前的做法是,RM每次比較兩個模型輸出的好壞,做法很簡單類似對比學(xué)習(xí),兩個樣本對應(yīng)兩個類別,RM對這兩個樣本分別輸出兩個得分,拼成一個logits向量;人類標(biāo)注比較好的那個輸出作為label,比如第一個比較好那么label為0,第二個比較好label為1;用cross entropy約束之。
但是作者發(fā)現(xiàn)這么做很容易過擬合;也不高效,因?yàn)槊勘容^一次都要重新過一下reward model。
因此作者的做法是,在一個batch里面,把每個Prompt對應(yīng)的所有的模型輸出,都過一遍Reward model,并把所有兩兩組合都比較一遍。比如一個Prompt有K個模型輸出,那么模型則只需要處理K個樣本就可以一氣兒做(2K)次比較。loss的設(shè)計(jì)如下:
很直觀,其中,x是prompt,yw和yl分別是較好和較差的模型response,rθ(x,y)是Reward Model的輸出。σ在文中似乎沒有解釋,不過根據(jù)公式推斷和開源實(shí)現(xiàn),應(yīng)該是sigmod函數(shù)。
這里要注意一個細(xì)節(jié):在RM訓(xùn)練完之后,會讓RM的輸出減去一個bias,使得reward score在人類寫的答案上(labeler demonstrations)的平均分為0。這里筆者沒找到具體在什么數(shù)據(jù)上統(tǒng)計(jì)的,猜測是在SFT數(shù)據(jù)上做的,如果有讀者知道是怎么做的歡迎指出。
Reinforcement Learning(RL)
直接看需要最大化的目標(biāo)函數(shù)
其中,πΦRL和πSFT分別是正在用RL訓(xùn)練的語言模型和SFT訓(xùn)練得到的模型。
上式中,
第一項(xiàng)期望式是在最大化reward的同時,最小化和SFT模型的per-token KL penalty,可以理解為是一種正則手段,兩者組合成關(guān)于prompt-Responce pair最終的Reward:R(x,y)=rθ(x,y)?βlog(πΦRL(y∣x)/πSFT(y∣x))。per-token KL penalty的好處如下:
1. 充當(dāng)熵紅利(Entropy bonus),鼓勵policy探索并阻止其坍塌為單一模式。
2. 確保策略模型產(chǎn)生的輸出 與 Reward Model在訓(xùn)練期間看到的輸出 不會相差太大,保證Reward的可靠性。僅含這一項(xiàng)就是單純使用了PPO。這里也可以看出來,Reward model的能力可能會成為RLHF的瓶頸。
第二項(xiàng)期望式是可選項(xiàng),注意到它其實(shí)是使用預(yù)訓(xùn)練的數(shù)據(jù)來做跟預(yù)訓(xùn)練同樣的任務(wù)(predict next word),因?yàn)檫@一項(xiàng)的數(shù)據(jù)不是模型生成的其實(shí)跟RL是并行的目標(biāo)。包含這一項(xiàng)的算法稱之為PPO-ptx。
PPO算法
本小節(jié)以最小知識補(bǔ)充為前提,快速介紹PPO,不用犯怵,很簡單而直觀。
通常來說,對于一個強(qiáng)化學(xué)習(xí)模型,會有一個做動作的策略網(wǎng)絡(luò)π,它根據(jù)自己觀測的狀態(tài)(si)做出動作(ai)跟環(huán)境交互,然后會拿到一個即刻的reward(ri), 同時進(jìn)入到下一個狀態(tài)(si+1);策略網(wǎng)絡(luò)再繼續(xù)觀測狀態(tài)si+1做下一個動作ai+1...直到達(dá)到最終狀態(tài)。這樣,策略網(wǎng)絡(luò)和環(huán)境的一系列互動后最終會得到一個軌跡(trajectory):τ=s1,a1,r1,s2,a2,r2,...,sT,aT,rT。
那么,在語言模型的場景下,策略網(wǎng)絡(luò)就是待微調(diào)的LLM,它所能做的動作就是預(yù)測下一個token,它觀測的轉(zhuǎn)狀態(tài)就是預(yù)測下一個token時所能觀測到的context(Prompt+這個token前所生成的所有tokens)。
reward除了最后一個rT等于上文提到的R(x,y)=rθ(x,y)?βlog(πRLΦ(yT∣x,y1,...,yT?1)/πSFT(yT∣x,y1,...,yT?1))
其他的ri=?βlog(πRLΦ(yi∣x,y1,...,yi?1)/πSFT(yi∣x,y1,...,yi?1))。
好,在LLM的場景中,現(xiàn)在可以統(tǒng)一一下符號:s1=x,ai=yi,si=cat(x,y1,y2,...,yi?1),其中x是prompt,yi是第i步蹦的token??吹竭@,了解PPO的同學(xué)基本上就清晰了RLHF具體是怎么做優(yōu)化的了,可以直接跳過下面的科普部分。
因?yàn)镻PO原文是基于Actor-Critic算法做的,Actor-Critic算法是進(jìn)階版的Policy Gradient算法。下面我們從policy gradient到Actor-Critic,再到PPO,幫助RL背景比較弱的讀者串一遍。
Policy Gradient(PG)算法
核心要義:用“Reward”作為權(quán)重,最大化策略網(wǎng)絡(luò)所做出的動作的概率。
偽代碼核心部分一句話的事:
用策略網(wǎng)絡(luò)πθ采樣出一個軌跡,然后根據(jù)即刻得到的rewardrt計(jì)算 discounted rewardRt=∑i=tTγi?tri;用Rt作為權(quán)重,最大化這個軌跡下所采取的動作的概率log(π(at∣st))?Rt,用梯度上升優(yōu)化之。
雖然在強(qiáng)化學(xué)習(xí)算法中對每一步都有一個即時的“reward”,但是每一步對后面的可能狀態(tài)都是有影響的。
即,后面的動作獲取的即時“reward”都能累計(jì)到前面的動作的貢獻(xiàn)。但是直接加上去可能不好,畢竟不是前面的動作直接獲取的reward,但是可以打個折扣再加上去,即乘個小于1的γ。
這里面讀者可能會有個問題:可是不好的動作也要最大化概率嗎?
這里有必要稍微展開一下:
1.Rt也可以是負(fù)的,對負(fù)的Rt那就是最小化動作at的概率,這也是為什么前面提到要對RM的輸出做歸一化的其中一個原因之一。
2.即便Rt都是正的,但只要充分采樣,同一個狀態(tài)下相對的Rt較小的動作也是會被抑制的,因?yàn)橥粋€狀態(tài)下的動作概率求和等于1,此消彼長,只有權(quán)重最大的動作才會得到獎勵。
可是,比如同一個狀態(tài)下,有兩個動作的Rt是正的,但是因?yàn)閯幼鞑蓸颖緛砭秃芟∈璧?,我們很可能不幸運(yùn)采樣到了相對較小的Rt對應(yīng)的動作,而沒有采樣到相對較大的。但因?yàn)樗钦?,這時候當(dāng)前的機(jī)制下,還是會鼓勵這個動作,這樣的話網(wǎng)絡(luò)很容易一直沿著不太好的策略去優(yōu)化。為了解決這個問題,我們引入Actor-Critic算法。
Actor-Critic (AC)算法
核心要義:再增加一個Critic網(wǎng)絡(luò)來構(gòu)造一個Reward baseline,只有獲得的reward比baseline要好才獎勵這個動作,否則抑制它。
Actor指的是策略網(wǎng)絡(luò)πθ;Criticb?目的就是給定一個策略網(wǎng)絡(luò),預(yù)估每個狀態(tài)st,策略網(wǎng)絡(luò)所能拿到期望rewardb?(st)是多少。什么是期望reward,無非就是在狀態(tài)st,對πθ采樣不同的動作at所能獲取的Rt的平均值嘛。我們要選擇的動作當(dāng)然是獲取的reward比平均reward要好的動作,不比baseline好的動作就得抑制它。
觀測上面算法2,其實(shí)對比PG算法就加了兩行:
1. 原來用Reward function來加權(quán),現(xiàn)在用Advantage function來加權(quán)。現(xiàn)在我們把b?(st)當(dāng)作一個baseline方法所能拿到的reward, 用采樣出來的at所拿到的rewardRt減去b?(st)作為最大化當(dāng)前動作概率的權(quán)重:At=Rt?b?(st)。其中 A_t 通常被稱作是Advantage function(或Advantage estimator),即優(yōu)勢函數(shù)。
2.拉近b?(st)和Rt的距離,初學(xué)者對這個可能會費(fèi)解。實(shí)則很好理解,記住b?在做什么,要預(yù)估當(dāng)前策略下Rt的期望,我只要不管三七二十一,每來一個動作的Rt都拉近一下距離,其實(shí)就是在預(yù)估平均值。更一般地:
其實(shí)上面用到的b?,它無非是換了皮的Vπθ?(st)(簡寫成V?(st)),即RL中的重要概念V function:給定策略πθ在st上的期望reward。那么最后一步 T 到達(dá)的state sT通常來講是沒有隨機(jī)性的(比如下棋,最后一個state決定贏輸就是固定的reward;LLM,最后一個token生成完,response確定了,reward也就確定了),因此rT應(yīng)該和V?(sT)相等。
所以我們可以重寫上面的優(yōu)勢函數(shù):
A^t=?V?(st)+rt+γrt+1+?+γT?tV?(sT)
寫成Generalized Advantage Estimation,當(dāng)λ=1 下式等于上式:
A^t=δt+(γλ)δt+1+?+(γλ)T?t+1δT?1
其中,δt=rt+γV?(st+1)?V?(st)是時序差分式(TD error)。
記住這個結(jié)論:這樣我們可以用A^t優(yōu)化πθ,現(xiàn)在我們可以用▽θlog(πθ(at∣st))?A^t來更新策略網(wǎng)絡(luò)了。
PPO Finetuning
上面提到的算法,有一個最嚴(yán)重的弊端是,一個軌跡只用一次就丟掉了??墒?,采樣軌跡通常是很耗時的,對應(yīng)到在LLM場景則需要做推理,眾所周知LLM的推理是比訓(xùn)練費(fèi)勁很多的,它需要一個一個地蹦詞??墒侵苯佑弥暗牟呗圆蓸映鰜淼臉颖緛韮?yōu)化現(xiàn)在的策略網(wǎng)絡(luò)肯定不行,如何合理復(fù)用樣本則是PPO要做的事情。
做法巨簡單,大致可以用這個思想來更新:
定義 動作概率比rt(θ)=πθold(at∣st)πθ(at∣st),用▽θrt(θ)?A^t去梯度上升更新策略網(wǎng)絡(luò),注意這里stat和A^t都是只之前的策略網(wǎng)絡(luò)πθold采樣得到的。這個公式,在筆者看來沒有直觀的解釋,需要一丟丟推導(dǎo),因?yàn)槭强破障蜻@里讀者先承認(rèn)就好了,后面筆者會單開一篇文章再重新梳理一遍。
本質(zhì)上是最大化這個目標(biāo)函數(shù):
但是如果πθ和πθold如果差別太大,就不能用這個式子優(yōu)化了,PPO給出的做法是給rt(θ)卡閾值,太大或太小就不用這一步的樣本更新了:
上面的目標(biāo)函數(shù)可以分類討論進(jìn)行分析,對優(yōu)勢函數(shù)A^t大于0和小于0兩種情況分析,這個目標(biāo)函數(shù)的圖像長這樣:
觀測圖像:
當(dāng)A^t大于0,要提高動作的概率,但是如果概率比之前大比較多了(πθ是πθold的1+?倍),就不提高了
當(dāng)A^t小于0,要減少動作的概率,但是如果概率比之前小比較多了(πθ是πθold的1??倍),就不減少了
偽代碼如下:
科普到此結(jié)束,看到這讀者就可以看懂RLHF的代碼。值得注意的是為了減少讀者負(fù)擔(dān)做了大量的敘述上的簡化,方法上是比較完備的,但是說法上不夠嚴(yán)謹(jǐn)。Again,更詳細(xì)的強(qiáng)化學(xué)習(xí)科普會單開一篇文章。
大語言模型的PPO
稍微整理一下,符號和上面的科普部分不一致,不過應(yīng)該不影響理解
1.現(xiàn)在我們的actor是SFT初始化的LLMπΦRL;
2.為了計(jì)算reward,我們需要兩個凍住參數(shù)網(wǎng)絡(luò),一個RM,一個是凍住的SFT模型πSFT用來計(jì)算KL散度,參考下面兩式子:rT=R(x,y)=rθ(x,y)?βlog(πRLΦ(yT∣x,y1,...,yT?1)/πSFT(yT∣x,y1,...,yT?1))其他步的ri=?βlog(πRLΦ(yi∣x,y1,...,yi?1)/πSFT(yi∣x,y1,...,yi?1));
3.為了執(zhí)行PPO算法,我們需要引入一個估計(jì)V值的網(wǎng)絡(luò)Vη,它初始化來自RM。所以統(tǒng)共,有4個網(wǎng)絡(luò),兩個訓(xùn)練的actorπΦRL和criticVη;兩個用來計(jì)算reward的SFT模型πSFT和RM模型。然后actor初始化來自SFT,critic初始化來自RM。
把這四個網(wǎng)絡(luò),結(jié)合reward的構(gòu)造,帶入到上面提到的PPO算法中,整個過程就比較清晰了。
盜一下DeepSpeed-Chat的圖,圖解如下:
看到這,相信讀者已經(jīng)可以輕易看懂的DeepSpeed-Chat代碼了。??
審核編輯:黃飛
-
ChatGPT
+關(guān)注
關(guān)注
29文章
1585瀏覽量
8710
原文標(biāo)題:PPO算法
文章出處:【微信號:zenRRan,微信公眾號:深度學(xué)習(xí)自然語言處理】歡迎添加關(guān)注!文章轉(zhuǎn)載請注明出處。
發(fā)布評論請先 登錄
拆解大語言模型RLHF中的PPO算法

一文解析BLDC電機(jī)控制算法
用PID算法調(diào)溫的經(jīng)驗(yàn)
C++的G代碼解析算法研究
一文解析機(jī)器學(xué)習(xí)常用35大算法

基于PPO強(qiáng)化學(xué)習(xí)算法的AI應(yīng)用案例
PID算法詳細(xì)解析——基于單片機(jī)

一文解析通信系統(tǒng)的高效正交變量優(yōu)化算法

評論