在 AI 領(lǐng)域,文本翻譯、語(yǔ)音識(shí)別、股價(jià)預(yù)測(cè)等場(chǎng)景都離不開(kāi)序列數(shù)據(jù)處理。循環(huán)神經(jīng)網(wǎng)絡(luò)(RNN)作為最早的序列建模工具,開(kāi)創(chuàng)了 “記憶歷史信息” 的先河;而長(zhǎng)短期記憶網(wǎng)絡(luò)(LSTM)則通過(guò)創(chuàng)新設(shè)計(jì),突破了 RNN 的核心局限。今天,我們從原理、梯度推導(dǎo)到實(shí)踐,全面解析這兩大經(jīng)典模型。
一、基礎(chǔ)鋪墊:RNN 的核心邏輯與痛點(diǎn)
RNN 的核心是讓模型 “記住過(guò)去”—— 通過(guò)隱藏層的循環(huán)連接,將前一時(shí)刻的信息傳遞到當(dāng)前時(shí)刻,從而捕捉序列的時(shí)序關(guān)聯(lián)。但這種 “全記憶” 設(shè)計(jì),也埋下了梯度消失的隱患。
1.1 核心結(jié)構(gòu)與參數(shù)

RNN 結(jié)構(gòu)簡(jiǎn)化為 “輸入層 - 隱藏層 - 輸出層”,關(guān)鍵組件如下:
- 輸入:Xt(第 t 時(shí)刻輸入,如文本中的詞向量)
- 隱藏狀態(tài):St(存儲(chǔ)截至 t 時(shí)刻的歷史信息,核心記憶載體)
- 輸出:Ot(第 t 時(shí)刻預(yù)測(cè)結(jié)果,如分類(lèi)標(biāo)簽)
- 共享參數(shù)(所有時(shí)間步復(fù)用):Wx:輸入→隱藏層權(quán)重矩陣(維度:隱藏層維度 × 輸入維度)Ws:隱藏層→自身的循環(huán)權(quán)重矩陣(維度:隱藏層維度 × 隱藏層維度,關(guān)鍵)Wo:隱藏層→輸出層權(quán)重矩陣(維度:輸出維度 × 隱藏層維度)偏置:b?(隱藏層偏置,維度:隱藏層維度 ×1)、b?(輸出層偏置,維度:輸出維度 ×1)
- 激活函數(shù):隱藏層用 tanh(值縮至 [-1,1]),輸出層用 Softmax(分類(lèi))或線性激活(回歸)
1.2 前向傳播:信息如何流動(dòng)?
前向傳播是 “輸入→輸出” 的計(jì)算過(guò)程,每個(gè)時(shí)間步的結(jié)果依賴(lài)前一時(shí)刻的隱藏狀態(tài)(以下基于標(biāo)量簡(jiǎn)化,向量場(chǎng)景邏輯一致):更新隱藏狀態(tài)

當(dāng)前記憶 St由 “當(dāng)前輸入 Xt” 和 “歷史記憶 St-1” 共同決定,tanh 確保狀態(tài)值在合理范圍。計(jì)算輸出

輸出僅依賴(lài)當(dāng)前記憶St,體現(xiàn) “歷史信息已壓縮到St中”。示例:若序列長(zhǎng)度為 3(t=1,2,3),初始狀態(tài) S?=0(無(wú)歷史信息):

1.3 反向傳播(BPTT)與梯度推導(dǎo)
模型訓(xùn)練依賴(lài)時(shí)間反向傳播(BPTT):通過(guò)鏈?zhǔn)椒▌t回溯所有時(shí)間步,計(jì)算損失對(duì)參數(shù)的梯度,再用梯度下降更新參數(shù)。假設(shè)損失函數(shù)為交叉熵?fù)p失 Loss = L (Ot, yt)(yt為 T 時(shí)刻真實(shí)標(biāo)簽),核心是推導(dǎo) Loss 對(duì) Wx、Ws、Wo的梯度。1.3.1 核心梯度推導(dǎo)步驟步驟 1:計(jì)算 Loss 對(duì)輸出Ot的梯度若輸出層用 Softmax 激活 + 交叉熵?fù)p失,對(duì)單個(gè)樣本有:

當(dāng)i=j時(shí),等于:

當(dāng)i≠j時(shí),等于:

所以,softmax函數(shù)的導(dǎo)數(shù)可以表示為:

我們只需要將softmax層的輸出pi,pj代入上面的公式就可以做求導(dǎo)計(jì)算了。在多分類(lèi)任務(wù)中,我們通常使用交叉熵?fù)p失函數(shù)(cross-entropy loss function)來(lái)評(píng)估模型的性能。交叉熵?fù)p失函數(shù)的定義如下:

其中yj是真實(shí)標(biāo)簽的one??ot向量,pj是softmax函數(shù)的輸出。交叉熵?fù)p失函數(shù)的作用是衡量模型的預(yù)測(cè)概率p和真實(shí)標(biāo)簽y之間的差異。交叉熵?fù)p失越小,表示模型的預(yù)測(cè)值越接近真實(shí)的標(biāo)簽。經(jīng)驗(yàn)告訴我們,當(dāng)使用softmax函數(shù)作為輸出層激活函數(shù)時(shí),最好使用交叉熵作為其損失函數(shù),這是因?yàn)榻徊骒睾蛃oftmax函數(shù)的結(jié)合可以簡(jiǎn)化反向傳播的計(jì)算。為了證明這一點(diǎn),我們對(duì)交叉熵函數(shù)求導(dǎo):

其中?pj/?zi就是上文推導(dǎo)的softmax的導(dǎo)數(shù),將其代入式中可得:

所以y是one-hot向量,所以:

最后,化簡(jiǎn)得到的交叉熵函數(shù)的求導(dǎo)公式:

步驟 2:計(jì)算 Loss 對(duì)隱藏狀態(tài) S?的梯度隱藏狀態(tài)St同時(shí)影響當(dāng)前輸出Ot和下一時(shí)刻隱藏狀態(tài) St+1,因此梯度需分兩部分:

拆解導(dǎo)數(shù)項(xiàng):

由St+1=tanh(WxXt+1+WsSt+b1)求導(dǎo):tanh'(x)=1?tanh2(x)因此遞推公式為:

(向量場(chǎng)景需轉(zhuǎn)置)步驟 3:計(jì)算 Loss 對(duì)參數(shù)的梯度對(duì) Wo的梯度:

(向量場(chǎng)景下為外積)對(duì) Wx的梯度:W?在所有時(shí)間步共享,需累加各時(shí)間步貢獻(xiàn):

對(duì)Ws的梯度:同理,Ws的梯度為各時(shí)間步貢獻(xiàn)的累加:

1.3.2 梯度消失的核心原因:累乘衰減
從Ws的梯度公式可見(jiàn),遠(yuǎn)時(shí)刻(如 t=1)對(duì)梯度的貢獻(xiàn)需經(jīng)過(guò)多次 tanh'(Sk)?Ws的累乘(k 從 2 到 T):tanh'(Sk) ∈ [0,1](tanh 導(dǎo)數(shù)特性,最大值為 1,多數(shù)時(shí)刻小于 0.5)|Ws| < 1(為避免數(shù)值爆炸,初始化時(shí)會(huì)限制權(quán)重范圍)導(dǎo)致累乘項(xiàng)隨時(shí)間步指數(shù)級(jí)衰減,例如:若tanh'(Sk)=0.5,|Ws|=0.8,序列長(zhǎng)度T=10,則累乘項(xiàng) =(0.5×0.8)^9≈0.00026,遠(yuǎn)時(shí)刻梯度趨近于 0,模型無(wú)法捕捉長(zhǎng)期依賴(lài)。
突破局限:LSTM 的創(chuàng)新設(shè)計(jì)與梯度推導(dǎo)
1997 年提出的 LSTM,通過(guò)“記憶細(xì)胞 + 門(mén)控機(jī)制”實(shí)現(xiàn) “選擇性記憶”—— 保留重要信息、過(guò)濾噪聲,從根本上緩解梯度消失。
2.1 核心結(jié)構(gòu):三門(mén) + 記憶細(xì)胞

LSTM 的核心是 “記憶細(xì)胞(C?)” 和三個(gè)門(mén)控,分工明確(以下基于標(biāo)量簡(jiǎn)化):
組件 | 功能 | 激活函數(shù) | 參數(shù)(權(quán)重 + 偏置) |
記憶細(xì)胞 Ct | 長(zhǎng)期記憶載體,狀態(tài)平緩更新 | 無(wú) | 依賴(lài)門(mén)控參數(shù) |
遺忘門(mén) ft | 控制保留多少歷史細(xì)胞狀態(tài) Ct-1 | σ(Sigmoid,輸出[0,1]) | Wxf、W?f、bf |
更新門(mén) it | 控制加入多少新信息到 Ct | σ(輸出 [0,1]) | Wxi、W?i、bi |
候選記憶 gt | 生成當(dāng)前時(shí)刻的新候選信息 | tanh(輸出[-1,1]) | Wxg、W?g、bg |
輸出門(mén) ot | 控制 Ct輸出到隱藏狀態(tài) ht的比例 | σ(輸出 [0,1]) | Wxo、W?o、bo |
隱藏狀態(tài) ht | 短期記憶,用于當(dāng)前輸出 | tanh(輸出[-1,1]) | Wyo、bo |
? | 元素相乘 | 無(wú) | 無(wú) |
⊕ | 元素相加 | 無(wú) | 無(wú) |
σ 函數(shù)輸出 [0,1],完美適配 “門(mén)控控制”(1 = 完全保留,0 = 完全過(guò)濾);tanh 確保信息值在合理范圍
2.2 前向傳播:5 步完成記憶更新
LSTM 的前向傳播圍繞 “記憶細(xì)胞更新” 展開(kāi),步驟清晰:遺忘門(mén):決定 “丟什么”ft=σ(Wxf?Xt+W?f??t?1+bf)例:ft=0.9→保留 90% 歷史記憶Ct?1;ft=0.1→過(guò)濾 90% 舊信息。更新門(mén) + 候選記憶:決定 “加什么”更新門(mén):it=σ(Wxi?Xt+W?i?ht?1+bi)(控制新信息的權(quán)重)候選記憶:gt=tanh(Wxg?Xt+W?g?ht?1+bg)(當(dāng)前時(shí)刻的新信息)更新記憶細(xì)胞:“丟舊 + 加新”Ct=Ct?1?ft+gt?it?為對(duì)應(yīng)元素相乘,Ct同時(shí)承載 “長(zhǎng)期歷史Ct?1?ft” 和 “當(dāng)前新信息gt?it”。輸出門(mén):決定 “輸出什么”ot=σ(Wxo?Xt+W?o??t?1+bo)生成隱藏狀態(tài)與輸出ht=ot?tanh (Ct)(tanh 將Ct縮至 [-1,1],再通過(guò)ot過(guò)濾)yt=Wy???t+by(最終預(yù)測(cè)結(jié)果,分類(lèi)任務(wù)需加 Softmax)
2.3 反向傳播與梯度推導(dǎo)
LSTM 的反向傳播仍基于 BPTT,但需同時(shí)更新三門(mén)參數(shù)和記憶細(xì)胞相關(guān)梯度,核心是確保記憶細(xì)胞 C?的梯度穩(wěn)定傳遞。假設(shè)損失 Loss = L (yt,y't)(y't為真實(shí)標(biāo)簽),以下為關(guān)鍵梯度推導(dǎo)。
2.3.1 核心梯度 1:Loss 對(duì)記憶細(xì)胞 C?的導(dǎo)數(shù)
記憶細(xì)胞Ct同時(shí)影響當(dāng)前隱藏狀態(tài)?t和下一時(shí)刻記憶細(xì)胞Ct+1,梯度公式為:

拆解導(dǎo)數(shù)項(xiàng):?Loss/??t:損失對(duì)隱藏狀態(tài)的梯度,由輸出層反向推導(dǎo):

(包含當(dāng)前輸出和下一時(shí)刻四門(mén)的貢獻(xiàn))??t/?Ct=ot?tanh'(Ct)(由?t=ot?tanh (Ct)求導(dǎo))?Ct+1/?Ct=ft+1(由Ct+1=Ct?ft+1+gt+1?it+1求導(dǎo))最終遞推公式:

2.3.2 核心梯度 2:Loss 對(duì)門(mén)控參數(shù)的導(dǎo)數(shù)(以遺忘門(mén)為例)遺忘門(mén)參數(shù)(Wxf、Whf、bf)的梯度需通過(guò)鏈?zhǔn)椒▌t推導(dǎo):先求 Loss 對(duì)遺忘門(mén)輸出ft的梯度:

再求 Loss 對(duì)遺忘門(mén)權(quán)重 Wxf的梯度:

(σ函數(shù)導(dǎo)數(shù)為σ(x)?(1-σ(x)),此處 ft=σ(...),故?ft/?Wxf?ft?(1?ft)?Xt)同理,Loss對(duì)Whf的梯度:

更新門(mén)、輸出門(mén)、候選記憶的參數(shù)梯度推導(dǎo)邏輯一致,最終所有參數(shù)通過(guò)梯度下降(如 Adam 優(yōu)化器)更新。2.3.3 LSTM 如何緩解梯度消失?對(duì)比 RNN 的梯度路徑,LSTM 的記憶細(xì)胞梯度傳遞具有決定性?xún)?yōu)勢(shì):從?Loss?Ct的遞推公式可見(jiàn),當(dāng)模型需要保留長(zhǎng)期信息時(shí),會(huì)通過(guò)參數(shù)學(xué)習(xí)使遺忘門(mén)ft+1≈1,此時(shí):

由于 tanh'(Ct)∈[0,1],ot∈[0,1],但核心是?Loss/?Ct+1直接傳遞到?Loss/?Ct,無(wú)指數(shù)級(jí)衰減。即使序列長(zhǎng)度達(dá)到 100+,遠(yuǎn)時(shí)刻(如 t=1)的梯度仍能穩(wěn)定傳遞到當(dāng)前時(shí)刻(如 t=100),從而有效捕捉長(zhǎng)期依賴(lài)。
關(guān)鍵補(bǔ)充:模型如何 “學(xué)習(xí)” 讓ft+1≈1?
遺忘門(mén)ft+1的輸出由以下公式?jīng)Q定:ft+1=σ(Wxf?Xt+1+W?f??t+bf)
其中σ是 Sigmoid 函數(shù),當(dāng)輸入值>2 時(shí),σ(x)≈0.95(接近1)。模型通過(guò)以下兩種方式學(xué)習(xí)讓ft+1≈1:
初始化階段:設(shè)置遺忘門(mén)偏置 bf>0
工程實(shí)踐中,會(huì)將遺忘門(mén)的偏置bf初始化為1~2(而非默認(rèn)0),此時(shí)即使Wxf?Xt+1+W?f??t=0,ft+1=σ(bf)≈0.73(已較高),為后續(xù)學(xué)習(xí) “保留長(zhǎng)期信息” 奠定基礎(chǔ)。訓(xùn)練階段:通過(guò)損失反向調(diào)整參數(shù)當(dāng)模型因 “未保留遠(yuǎn)時(shí)刻信息” 導(dǎo)致 Loss 升高時(shí),反向傳播會(huì)調(diào)整Wxf、W?f、bf的取值:若 t=1 的信息對(duì) t=100 的預(yù)測(cè)很重要,但當(dāng)前f2=0.1(過(guò)濾了 t=1 的信息),則 Loss 會(huì)增大;反向傳播時(shí),?Loss/?f2為正值(增加f2可降低 Loss),進(jìn)而通過(guò)?Loss/?Wf調(diào)整權(quán)重,使f2增大;反復(fù)迭代后,模型會(huì)學(xué)習(xí)到 “對(duì)重要的長(zhǎng)期信息,讓ft+1≈1。
三、RNN vs LSTM:怎么選?
兩大模型各有優(yōu)劣,需結(jié)合場(chǎng)景匹配:
維度 | 循環(huán)神經(jīng)網(wǎng)絡(luò)(RNN) | 長(zhǎng)短期記憶網(wǎng)絡(luò)(LSTM) |
記憶能力 | 僅短期依賴(lài) | 長(zhǎng)期依賴(lài)(序列長(zhǎng)度 100+) |
梯度問(wèn)題 | 易出現(xiàn)梯度消失,遠(yuǎn)時(shí)刻信息丟失 | 記憶細(xì)胞梯度穩(wěn)定,緩解梯度消失 |
模型復(fù)雜度 | 低(僅 3 組核心參數(shù):W?、W?、W?) | 高(9 組核心參數(shù):3 門(mén) ×3 組權(quán)重 + 輸出層權(quán)重) |
參數(shù)數(shù)量 | 少(如隱藏層維度 H=128,輸入維度 D=64,參數(shù)量≈1282+128×64=24576) | 多(同上述維度,參數(shù)量≈4×(1282+128×64)=98304,約為 RNN 的 4 倍) |
計(jì)算效率 | 快(前向 / 反向傳播步驟少) | 慢(門(mén)控計(jì)算多) |
訓(xùn)練難度 | 低(參數(shù)少,收斂快,易實(shí)現(xiàn)) | 高(參數(shù)多,易過(guò)擬合,需更多數(shù)據(jù)和正則化) |
核心優(yōu)勢(shì) | 結(jié)構(gòu)簡(jiǎn)單、訓(xùn)練速度快、資源占用低 | 魯棒性強(qiáng)、長(zhǎng)期依賴(lài)捕捉能力突出、任務(wù)精度高 |
四、工程實(shí)踐小貼士4.1 模型選擇策略先簡(jiǎn)后繁:先用 RNN 驗(yàn)證短序列任務(wù)可行性,若精度不達(dá)標(biāo)(如測(cè)試集準(zhǔn)確率 < 85%),再替換為 LSTM;折中方案:若 LSTM 計(jì)算壓力大,可選用 GRU(門(mén)控循環(huán)單元)—— 簡(jiǎn)化為重置門(mén)和更新門(mén) 2 個(gè)門(mén),參數(shù)量比 LSTM 少 25%,性能接近 LSTM;
數(shù)據(jù)適配:若序列長(zhǎng)度差異大(如文本長(zhǎng)度 5-200 詞),可采用 “截?cái)?+ 填充”(固定序列長(zhǎng)度)或 “動(dòng)態(tài)批處理”(同批次序列長(zhǎng)度一致)。4.2 LSTM 性能優(yōu)化技巧參數(shù)裁剪:隱藏層維度從 256 降至 128,參數(shù)量減少 75%,訓(xùn)練速度提升 2-3 倍;序列分段:將長(zhǎng)序列(如 1000 幀音頻)拆分為 10 個(gè) 100 幀子序列,采用 “滾動(dòng)預(yù)測(cè)” 拼接結(jié)果;量化訓(xùn)練:將 32 位浮點(diǎn)數(shù)參數(shù)轉(zhuǎn)為 16 位半精度,顯存占用減少 50%,推理速度提升 1.5 倍;正則化:添加 Dropout(隱藏層 dropout 率 0.2-0.5)、L2 正則化(權(quán)重衰減系數(shù) 1e-4),緩解過(guò)擬合。
4.3 常見(jiàn)問(wèn)題排查
問(wèn)題現(xiàn)象 | 可能原因 | 解決方案 |
訓(xùn)練 loss 不下降 | 1. 學(xué)習(xí)率過(guò)高 / 過(guò)低2. 梯度消失(LSTM 遺忘門(mén)ft過(guò)?。?/p> | 1. 調(diào)整學(xué)習(xí)率(如 Adam優(yōu)化器默認(rèn) 0.001,可嘗試0.0001-0.01)2. 初始化遺忘門(mén)偏置bf為1-2(使ft初始值接近 1) |
測(cè)試集 loss 波動(dòng)大 | 1. 數(shù)據(jù)量不足2. 序列長(zhǎng)度分布不均 | 1. 數(shù)據(jù)增強(qiáng)(如文本同義詞替換、時(shí)序數(shù)據(jù)加噪)2. 按序列長(zhǎng)度分組訓(xùn)練,平衡各長(zhǎng)度樣本占比 |
總結(jié)RNN 作為序列建模的 “基石”,以簡(jiǎn)單的循環(huán)結(jié)構(gòu)開(kāi)創(chuàng)了歷史信息復(fù)用的思路,但受限于梯度消失無(wú)法處理長(zhǎng)序列;LSTM 則通過(guò)記憶細(xì)胞和門(mén)控機(jī)制的創(chuàng)新,從梯度傳遞路徑上解決了長(zhǎng)期依賴(lài)問(wèn)題,成為長(zhǎng)序列任務(wù)的經(jīng)典方案。盡管當(dāng)前 Transformer(如 BERT、GPT)在多數(shù)序列任務(wù)中表現(xiàn)更優(yōu),但 RNN 和 LSTM 的核心思想(時(shí)序關(guān)聯(lián)捕捉、選擇性記憶)仍是理解復(fù)雜序列模型的基礎(chǔ),也是 AI 工程師在資源受限場(chǎng)景下的重要選擇。你在項(xiàng)目中用過(guò) RNN 或 LSTM 嗎?遇到過(guò)哪些訓(xùn)練難題?歡迎在評(píng)論區(qū)分享你的實(shí)踐經(jīng)驗(yàn)!
本文轉(zhuǎn)自:秦芯智算
-
神經(jīng)網(wǎng)絡(luò)
+關(guān)注
關(guān)注
42文章
4829瀏覽量
106806 -
rnn
+關(guān)注
關(guān)注
0文章
92瀏覽量
7301 -
LSTM
+關(guān)注
關(guān)注
0文章
63瀏覽量
4296
發(fā)布評(píng)論請(qǐng)先 登錄

一文讀懂LSTM與RNN:從原理到實(shí)戰(zhàn),掌握序列建模核心技術(shù)
評(píng)論