前言
這篇文章的主要內(nèi)容是,解讀 AlphaTensor 這篇論文的主要思想,如何通過(guò)強(qiáng)化學(xué)習(xí)來(lái)探索發(fā)現(xiàn)更高效的矩陣乘算法。
1、二進(jìn)制加法和乘法
這一節(jié)簡(jiǎn)單介紹一下計(jì)算機(jī)是怎么實(shí)現(xiàn)加法和乘法的。
以 2 + 5 和 2 * 5 為例。
我們知道數(shù)字在計(jì)算機(jī)中是以二進(jìn)制形式表示的。
整數(shù)2的二進(jìn)制表示為:0010
整數(shù)5的二進(jìn)制表示為:0101
1.1、二進(jìn)制加法
二進(jìn)制加法很簡(jiǎn)單,也就是兩個(gè)二進(jìn)制數(shù)按位相加,如下圖所示:
當(dāng)然具體到硬件實(shí)現(xiàn)其實(shí)是包含了異或運(yùn)算和與運(yùn)算,具體細(xì)節(jié)可以閱讀文末參考的資料。
1.2、二進(jìn)制乘法
二進(jìn)制乘法其實(shí)也是通過(guò)二進(jìn)制加法來(lái)實(shí)現(xiàn)的,如下圖所示:
乘法在硬件上的實(shí)現(xiàn)本質(zhì)是移位相加。
對(duì)于二進(jìn)制數(shù)來(lái)說(shuō)乘數(shù)和被乘數(shù)的每一位非0即1。
所以相當(dāng)于乘數(shù)中的每一位從低位到高位,分別和被乘數(shù)的每一位進(jìn)行與運(yùn)算并產(chǎn)生其相應(yīng)的局部乘積,再將這些局部乘積左移一位與上次的和相加。
從乘數(shù)的最低位開(kāi)始:
若為1,則復(fù)制被乘數(shù),并左移一位與上一次的和相加;
若為0,則直接將0左移一位與上一次的和相加;
如此循環(huán)至乘數(shù)的最高位。
從二進(jìn)制乘法的實(shí)現(xiàn)也可以看出來(lái),加法比乘法操作要快。
1.3、用加法替換乘法的簡(jiǎn)單例子
上面這個(gè)公式相信大家都很熟悉了,式子兩邊是等價(jià)的
左邊包含了2次乘法和1次加法(減法也可以看成加法)
右邊則包含了1次乘法和2次加法
可以看到通過(guò)數(shù)學(xué)上的等價(jià)變換,增加了加法的次數(shù)同時(shí)減少了乘法的次數(shù)。
2、矩陣乘算法
對(duì)于兩個(gè)大小分別為 Q x R 和 R x P 的矩陣相乘,通用的實(shí)現(xiàn)就需要 Q * P * R 次乘法操作(輸出矩陣大小 Q x P,總共 Q * P 個(gè)元素,每個(gè)元素計(jì)算需要 R 次乘法操作)。
根據(jù)前面 1.2內(nèi)容可知,乘法比加法慢,所以如果能減少的乘法次數(shù)就能有效加速矩陣乘的運(yùn)算。
2.1、通用矩陣乘算法
首先來(lái)看一下通用的矩陣乘算法:
如上圖所示,兩個(gè)大小為2x2矩陣做乘法,總共需要8次乘法和4次加法。
2.2、Strassen 矩陣乘算法
上圖所示即為 Strassen 矩陣乘算法,和通用矩陣乘算法不一樣的地方是,引入了7個(gè)中間變量 m,只有在計(jì)算這7個(gè)中間變量才會(huì)用到乘法。
簡(jiǎn)單用 c1 驗(yàn)證一下:
可以看到 Strassen 算法總共包含7次乘法和18次加法,通過(guò)數(shù)學(xué)上的等價(jià)變換減少了1次乘法同時(shí)增加了14次加法。
3、AlphaTensor 核心思想解讀
3.1、將矩陣乘表示為3維張量
首先來(lái)看下論文中的一張圖
圖中下方是3維張量,每個(gè)立方體表示3維張量一個(gè)坐標(biāo)點(diǎn)。
其中張量每個(gè)位置的值只能是 0 或者 1,透明的立方體表示 0,紫色的立方體表示 1。
現(xiàn)在將圖簡(jiǎn)化一下,以[a,b,c]這樣的維度順序,將張量以維度a平攤開(kāi),這樣更容易理解:
這個(gè)3維張量怎么理解呢?
比如對(duì)于 c1,我們知道 c1 的計(jì)算需要用到 a1,a2,b1,b3,對(duì)應(yīng)到3維張量就是:
而從上圖可知,對(duì)于兩個(gè) 2 x 2 的矩陣相乘,3維張量大小為 4 x 4 x 4。
一般的,對(duì)于兩個(gè) n x n 的矩陣相乘,3維張量大小為 n^2 x n^2 x n^2。
更一般的,對(duì)于兩個(gè) n x m 和 m x p 的矩陣相乘,3維張量大小為 n*m x m*p x n*p。
然后論文中為了簡(jiǎn)化理解,都是以 n x n 矩陣乘來(lái)講解的,論文中以
表示 n x n 矩陣乘的3維張量,下文中為了方便寫(xiě)作以 Tn 來(lái)表示。
3.2、3維張量分解
然后論文中提出了一個(gè)假設(shè):
如果能將3維張量 Tn 分解為 R 個(gè)秩1的3維張量(R rank-one terms)的和的話,那么對(duì)于任意的 n x n 矩陣乘計(jì)算就只需要 R 次乘法。
如上圖公式所示,就是表示的這個(gè)分解,其中的
就表示的一個(gè)秩1的3維張量,是由 u^(r) 、 v^(r) 和 ?w^(r) 這3個(gè)一維向量做外積得到的。
這具體怎么什么理解呢?我們回去看一下 Strassen 矩陣乘算法:
上圖左邊就是 Strassen 矩陣乘算法的計(jì)算過(guò)程,右邊的 U,V 和 W 3個(gè)矩陣,各自分別對(duì)應(yīng)左邊 U -> a, V -> b 和 W -> m。
具體又怎么理解這三個(gè)矩陣呢?
我們?cè)趫D上加一些標(biāo)注來(lái)解釋,其中 U , V 和 W 矩陣每一列從左到右按順序,就對(duì)應(yīng)上文提到的,u^(r) 、 v^(r) 和 ?w^(r) 這3個(gè)一維向量。
然后矩陣 U 每一列和 [a1,a2,a3,a4] 做內(nèi)積,矩陣 V 每一列和 [b1,b2,b3,b4] 做內(nèi)積,然后內(nèi)積結(jié)果相乘就得到 [m1,m2,m3,m4,m5,m6,m7]了。
最后矩陣 W 每一行和 [m1,m2,m3,m4,m5,m6,m7] 做內(nèi)積就得到 [c1,c2,c3,c4]。
接著再看一下的 U,V 和 W 這三個(gè)矩陣第一列的外積結(jié)果
如下圖所示:
可以看到 U,V 和 W 三個(gè)矩陣每一列對(duì)應(yīng)的外積的結(jié)果就是一個(gè)3維張量,那么這些3維張量全部加起來(lái)就會(huì)得到 Tn 么?下面我們來(lái)驗(yàn)證一下:
?
可以看到這些外積的結(jié)果全部加起來(lái)就恰好等于 Tn:
?
所以也就證實(shí)了開(kāi)頭的假設(shè):
如果能將表示矩陣乘的3維張量 Tn 分解為 R 個(gè)秩1的3維張量(R rank-one terms)的和,那么對(duì)于任意的 n x n 矩陣乘計(jì)算就只需要 R 次乘法。
因此也就很自然的可以想到,如果能找到更優(yōu)的張量分解,也就是讓 R 更小的話,那么就相當(dāng)于找到乘法次數(shù)更小的矩陣乘算法了。
通過(guò)強(qiáng)化學(xué)習(xí)探索更優(yōu)的3維張量分解
將探索3維張量分解過(guò)程變成游戲
論文中是采用了強(qiáng)化學(xué)習(xí)這個(gè)框架,來(lái)探索對(duì)3維張量Tn的更優(yōu)的分解。強(qiáng)化學(xué)習(xí)的環(huán)境是一個(gè)單玩家的游戲(a single-player game, TensorGame)。
首先定義這個(gè)游戲進(jìn)行 t 步之后的狀態(tài)為 St:
然后初始狀態(tài) S0 就設(shè)置為要分解的3維張量 Tn:
?
對(duì)于游戲中的每一步t,玩家(就是本論文提出的 AlphaTensor)會(huì)根據(jù)當(dāng)前的狀態(tài)選擇下一步的行動(dòng),也就是通過(guò)生成新的三個(gè)一維向量從而得到新的秩1張量:
?
接著更新?tīng)顟B(tài) St減去這個(gè)秩1張量:
?
玩家的目標(biāo)就是,讓最終狀態(tài) St=0同時(shí)盡量的減少游戲的步數(shù)。
當(dāng)?shù)竭_(dá)最終狀態(tài) St=0 之后,也就找到了3維張量Tn的一個(gè)分解了:
?
還有些細(xì)節(jié)是,對(duì)于玩家每一步的選擇都是給一個(gè) -1 的分?jǐn)?shù)獎(jiǎng)勵(lì),其實(shí)也很容易理解,也就是玩的步數(shù)越多,獎(jiǎng)勵(lì)越低,從而鼓勵(lì)玩家用更少的步數(shù)完成游戲。
而且對(duì)于一維向量的生成,也做了限制
?
就是生成這些一維向量的值,只限定在比如 [?2,??1,?0,?1,?2] 這5個(gè)離散值之內(nèi)。
AlphaTensor 簡(jiǎn)要解讀
論文中是怎么說(shuō)的,在游戲過(guò)程中玩家 AlphaTensor 是通過(guò)一個(gè)深度神經(jīng)網(wǎng)絡(luò)來(lái)指導(dǎo)蒙特卡洛樹(shù)搜索(MonteCarlo tree search)。關(guān)于這個(gè)蒙特卡洛樹(shù)搜索,我不是很了解這里就不做解讀了,有興趣的讀者可以閱讀文末參考資料。
首先看下深渡神經(jīng)網(wǎng)絡(luò)部分:
?
深度神經(jīng)網(wǎng)絡(luò)的輸入是當(dāng)前的狀態(tài) St也就是需要分解的張量(上圖中的最右邊的粉紅色立方體)。輸出包含兩個(gè)部分,分別是 Policy head 和 Value head。
其中 Policy head 的輸出是對(duì)于當(dāng)前狀態(tài)可以采取的潛在下一步行動(dòng),也就是一維向量(u(t),?v(t),?w(t)) 的候選分布,然后通過(guò)采樣得到下一步的行動(dòng)。
然后 Value head 應(yīng)該是對(duì)于給定的當(dāng)前的狀態(tài) St ,估計(jì)游戲完成之后的最終獎(jiǎng)勵(lì)分?jǐn)?shù)的分布。
接下來(lái)簡(jiǎn)要解讀一下整個(gè)游戲的流程,還有深度神經(jīng)網(wǎng)絡(luò)是如何訓(xùn)練的:
先看流程圖的上方 Acting 那個(gè)方框內(nèi),表示的是用訓(xùn)練好的網(wǎng)絡(luò)做推理玩游戲的過(guò)程。
可以看到最左邊綠色的立方體,也就是待分解的3維張量 Tn變換到粉紅色立方體,論文中提到是作了基的變換,但是這塊感覺(jué)如果不是去復(fù)現(xiàn)就不用了解的那么深入,而且我也沒(méi)去細(xì)看這塊就跳過(guò)吧。
然后從最初待分解的 Tn 開(kāi)始,輸入到神經(jīng)網(wǎng)絡(luò),通過(guò)蒙特卡洛樹(shù)搜索得到秩1張量,然后減去該張量之后,繼續(xù)將相減的結(jié)果輸入到網(wǎng)路中,繼續(xù)這個(gè)過(guò)程直到張量相減的結(jié)果為0。
將游戲過(guò)程記錄下來(lái),就是流程圖最右邊的 Played game。
然后流程圖下方的 Learning 方框表示的就是訓(xùn)練過(guò)程,訓(xùn)練數(shù)據(jù)有兩個(gè)部分,一個(gè)是已經(jīng)玩過(guò)的游戲記錄 Played games buffer 還有就是通過(guò)人工生成的數(shù)據(jù)。
人工怎么生成訓(xùn)練數(shù)據(jù)呢?
論文中提到,盡管張量分解是個(gè) NP-hard 的問(wèn)題,給定一個(gè) Tn 要找其分解很難。但是我們可以反過(guò)來(lái)用秩1張量來(lái)構(gòu)造出一個(gè)待分解的張量嘛!簡(jiǎn)單來(lái)說(shuō)就是采樣R個(gè)秩1張量,然后加起來(lái)就能的到分解的張量了。
因?yàn)閷?duì)于強(qiáng)化學(xué)習(xí)這塊我不是了解的并不深入,所以也就只能作粗淺的解讀。
實(shí)驗(yàn)結(jié)果
最后看一下實(shí)驗(yàn)結(jié)果
表格最左邊一列表示矩陣乘的規(guī)模,最右邊三列表示矩陣乘算法乘法次數(shù)。
第一列表示目前為止,數(shù)學(xué)家找到的最優(yōu)乘法次數(shù)。
第2和3列就是 AlphaTensor 找到的最優(yōu)乘法次數(shù)。
可以看到其中有5個(gè)規(guī)模,AlphaTensor 能找到更優(yōu)的乘法次數(shù)(標(biāo)紅的部分):
兩個(gè) 4 x 4 和 4 x 4 的矩陣乘,AlphaTensor 搜索出47次乘法;
兩個(gè) 5 x 5 和 5 x 5 的矩陣乘,AlphaTensor 搜索出96次乘法;
兩個(gè) 3 x 4 和 4 x 5 的矩陣乘,AlphaTensor 搜索出47次乘法;
兩個(gè) 4 x 4 和 4 x 5 的矩陣乘,AlphaTensor 搜索出63次乘法;
兩個(gè) 4 x 5 和 5 x 5 的矩陣乘,AlphaTensor 搜索出76次乘法;
審核編輯:劉清
評(píng)論