可逆神經(jīng)網(wǎng)絡(luò)(Invertible Neural Networks)詳細(xì)解析:讓神經(jīng)網(wǎng)絡(luò)更加輕量化
來(lái)源:PaperWeekly
為什么要用可逆網(wǎng)絡(luò)呢?
- 因?yàn)榫幋a和解碼使用相同的參數(shù),所以 model 是輕量級(jí)的??赡娴慕翟刖W(wǎng)絡(luò) InvDN 只有 DANet 網(wǎng)絡(luò)參數(shù)量的 4.2%,但是 InvDN 的降噪性能更好。
- 由于可逆網(wǎng)絡(luò)是信息無(wú)損的,所以它能保留輸入數(shù)據(jù)的細(xì)節(jié)信息。
- 無(wú)論網(wǎng)絡(luò)的深度如何,可逆網(wǎng)絡(luò)都使用恒定的內(nèi)存來(lái)計(jì)算梯度。
其中最主要目的就是為了減少內(nèi)存的消耗,當(dāng)前所有的神經(jīng)網(wǎng)絡(luò)都采用反向傳播的方式來(lái)訓(xùn)練,反向傳播算法需要存儲(chǔ)網(wǎng)絡(luò)的中間結(jié)果來(lái)計(jì)算梯度,而且其對(duì)內(nèi)存的消耗與網(wǎng)絡(luò)單元數(shù)成正比。這也就意味著,網(wǎng)絡(luò)越深越廣,對(duì)內(nèi)存的消耗越大,這將成為很多應(yīng)用的瓶頸。
下面是 Pytorch summary 的結(jié)果,F(xiàn)orward/backward pass size(MB): 218.59 就是需要保存的中間變量大小,可以看出這部分占據(jù)了很大部分顯存(隨著網(wǎng)絡(luò)深度的增加,中間變量占據(jù)顯存量會(huì)一直增加,resnet152(size=224)的中間變量更是占據(jù)總共內(nèi)存的 606.6÷836.79≈0.725 )。如果不存儲(chǔ)中間層結(jié)果,那么就可以大幅減少 GPU 的顯存占用,有助于訓(xùn)練更深更廣的網(wǎng)絡(luò)。
import torchfrom torchvision import modelsfrom torchsummary import summary
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')vgg = models.vgg16().to(device)
summary(vgg, (3, 224, 224))
結(jié)果:---------------------------------------------------------------- Layer (type) Output Shape Param #================================================================ Conv2d-1 [-1, 64, 224, 224] 1,792 ReLU-2 [-1, 64, 224, 224] 0 Conv2d-3 [-1, 64, 224, 224] 36,928 ReLU-4 [-1, 64, 224, 224] 0 MaxPool2d-5 [-1, 64, 112, 112] 0 Conv2d-6 [-1, 128, 112, 112] 73,856 ReLU-7 [-1, 128, 112, 112] 0 Conv2d-8 [-1, 128, 112, 112] 147,584 ReLU-9 [-1, 128, 112, 112] 0 MaxPool2d-10 [-1, 128, 56, 56] 0 Conv2d-11 [-1, 256, 56, 56] 295,168 ReLU-12 [-1, 256, 56, 56] 0 Conv2d-13 [-1, 256, 56, 56] 590,080 ReLU-14 [-1, 256, 56, 56] 0 Conv2d-15 [-1, 256, 56, 56] 590,080 ReLU-16 [-1, 256, 56, 56] 0 MaxPool2d-17 [-1, 256, 28, 28] 0 Conv2d-18 [-1, 512, 28, 28] 1,180,160 ReLU-19 [-1, 512, 28, 28] 0 Conv2d-20 [-1, 512, 28, 28] 2,359,808 ReLU-21 [-1, 512, 28, 28] 0 Conv2d-22 [-1, 512, 28, 28] 2,359,808 ReLU-23 [-1, 512, 28, 28] 0 MaxPool2d-24 [-1, 512, 14, 14] 0 Conv2d-25 [-1, 512, 14, 14] 2,359,808 ReLU-26 [-1, 512, 14, 14] 0 Conv2d-27 [-1, 512, 14, 14] 2,359,808 ReLU-28 [-1, 512, 14, 14] 0 Conv2d-29 [-1, 512, 14, 14] 2,359,808 ReLU-30 [-1, 512, 14, 14] 0 MaxPool2d-31 [-1, 512, 7, 7] 0 Linear-32 [-1, 4096] 102,764,544 ReLU-33 [-1, 4096] 0 Dropout-34 [-1, 4096] 0 Linear-35 [-1, 4096] 16,781,312 ReLU-36 [-1, 4096] 0 Dropout-37 [-1, 4096] 0 Linear-38 [-1, 1000] 4,097,000================================================================Total params: 138,357,544Trainable params: 138,357,544Non-trainable params: 0----------------------------------------------------------------Input size (MB): 0.57Forward/backward pass size (MB): 218.59Params size (MB): 527.79Estimated Total Size (MB): 746.96----------------------------------------------------------------
接下來(lái)我將先從可逆神經(jīng)網(wǎng)絡(luò)講起,然后是神經(jīng)網(wǎng)絡(luò)的反向傳播,最后是標(biāo)準(zhǔn)殘差網(wǎng)絡(luò)。對(duì)反向傳播算法和標(biāo)準(zhǔn)殘差網(wǎng)絡(luò)比較熟悉的小伙伴,可以只看第一節(jié):可逆神經(jīng)網(wǎng)絡(luò)。如果各位小伙伴不熟悉反向傳播算法和標(biāo)準(zhǔn)殘差網(wǎng)絡(luò),建議先看第二節(jié):反向傳播(BP)算法和第三節(jié):殘差網(wǎng)絡(luò)(Residual Network)。本文1.2和1.3.4摘錄自 @阿亮。
可逆神經(jīng)網(wǎng)絡(luò)
可逆網(wǎng)絡(luò)具有的性質(zhì):
網(wǎng)絡(luò)的輸入、輸出的大小必須一致。
網(wǎng)絡(luò)的雅可比行列式不為 0。
1.1 什么是雅可比行列式?
雅可比行列式通常稱為雅可比式(Jacobian),它是以 n 個(gè) n 元函數(shù)的偏導(dǎo)數(shù)為元素的行列式 。事實(shí)上,在函數(shù)都連續(xù)可微(即偏導(dǎo)數(shù)都連續(xù))的前提之下,它就是函數(shù)組的微分形式下的系數(shù)矩陣(即雅可比矩陣)的行列式。若因變量對(duì)自變量連續(xù)可微,而自變量對(duì)新變量連續(xù)可微,則因變量也對(duì)新變量連續(xù)可微。這可用行列式的乘法法則和偏導(dǎo)數(shù)的連鎖法則直接驗(yàn)證。也類似于導(dǎo)數(shù)的連鎖法則。偏導(dǎo)數(shù)的連鎖法則也有類似的公式;這常用于重積分的計(jì)算中。
為什么神經(jīng)網(wǎng)絡(luò)會(huì)與雅可比行列式有關(guān)系?這里我借用李宏毅老師的 ppt(12-14頁(yè))。想看視頻的可以到 b 站上看。
簡(jiǎn)單的來(lái)講就是 ,他們的分布之間的關(guān)系就變?yōu)?nbsp;,又因?yàn)橛?nbsp;,所以 這個(gè)網(wǎng)絡(luò)的雅可比行列式不為 0 才行。
順便提一下,flow-based Model 優(yōu)化的損失函數(shù)如下:
其實(shí)這里跟矩陣運(yùn)算很像,矩陣可逆的條件也是矩陣的雅可比行列式不為 0,雅可比矩陣可以理解為矩陣的一階導(dǎo)數(shù)。
假設(shè)可逆網(wǎng)絡(luò)的表達(dá)式為:
它的雅可比矩陣為:
其行列式為 1。
https://arxiv.org/abs/1707.04585
1.3.1 可逆塊結(jié)構(gòu)
可逆神經(jīng)網(wǎng)絡(luò)將每一層分割成兩部分,分別為 和 ,每一個(gè)可逆塊的輸入是 ,輸出是 。其結(jié)構(gòu)如下:
正向計(jì)算圖示:
公式表示:
逆向計(jì)算圖示:
公式表示:
其中 F 和 G 都是相似的殘差函數(shù),參考上圖殘差網(wǎng)絡(luò)。可逆塊的跨距只能為 1,也就是說(shuō)可逆塊必須一個(gè)接一個(gè)連接,中間不能采用其它網(wǎng)絡(luò)形式銜接,否則的話就會(huì)丟失信息,并且無(wú)法可逆計(jì)算了,這點(diǎn)與殘差塊不一樣。如果一定要采取跟殘差塊相似的結(jié)構(gòu),也就是中間一部分采用普通網(wǎng)絡(luò)形式銜接,那中間這部分的激活結(jié)果就必須顯式的存起來(lái)。
1.3.2 不用存儲(chǔ)激活結(jié)果的反向傳播
為了更好地計(jì)算反向傳播的步驟,我們修改一下上述正向計(jì)算和逆向計(jì)算的公式:
盡管 和 的值是相同的,但是兩個(gè)變量在圖中卻代表不同的節(jié)點(diǎn),所以在反向傳播中它們的總體導(dǎo)數(shù)是不一樣的。 的導(dǎo)數(shù)包含通過(guò) 產(chǎn)生的間接影響,而 的導(dǎo)數(shù)卻不受 的任何影響。
在反向傳播計(jì)算流程中,先給出最后一層的激活值 和誤差傳播的總體導(dǎo)數(shù) ,然后要計(jì)算出其輸入值 和對(duì)應(yīng)的導(dǎo)數(shù) ,以及殘差函數(shù) F 和 G 中權(quán)重參數(shù)的總體導(dǎo)數(shù),求解步驟如下:
1.3.3 計(jì)算開(kāi)銷
一個(gè) N 個(gè)連接的神經(jīng)網(wǎng)絡(luò),正向計(jì)算的理論加乘開(kāi)銷為 N,反向傳播求導(dǎo)的理論加乘開(kāi)銷為 2N(反向求導(dǎo)包含復(fù)合函數(shù)求導(dǎo)連乘),而可逆網(wǎng)絡(luò)多一步需要反向計(jì)算輸入值的操作,所以理論計(jì)算開(kāi)銷為 4N,比普通網(wǎng)絡(luò)開(kāi)銷約多出 33% 左右。但是在實(shí)際操作中,正向和反向的計(jì)算開(kāi)銷在 GPU 上差不多,可以都理解為 N。那么這樣的話,普通網(wǎng)絡(luò)的整體計(jì)算開(kāi)銷為 2N,可逆網(wǎng)絡(luò)的整體開(kāi)銷為 3N,也就是多出了約 50%。
1.3.4 雅可比行列式的計(jì)算
其編碼公式如下:
其解碼公式如下:
為了計(jì)算雅可比矩陣,我們更直觀的寫(xiě)成下面的編碼公式:
它的雅可比矩陣為:
其實(shí)上面這個(gè)雅可比行列式也是 1,因?yàn)檫@里 ,它們的系數(shù)是一樣的。
有另外一種解釋方式就是把這種對(duì)偶的形式切成兩半:
其行列式為 1。
因?yàn)槭菍?duì)偶的形式,所以這里的行列式也為 1。
因?yàn)?nbsp;,所以其行列式也為 1。
上圖中符號(hào)的含義:
- x1,x2,x3:表示 3 個(gè)輸入層節(jié)點(diǎn)。
- :表示從 t-1 層到 t 層的權(quán)重參數(shù),j 表示 t 層的第 j 個(gè)節(jié)點(diǎn),i 表示 t-1 層的第 i 個(gè)節(jié)點(diǎn)。
- :表示 t 層的第 i 個(gè)激活后輸出結(jié)果。
- g(x):表示激活函數(shù)。
正向傳播計(jì)算過(guò)程:
隱藏層(網(wǎng)絡(luò)的第二層)
輸出層(網(wǎng)絡(luò)的最后一層)
反向傳播計(jì)算過(guò)程:
以單個(gè)樣本為例,假設(shè)輸入向量是 [x1,x2,x3],目標(biāo)輸出值是 [y1,y2],代價(jià)函數(shù)用 L 表示。反向傳播的總體原理就是根據(jù)總體輸出誤差,反向傳播回網(wǎng)絡(luò),通過(guò)計(jì)算每一層節(jié)點(diǎn)的梯度,利用梯度下降法原理,更新每一層的網(wǎng)絡(luò)權(quán)重 w 和偏置 b,這也是網(wǎng)絡(luò)學(xué)習(xí)的過(guò)程。誤差反向傳播的優(yōu)點(diǎn)就是可以把繁雜的導(dǎo)數(shù)計(jì)算以數(shù)列遞推的形式來(lái)表示, 簡(jiǎn)化了計(jì)算過(guò)程。
以平方誤差來(lái)計(jì)算反向傳播的過(guò)程,代價(jià)函數(shù)表示如下:
引入新的誤差求導(dǎo)表示形式,稱為神經(jīng)單元誤差:
l=2,3 表示第幾層,j 表示某一層的第幾個(gè)節(jié)點(diǎn)。替換表示后如下:
所以我們可以歸納出一般的計(jì)算公式:
從上述公式可以看出,如果神經(jīng)單元誤差 δ 可以求出來(lái),那么總誤差對(duì)每一層的權(quán)重 w 和偏置 b 的偏導(dǎo)數(shù)就可以求出來(lái),接下來(lái)就可以利用梯度下降法來(lái)優(yōu)化參數(shù)了。
求解每一層的 δ:
輸出層
隱藏層
從而得出 l 層神經(jīng)單元誤差和 l+1 層神經(jīng)單元誤差的關(guān)系。這就是誤差反向傳播算法,只要求出輸出層的神經(jīng)單元誤差,其它層的神經(jīng)單元誤差就不需要計(jì)算偏導(dǎo)數(shù)了,而可以直接通過(guò)上述公式得出。
殘差網(wǎng)絡(luò)(Residual Network)
梯度消失問(wèn)題;
網(wǎng)絡(luò)退化問(wèn)題。
所以在第二層進(jìn)入激活函數(shù) ReLU之 前 F(x)+x 組成新的輸入,也叫恒等映射。
恒等映射就是在這個(gè)殘差塊輸入是 x 的情況下輸出依然是 x,這樣其目標(biāo)就是學(xué)習(xí)讓 F(X)=0。
這里有一個(gè)問(wèn)題哈,為什么要額外加一個(gè) x 呢,而不是讓模型直接學(xué)習(xí) F(x)=x?
因?yàn)樽?F(x)=0 比較容易,初始化參數(shù) W 非常小接近 0,就可以讓輸出接近 0,同時(shí)輸出如果是負(fù)數(shù),經(jīng)過(guò)第一層 Relu 后輸出依然 0,都能使得最后的 F(x)=0,也就是有多種情況都可以使得 F(x)=0;但是讓 F(x)=x 確實(shí)非常難的,因?yàn)閰?shù)都必須剛剛好才能使得最后輸出為 x。
恒等映射有什么作用?
恒等映射就可以解決網(wǎng)絡(luò)退化的問(wèn)題,當(dāng)網(wǎng)絡(luò)層數(shù)越來(lái)越深的時(shí)候,網(wǎng)絡(luò)的精度卻在下降,也就是說(shuō)網(wǎng)絡(luò)自身存在一個(gè)最優(yōu)的層度結(jié)構(gòu),太深太淺都能使得模型精度下降。有了恒等映射存在,網(wǎng)絡(luò)就能夠自己學(xué)習(xí)到哪些層是冗余的,就可以無(wú)損通過(guò)這些層,理論上講再深的網(wǎng)絡(luò)都不影響其精度,解決了網(wǎng)絡(luò)退化問(wèn)題。
為什么可以解決梯度消失問(wèn)題呢?
以兩個(gè)殘差塊的結(jié)構(gòu)實(shí)例圖來(lái)分析,其中每個(gè)殘差塊有 2 層神經(jīng)網(wǎng)絡(luò)組成,如下圖:
假設(shè)激活函數(shù) ReLU 用 g(x) 函數(shù)來(lái)表示,樣本實(shí)例是 [x1,y1],即輸入是 x1,目標(biāo)值是 y1,損失函數(shù)還是采用平方損失函數(shù),則每一層的計(jì)算如下:
下面我們對(duì)第一個(gè)殘差塊的權(quán)重參數(shù)求導(dǎo),根據(jù)鏈?zhǔn)角髮?dǎo)法則,公式如下:
我們可以看到求導(dǎo)公式中多了一個(gè)+1項(xiàng),這就將原來(lái)的鏈?zhǔn)角髮?dǎo)中的連乘變成了連加狀態(tài),可以有效避免梯度消失了。
參考文獻(xiàn):
[1] PPT https://speech.ee.ntu.edu.tw/~tlkagk/courses/ML_2019/Lecture/FLOW%20(v7).pdf[2] 神經(jīng)網(wǎng)絡(luò)的可逆形式 https://zhuanlan.zhihu.com/p/268242678[3] 大幅減少GPU顯存占用:可逆殘差網(wǎng)絡(luò)(The Reversible Residual Network) https://www.cnblogs.com/gczr/p/12181354.html[4] 雅可比行列式 https://baike.baidu.com/item/雅可比行列式/4709261?fr=aladdin[5] The Reversible Residual Network: Backpropagation Without Storing Activations[6] pytorch-summary https://github.com/sksq96/pytorch-summary
*博客內(nèi)容為網(wǎng)友個(gè)人發(fā)布,僅代表博主個(gè)人觀點(diǎn),如有侵權(quán)請(qǐng)聯(lián)系工作人員刪除。
pwm相關(guān)文章:pwm是什么