NAS最近也很火,正好看到了這篇論文,解讀一下,這篇論文是基于DAG(directed acyclic graph)的,DAG包含了上億的 sub-graphs, 為了防止全部遍歷這些模型,這篇論文設(shè)計了一種全新的采樣器,這種采樣器叫做Gradient-based search suing differential Architecture Sampler(GDAS),該采樣器可以自行學(xué)習(xí)和優(yōu)化,在這個的基礎(chǔ)上,在CIFAR-10上通過4 GPU hours就能找到一個最優(yōu)的網(wǎng)絡(luò)結(jié)構(gòu)。
目前主流的NAS一般是基于進(jìn)化算法(EA)和強(qiáng)化學(xué)習(xí)(RL)來做的。EA通過權(quán)衡validation accuracy來決定是否需要移除一個模型,RL則是validation accuracy作為獎勵來優(yōu)化模型生成。作者認(rèn)為這兩種方法都很消耗計算資源。作者這篇論文中設(shè)計的GDAS方法可以在一個單v100 GPU上,用四小時搜索到一個優(yōu)秀模型。
GDAS
這個采用了搜索robust neural cell來替代搜索整個網(wǎng)絡(luò)。如下圖,不同的操作(操作用箭頭表示)會計算出不同的中間結(jié)果(中間結(jié)果用cycle表示),前面的中間結(jié)果會加起來闖到后面。
在優(yōu)化速度上,傳統(tǒng)的DAG存在一些問題:基于RL和EA的方法,需要獲得反饋都需要很長一段時間。而這篇論文提出的GDAS方法能夠利用梯度下降去做優(yōu)化,具體怎么梯的下面會說到。此外,使用GDAS的方法可以sample出sub-graph,這意味著計算量要比DAG的方法小很多。
絕大多數(shù)的NAS方法可以歸為兩類:Macro search和micro search
Macro search
顧名思義,實(shí)際上算法的目的是想要發(fā)現(xiàn)一個完整的網(wǎng)絡(luò)結(jié)構(gòu)。因此多會采用強(qiáng)化學(xué)習(xí)的方式?,F(xiàn)有的方法很多都是使用Q-learning的方法來學(xué)習(xí)的。那么會存在的問題是,需要搜索的網(wǎng)絡(luò)數(shù)量會呈指數(shù)級增長。最后導(dǎo)致的結(jié)果就是網(wǎng)絡(luò)會更淺。
Micro Search
這種不是搜索整個神經(jīng)網(wǎng)絡(luò),而是搜索neural cells的方式。找到指定的neural cells后,再去堆疊。這種設(shè)計方式雖然能夠設(shè)計更深的網(wǎng)絡(luò),但是依舊要消耗很長時間,比如100GPU days,超長。這篇文章就是在消耗上面做優(yōu)化。
算法原理
DAG的搜索空間
前面也說了DAG是通過搜索所謂的neural cell而不是搜索整個網(wǎng)絡(luò)。每個cell由多個節(jié)點(diǎn)和節(jié)點(diǎn)間的激活函數(shù)構(gòu)成。節(jié)點(diǎn)我們用來表示,節(jié)點(diǎn)的計算如下圖。每個節(jié)點(diǎn)有其余兩個節(jié)點(diǎn)(下面公式中的節(jié)點(diǎn)i和節(jié)點(diǎn)j)來生成,而中間會從一個函數(shù)集合中去sample函數(shù)出來, 這個F數(shù)據(jù)集的組成是1)恒等映射 2)歸零 3)3x3 depthwise分離卷積 4)3x3 dilated depthwise 分離卷積 5)5x5 depthwise分離卷積 6)5x5 dilated depthwise 分離卷積。7)3x3平均池化 8) 3 x 3 最大池化。
那么生成節(jié)點(diǎn)I后,再去生成對應(yīng)的cell。我們將cell的節(jié)點(diǎn)數(shù)記為B,以B=4為例,該cell實(shí)際上會包括7個節(jié)點(diǎn),是前面兩層的cell的輸出(實(shí)際上也就是上面公式中的k和j),而則是我們(1)中計算出來的結(jié)果。也就是該cell的output tensor實(shí)際上是四個節(jié)點(diǎn)的output的聯(lián)結(jié)。
將cell組裝為網(wǎng)絡(luò)
剛剛上面的這種叫做normal cell,作者還設(shè)計了一個reduction cell, 用于下采樣。這個reduction cell就是手動設(shè)計的了,沒有像normal cell那樣復(fù)雜。normal cell 的步長為1,reduction cell步長為2, 最后的網(wǎng)絡(luò)實(shí)際上就是由這些cell組裝起來的。如下圖:
搜索模型參數(shù)
搭建的工作如上面所示,好像也還好,就像搭積木,這篇論文我覺得創(chuàng)新的地方在于它的搜索方法,特別是通過梯度下降的方式來更新參數(shù),很棒。具體的搜索參數(shù)環(huán)節(jié),它是這么做的:
首先我們的優(yōu)化目標(biāo)和手工設(shè)計的網(wǎng)絡(luò)別無二致,都是最大釋然估計:
而上式中的Pr,實(shí)際上可寫成:
這個實(shí)際上是node i和node j的函數(shù)分布,k則是F的基數(shù)。而Node可以表示為:
是從中sample出來的,而
這個實(shí)際上是node i和node j的函數(shù)分布,k則是F的基數(shù)。而Node可以表示為:
其中是從離散分布中間sample出來的函數(shù)。這里問題來了,如果直接去優(yōu)化Pr,這里由于I是來自于一個離散分布,沒法對離散分布使用梯度下降方法。這里,作者使用了Gumbel-Max trick來解決離散分布中采樣不可微的問題,具體可以看這個問題下的回答
如何理解Gumbel-Max trick?
TL;DR: Gumbel Trick 是一種從離散分布取樣的方法,它的形式可以允許我們定義一種可微分的,離散分布的近似取樣,這種取樣方式不像「干脆以各類概率值的概率向量替代取樣」這么粗糙,也不像直接取樣一樣不可導(dǎo)(因此沒辦法應(yīng)對可能的 bp )。
于是這里將這個離散分布不可微的問題做了轉(zhuǎn)移,同時對應(yīng)的優(yōu)化目標(biāo)變?yōu)椋?/p>
這里有個的參數(shù),可以控制的相似程度。注意在前向傳播中我們使用的是等式(5), 而在反向傳播中,使用的是等式(7)。結(jié)合以上內(nèi)容,我們模型的loss是:
我們將最后學(xué)習(xí)到的網(wǎng)絡(luò)結(jié)構(gòu)稱為A,每一個節(jié)點(diǎn)由前面T個節(jié)點(diǎn)連接而來,在CNN中,我們把T設(shè)為2, 在RNN中,T設(shè)為1
在參數(shù)上,作者使用了SGD,學(xué)習(xí)率從0.025逐漸降到1e-3,使用的是cosine schedule。具體的參數(shù)和function F 設(shè)計上,可以去看看原論文。
總的來說,我覺得這篇論文最大的創(chuàng)新點(diǎn)是使用Gumbel-Max trick來使得搜索過程可微分,當(dāng)然它中間也使用了一些手動設(shè)計的模塊(如reduction cell),所以速度會比其余的NAS更快,之前我也沒有接觸過NAS, 看完這篇論文后對現(xiàn)在的NAS常用的方法以及未來NAS發(fā)展的趨勢還是有了更深的理解,推薦看看原文。
-
神經(jīng)網(wǎng)絡(luò)
+關(guān)注
關(guān)注
42文章
4819瀏覽量
106068 -
gpu
+關(guān)注
關(guān)注
28文章
5036瀏覽量
133739 -
強(qiáng)化學(xué)習(xí)
+關(guān)注
關(guān)注
4文章
269瀏覽量
11820
原文標(biāo)題:單v100 GPU,4小時搜索到一個魯棒的網(wǎng)絡(luò)結(jié)構(gòu)
文章出處:【微信號:rgznai100,微信公眾號:rgznai100】歡迎添加關(guān)注!文章轉(zhuǎn)載請注明出處。
發(fā)布評論請先 登錄
自動駕駛中常提的魯棒性是個啥?

DVB-H網(wǎng)絡(luò)結(jié)構(gòu)
特斯拉V100 Nvlink是否支持v100卡的nvlink變種的GPU直通?
神經(jīng)網(wǎng)絡(luò)結(jié)構(gòu)搜索有什么優(yōu)勢?
備貨Hi3519A V100 4K智能IP攝像頭SoC使用手冊分享
TD-SCDMA R4網(wǎng)絡(luò)結(jié)構(gòu)和技術(shù)要求
環(huán)形網(wǎng)絡(luò),環(huán)形網(wǎng)絡(luò)結(jié)構(gòu)是什么?
4G網(wǎng)絡(luò)結(jié)構(gòu)及關(guān)鍵技術(shù)

魯棒性是什么意思_Robust為什么翻譯成魯棒性

一種改進(jìn)的深度神經(jīng)網(wǎng)絡(luò)結(jié)構(gòu)搜索方法

基于YOLO-V5的網(wǎng)絡(luò)結(jié)構(gòu)及實(shí)現(xiàn)行人社交距離風(fēng)險提示
物聯(lián)網(wǎng)行業(yè)通用主板—卓越V100

評論