只要網(wǎng)絡(luò)足夠?qū)挘?a href="http://www.brongaenegriffin.com/v/tag/448/" target="_blank">深度學(xué)習(xí)動(dòng)態(tài)就能大大簡(jiǎn)化,并且更易于理解。
最近的許多研究結(jié)果表明,無(wú)限寬度的DNN會(huì)收斂成一類(lèi)更為簡(jiǎn)單的模型,稱(chēng)為高斯過(guò)程(Gaussian processes)。
于是,復(fù)雜的現(xiàn)象可以被歸結(jié)為簡(jiǎn)單的線性代數(shù)方程,以了解AI到底是怎樣工作的。
所謂的無(wú)限寬度(infinite width),指的是完全連接層中的隱藏單元數(shù),或卷積層中的通道數(shù)量有無(wú)窮多。
但是,問(wèn)題來(lái)了:推導(dǎo)有限網(wǎng)絡(luò)的無(wú)限寬度限制需要大量的數(shù)學(xué)知識(shí),并且必須針對(duì)不同研究的體系結(jié)構(gòu)分別進(jìn)行計(jì)算。對(duì)工程技術(shù)水平的要求也很高。
谷歌最新開(kāi)源的Neural Tangents,旨在解決這個(gè)問(wèn)題,讓研究人員能夠輕松建立、訓(xùn)練無(wú)限寬神經(jīng)網(wǎng)絡(luò)。
甚至只需要5行代碼,就能夠打造一個(gè)無(wú)限寬神經(jīng)網(wǎng)絡(luò)模型。
這一研究成果已經(jīng)中了ICLR 2020。戳進(jìn)文末Colab鏈接,即可在線試玩。
開(kāi)箱即用,5行代碼打造無(wú)限寬神經(jīng)網(wǎng)絡(luò)模型
Neural Tangents 是一個(gè)高級(jí)神經(jīng)網(wǎng)絡(luò) API,可用于指定復(fù)雜、分層的神經(jīng)網(wǎng)絡(luò),在 CPU/GPU/TPU 上開(kāi)箱即用。
該庫(kù)用 JAX編寫(xiě),既可以構(gòu)建有限寬度神經(jīng)網(wǎng)絡(luò),亦可輕松創(chuàng)建和訓(xùn)練無(wú)限寬度神經(jīng)網(wǎng)絡(luò)。
有什么用呢?舉個(gè)例子,你需要訓(xùn)練一個(gè)完全連接神經(jīng)網(wǎng)絡(luò)。通常,神經(jīng)網(wǎng)絡(luò)是隨機(jī)初始化的,然后采用梯度下降進(jìn)行訓(xùn)練。
研究人員通過(guò)對(duì)一組神經(jīng)網(wǎng)絡(luò)中不同成員的預(yù)測(cè)取均值,來(lái)提升模型的性能。另外,每個(gè)成員預(yù)測(cè)中的方差可以用來(lái)估計(jì)不確定性。
如此一來(lái),就需要大量的計(jì)算預(yù)算。
但當(dāng)神經(jīng)網(wǎng)絡(luò)變得無(wú)限寬時(shí),網(wǎng)絡(luò)集合就可以用高斯過(guò)程來(lái)描述,其均值和方差可以在整個(gè)訓(xùn)練過(guò)程中進(jìn)行計(jì)算。
而使用 Neural Tangents ,僅需5行代碼,就能完成對(duì)無(wú)限寬網(wǎng)絡(luò)集合的構(gòu)造和訓(xùn)練。
from neural_tangents import predict, staxinit_fn, apply_fn, kernel_fn = stax.serial( stax.Dense(2048, W_std=1.5, b_std=0.05), stax.Erf(), stax.Dense(2048, W_std=1.5, b_std=0.05), stax.Erf(), stax.Dense(1, W_std=1.5, b_std=0.05))y_mean, y_var = predict.gp_inference(kernel_fn, x_train, y_train, x_test, ‘ntk’, diag_reg=1e-4, compute_cov=True)
上圖中,左圖為訓(xùn)練過(guò)程中輸出(f)隨輸入數(shù)據(jù)(x)的變化;右圖為訓(xùn)練過(guò)程中的不確定性訓(xùn)練、測(cè)試損失。
將有限神經(jīng)網(wǎng)絡(luò)的集合訓(xùn)練和相同體系結(jié)構(gòu)的無(wú)限寬度神經(jīng)網(wǎng)絡(luò)集合進(jìn)行比較,研究人員發(fā)現(xiàn),使用無(wú)限寬模型的精確推理,與使用梯度下降訓(xùn)練整體模型的結(jié)果之間,具有良好的一致性。
這說(shuō)明了無(wú)限寬神經(jīng)網(wǎng)絡(luò)捕捉訓(xùn)練動(dòng)態(tài)的能力。
不僅如此,常規(guī)神經(jīng)網(wǎng)絡(luò)可以解決的問(wèn)題,Neural Tangents 構(gòu)建的網(wǎng)絡(luò)亦不在話下。
研究人員在 CIFAR-10 數(shù)據(jù)集的圖像識(shí)別任務(wù)上比較了 3 種不同架構(gòu)的無(wú)限寬神經(jīng)網(wǎng)絡(luò)。
可以看到,無(wú)限寬網(wǎng)絡(luò)模擬有限神經(jīng)網(wǎng)絡(luò),遵循相似的性能層次結(jié)構(gòu),其全連接網(wǎng)絡(luò)的性能比卷積網(wǎng)絡(luò)差,而卷積網(wǎng)絡(luò)的性能又比寬殘余網(wǎng)絡(luò)差。
但是,與常規(guī)訓(xùn)練不同,這些模型的學(xué)習(xí)動(dòng)力在封閉形式下是易于控制的,也就是說(shuō),可以用前所未有的視角去觀察其行為。
對(duì)于深入理解機(jī)器學(xué)習(xí)機(jī)制來(lái)說(shuō),該研究也提供了一種新思路。谷歌表示,這將有助于“打開(kāi)機(jī)器學(xué)習(xí)的黑匣子”。
-
神經(jīng)網(wǎng)絡(luò)
+關(guān)注
關(guān)注
42文章
4814瀏覽量
103578 -
代碼
+關(guān)注
關(guān)注
30文章
4900瀏覽量
70718 -
深度學(xué)習(xí)
+關(guān)注
關(guān)注
73文章
5561瀏覽量
122793
發(fā)布評(píng)論請(qǐng)先 登錄
評(píng)論