Batch Normalization原理與實戰(zhàn)(1)
作者丨天雨粟@知乎
來源丨h(huán)ttps://zhuanlan.zhihu.com/p/34879333
編輯丨江大白
導讀本文主要從理論與實戰(zhàn)的視角,對深度學習中的Batch Normalization的思路進行講解、歸納和總結(jié),并輔以代碼讓小伙伴兒們對Batch Normalization的作用有更加直觀的了解。
前言本文主要分為兩大部分。第一部分是理論板塊,主要從背景、算法、效果等角度對Batch Normalization進行詳解;第二部分是實戰(zhàn)板塊,主要以MNIST數(shù)據(jù)集作為整個代碼測試的數(shù)據(jù),通過比較加入Batch Normalization前后網(wǎng)絡(luò)的性能來讓大家對Batch Normalization的作用與效果有更加直觀的感知。
(一)理論板塊理論板塊將從以下四個方面對Batch Normalization進行詳解:
- 提出背景
- BN算法思想
- 測試階段如何使用BN
- BN的優(yōu)勢
理論部分主要參考2015年Google的Sergey Ioffe與Christian Szegedy的論文內(nèi)容,并輔以吳恩達Coursera課程與其它博主的資料。所有參考內(nèi)容鏈接均見于文章最后參考鏈接部分。
1 提出背景1.1 煉丹的困擾在深度學習中,由于問題的復雜性,我們往往會使用較深層數(shù)的網(wǎng)絡(luò)進行訓練,相信很多煉丹的朋友都對調(diào)參的困難有所體會,尤其是對深層神經(jīng)網(wǎng)絡(luò)的訓練調(diào)參更是困難且復雜。在這個過程中,我們需要去嘗試不同的學習率、初始化參數(shù)方法(例如Xavier初始化)等方式來幫助我們的模型加速收斂。深度神經(jīng)網(wǎng)絡(luò)之所以如此難訓練,其中一個重要原因就是網(wǎng)絡(luò)中層與層之間存在高度的關(guān)聯(lián)性與耦合性。下圖是一個多層的神經(jīng)網(wǎng)絡(luò),層與層之間采用全連接的方式進行連接。
我們規(guī)定左側(cè)為神經(jīng)網(wǎng)絡(luò)的底層,右側(cè)為神經(jīng)網(wǎng)絡(luò)的上層。那么網(wǎng)絡(luò)中層與層之間的關(guān)聯(lián)性會導致如下的狀況:隨著訓練的進行,網(wǎng)絡(luò)中的參數(shù)也隨著梯度下降在不停更新。一方面,當?shù)讓泳W(wǎng)絡(luò)中參數(shù)發(fā)生微弱變化時,由于每一層中的線性變換與非線性激活映射,這些微弱變化隨著網(wǎng)絡(luò)層數(shù)的加深而被放大(類似蝴蝶效應);另一方面,參數(shù)的變化導致每一層的輸入分布會發(fā)生改變,進而上層的網(wǎng)絡(luò)需要不停地去適應這些分布變化,使得我們的模型訓練變得困難。上述這一現(xiàn)象叫做Internal Covariate Shift。
1.2 什么是Internal Covariate ShiftBatch Normalization的原論文作者給了Internal Covariate Shift一個較規(guī)范的定義:在深層網(wǎng)絡(luò)訓練的過程中,由于網(wǎng)絡(luò)中參數(shù)變化而引起內(nèi)部結(jié)點數(shù)據(jù)分布發(fā)生變化的這一過程被稱作Internal Covariate Shift。
這句話該怎么理解呢? 我們同樣以1.1中的圖為例, 我們定義每一層的線性變換為 input , 其中 代表層數(shù); 非線性變換為 , 其中 為 第 層的激活函數(shù)。
隨著梯度下降的進行, 每一層的參數(shù) 與 都會被更新, 那么 的分布也就發(fā)生了改變, 進而 也同樣出現(xiàn)分布的改變。而 作為第 層的輸入, 意味著 層就需要去不停 適應這種數(shù)據(jù)分布的變化, 這一過程就被叫做Internal Covariate Shift。
1.3 Internal Covariate Shift會帶來什么問題?(1)上層網(wǎng)絡(luò)需要不停調(diào)整來適應輸入數(shù)據(jù)分布的變化,導致網(wǎng)絡(luò)學習速度的降低
我們在上面提到了梯度下降的過程會讓每一層的參數(shù) 和 發(fā)生變化,進而使得每一層的線性與非線性計算結(jié)果分布產(chǎn)生變化。后層網(wǎng)絡(luò)就要不停地去適應這種分布變化,這個時候就會使得整個網(wǎng)絡(luò)的學習速率過慢。
(2)網(wǎng)絡(luò)的訓練過程容易陷入梯度飽和區(qū),減緩網(wǎng)絡(luò)收斂速度
當我們在神經(jīng)網(wǎng)絡(luò)中采用飽和激活函數(shù) (saturated activation function) 時, 例如sigmoid, tanh 激活函數(shù), 很容易使得模型訓練陷入梯度飽和區(qū) (saturated regime)。隨著模型訓練的進行, 我 們的參數(shù) 會逐漸更新并變大, 此時 就會隨之變大, 并且 還受 到更底層網(wǎng)絡(luò)參數(shù) 的影響, 隨著網(wǎng)絡(luò)層數(shù)的加深, 很容易陷入梯度 飽和區(qū), 此時梯度會變得很小甚至接近于 0 , 參數(shù)的更新速度就會減慢, 進而就會放慢網(wǎng)絡(luò)的收玫 速度。
對于激活函數(shù)梯度飽和問題,有兩種解決思路。第一種就是更為非飽和性激活函數(shù),例如線性整流函數(shù)ReLU可以在一定程度上解決訓練進入梯度飽和區(qū)的問題。另一種思路是,我們可以讓激活函數(shù)的輸入分布保持在一個穩(wěn)定狀態(tài)來盡可能避免它們陷入梯度飽和區(qū),這也就是Normalization的思路。
1.4 我們?nèi)绾螠p緩Internal Covariate Shift?要緩解ICS的問題,就要明白它產(chǎn)生的原因。ICS產(chǎn)生的原因是由于參數(shù)更新帶來的網(wǎng)絡(luò)中每一層輸入值分布的改變,并且隨著網(wǎng)絡(luò)層數(shù)的加深而變得更加嚴重,因此我們可以通過固定每一層網(wǎng)絡(luò)輸入值的分布來對減緩ICS問題。
(1)白化(Whitening)
白化(Whitening)是機器學習里面常用的一種規(guī)范化數(shù)據(jù)分布的方法,主要是PCA白化與ZCA白化。白化是對輸入數(shù)據(jù)分布進行變換,進而達到以下兩個目的:
- 使得輸入特征分布具有相同的均值與方差。 其中PCA白化保證了所有特征分布均值為0,方差為1;而ZCA白化則保證了所有特征分布均值為0,方差相同;
- 去除特征之間的相關(guān)性。
通過白化操作,我們可以減緩ICS的問題,進而固定了每一層網(wǎng)絡(luò)輸入分布,加速網(wǎng)絡(luò)訓練過程的收斂(LeCun et al.,1998b;Wiesler&Ney,2011)。
(2)Batch Normalization提出
既然白化可以解決這個問題,為什么我們還要提出別的解決辦法?當然是現(xiàn)有的方法具有一定的缺陷,白化主要有以下兩個問題:
- 白化過程計算成本太高, 并且在每一輪訓練中的每一層我們都需要做如此高成本計算的白化操作;
- 白化過程由于改變了網(wǎng)絡(luò)每一層的分布,因而改變了網(wǎng)絡(luò)層中本身數(shù)據(jù)的表達能力。底層網(wǎng)絡(luò)學習到的參數(shù)信息會被白化操作丟失掉。
既然有了上面兩個問題,那我們的解決思路就很簡單,一方面,我們提出的normalization方法要能夠簡化計算過程;另一方面又需要經(jīng)過規(guī)范化處理后讓數(shù)據(jù)盡可能保留原始的表達能力。于是就有了簡化+改進版的白化——Batch Normalization。
2 Batch Normalization2.1 思路既然白化計算過程比較復雜,那我們就簡化一點,比如我們可以嘗試單獨對每個特征進行normalizaiton就可以了,讓每個特征都有均值為0,方差為1的分布就OK。
另一個問題,既然白化操作減弱了網(wǎng)絡(luò)中每一層輸入數(shù)據(jù)表達能力,那我就再加個線性變換操作,讓這些數(shù)據(jù)再能夠盡可能恢復本身的表達能力就好了。
因此,基于上面兩個解決問題的思路,作者提出了Batch Normalization,下一部分來具體講解這個算法步驟。
2.2 算法在深度學習中,由于采用full batch的訓練方式對內(nèi)存要求較大,且每一輪訓練時間過長;我們一般都會采用對數(shù)據(jù)做劃分,用mini-batch對網(wǎng)絡(luò)進行訓練。因此,Batch Normalization也就在mini-batch的基礎(chǔ)上進行計算。
2.2.1 參數(shù)定義我們依舊以下圖這個神經(jīng)網(wǎng)絡(luò)為例。我們定義網(wǎng)絡(luò)總共有 LL 層(不包含輸入層)并定義如下符號:
參數(shù)相關(guān):
- : 網(wǎng)絡(luò)中的層標號
- : 網(wǎng)絡(luò)中的最后一層或總層數(shù)
- : 第 層的維度, 即神經(jīng)元結(jié)點數(shù)
- : 第 層的權(quán)重矩陣,
- 第 層的偏置向量,
- : 第 層的線性計算結(jié)果, input
- : 第 層的激活函數(shù)
- : 第 層的非線性激活結(jié)果,
樣本相關(guān):
- : 訓練樣本的數(shù)量
- : 訓練樣本的特征數(shù)
- : 訓練樣本集, (注意這里 的一列是一個 樣本)
- : batch size, 即每個batch中樣本的數(shù)量
- : 第 個mini-batch的訓練數(shù)據(jù), , 其中
介紹算法思路沿襲前面BN提出的思路來講。第一點, 對每個特征進行獨立的normalization。我們考慮一個batch的訓練, 傳入 個訓練樣本, 并關(guān)注網(wǎng)絡(luò)中的某一層, 忽略上標 。
我們關(guān)注當前層的第 個維度, 也就是第 個神經(jīng)元結(jié)點, 則有 。我們當前維度進行規(guī)范化:
其中 是為了防止方差為0產(chǎn)生無效計算。
下面我們再來結(jié)合個具體的例子來進行計算。下圖我們只關(guān)注第 層的計算結(jié)果, 左邊的矩陣是 線性計算結(jié)果, 還末進行激活函數(shù)的非線性變換。此時每一列是一個樣本, 圖中可以看到共有8列, 代表當前訓練樣本的batch中共有8個樣本, 每一行代表當前 層神經(jīng)元的一個節(jié)點, 可以看到當前 層共有4個神經(jīng)元結(jié)點, 即第 層維度為4。我們可以看到, 每行的數(shù)據(jù)分布都不同。
對于第一個神經(jīng)元, 我們求得 (其中 ), 此時我們利用 對第一行數(shù)據(jù)(第一個維度)進行normalization得到新的值 。同理我們可以計算出其他輸入維度歸一化后的值。如下圖:
通過上面的變換,我們解決了第一個問題,即用更加簡化的方式來對數(shù)據(jù)進行規(guī)范化,使得第 ll 層的輸入每個特征的分布均值為0,方差為1。
如同上面提到的,Normalization操作我們雖然緩解了ICS問題,讓每一層網(wǎng)絡(luò)的輸入數(shù)據(jù)分布都變得穩(wěn)定,但卻導致了數(shù)據(jù)表達能力的缺失。也就是我們通過變換操作改變了原有數(shù)據(jù)的信息表達(representation ability of the network),使得底層網(wǎng)絡(luò)學習到的參數(shù)信息丟失。另一方面,通過讓每一層的輸入分布均值為0,方差為1,會使得輸入在經(jīng)過sigmoid或tanh激活函數(shù)時,容易陷入非線性激活函數(shù)的線性區(qū)域。
因此, BN又引入了兩個可學習 (learnable) 的參數(shù) 與 。這兩個參數(shù)的引入是為了恢復數(shù)據(jù)本 身的表達能力, 對規(guī)范化后的數(shù)據(jù)進行線性變換, 即 。特別地, 當 時, 可以實現(xiàn)等價變換(identity transform)并且保留了原始輸入特征的分布信 息。
通過上面的步驟,我們就在一定程度上保證了輸入數(shù)據(jù)的表達能力。
以上就是整個Batch Normalization在模型訓練中的算法和思路。
2.2.3 公式補充:在進行normalization的過程中, 由于我們的規(guī)范化操作會對減去均值, 因此, 偏置項 可以被忽略掉或可以被置為0, 即
對于神經(jīng)網(wǎng)絡(luò)中的第 ll 層,我們有:
我們知道BN在每一層計算的 與 都是基于當前batch中的訓練數(shù)據(jù), 但是這就帶來了一個問 題: 我們在預測階段, 有可能只需要預測一個樣本或很少的樣本, 沒有像訓練樣本中那么多的數(shù) 據(jù), 此時 與 的計算一定是有偏估計, 這個時候我們該如何進行計算呢?
利用BN訓練好模型后, 我們保留了每組mini-batch訓練數(shù)據(jù)在網(wǎng)絡(luò)中每一層的 與 。此時我們使用整個樣本的統(tǒng)計量來對Test數(shù)據(jù)進行歸一化,具體來說使用均值與方差的無偏估計:
得到每個特征的均值與方差的無偏估計后, 我們對test數(shù)據(jù)采用同樣的normalization方法:
另外,除了采用整體樣本的無偏估計外。吳恩達在Coursera上的Deep Learning課程指出可以對train階段每個batch計算的mean/variance采用指數(shù)加權(quán)平均來得到test階段mean/variance的估計。
4 Batch Normalization的優(yōu)勢Batch Normalization在實際工程中被證明了能夠緩解神經(jīng)網(wǎng)絡(luò)難以訓練的問題,BN具有的有事可以總結(jié)為以下三點:
(1)BN使得網(wǎng)絡(luò)中每層輸入數(shù)據(jù)的分布相對穩(wěn)定,加速模型學習速度
BN通過規(guī)范化與線性變換使得每一層網(wǎng)絡(luò)的輸入數(shù)據(jù)的均值與方差都在一定范圍內(nèi),使得后一層網(wǎng)絡(luò)不必不斷去適應底層網(wǎng)絡(luò)中輸入的變化,從而實現(xiàn)了網(wǎng)絡(luò)中層與層之間的解耦,允許每一層進行獨立學習,有利于提高整個神經(jīng)網(wǎng)絡(luò)的學習速度。
(2)BN使得模型對網(wǎng)絡(luò)中的參數(shù)不那么敏感,簡化調(diào)參過程,使得網(wǎng)絡(luò)學習更加穩(wěn)定
在神經(jīng)網(wǎng)絡(luò)中,我們經(jīng)常會謹慎地采用一些權(quán)重初始化方法(例如Xavier)或者合適的學習率來保證網(wǎng)絡(luò)穩(wěn)定訓練。
當學習率設(shè)置太高時, 會使得參數(shù)更新步伐過大, 容易出現(xiàn)震蕩和不收斂。但是使用BN的網(wǎng)絡(luò)將 不會受到參數(shù)數(shù)值大小的影響。例如, 我們對參數(shù) 進行縮放得到 。對于縮放前的值 , 我們設(shè)其均值為 , 方差為 ; 對于縮放值 , 設(shè)其均值為 , 方差為 , 則我們有:
我們忽略 , 則有:
注:公式中的 是當前層的輸入,也是前一層的輸出;不是下標啊旁友們!
我們可以看到, 經(jīng)過BN操作以后, 權(quán)重的縮放值會被“抺去”, 因此保證了輸入數(shù)據(jù)分布穩(wěn)定在一定范圍內(nèi)。另外, 權(quán)重的縮放并不會影響到對 的梯度計算; 并且當權(quán)重越大時, 即 越大, 越小,意味著權(quán)重 的梯度反而越小,這樣BN就保證了梯度不會依賴于參數(shù)的scale, 使得參數(shù)的更新處在更加穩(wěn)定的狀態(tài)。
因此,在使用Batch Normalization之后,抑制了參數(shù)微小變化隨著網(wǎng)絡(luò)層數(shù)加深被放大的問題,使得網(wǎng)絡(luò)對參數(shù)大小的適應能力更強,此時我們可以設(shè)置較大的學習率而不用過于擔心模型divergence的風險。
(3)BN允許網(wǎng)絡(luò)使用飽和性激活函數(shù)(例如sigmoid,tanh等),緩解梯度消失問題
在不使用BN層的時候, 由于網(wǎng)絡(luò)的深度與復雜性, 很容易使得底層網(wǎng)絡(luò)變化累積到上層網(wǎng)絡(luò)中, 導致模型的訓練很容易進入到激活函數(shù)的梯度飽和區(qū); 通過normalize操作可以讓激活函數(shù)的輸入數(shù)據(jù)落在梯度非飽和區(qū), 緩解梯度消失的問題; 另外通過自適應學習 與 又讓數(shù)據(jù)保留更多的原始信息。
(4)BN具有一定的正則化效果
在Batch Normalization中,由于我們使用mini-batch的均值與方差作為對整體訓練樣本均值與方差的估計,盡管每一個batch中的數(shù)據(jù)都是從總體樣本中抽樣得到,但不同mini-batch的均值與方差會有所不同,這就為網(wǎng)絡(luò)的學習過程中增加了隨機噪音,與Dropout通過關(guān)閉神經(jīng)元給網(wǎng)絡(luò)訓練帶來噪音類似,在一定程度上對模型起到了正則化的效果。
另外,原作者通過也證明了網(wǎng)絡(luò)加入BN后,可以丟棄Dropout,模型也同樣具有很好的泛化效果。
(二)實戰(zhàn)板塊
經(jīng)過了上面了理論學習,我們對BN有了理論上的認知?!癟alk is cheap, show me the code”。接下來我們就通過實際的代碼來對比加入BN前后的模型效果。實戰(zhàn)部分使用MNIST數(shù)據(jù)集作為數(shù)據(jù)基礎(chǔ),并使用TensorFlow中的Batch Normalization結(jié)構(gòu)來進行BN的實現(xiàn)。
數(shù)據(jù)準備:MNIST手寫數(shù)據(jù)集
代碼地址:我的GitHub (https://github.com/NELSONZHAO/zhihu/tree/master/batch_normalization_discussion)
注:TensorFlow版本為1.6.0
實戰(zhàn)板塊主要分為兩部分:
- 網(wǎng)絡(luò)構(gòu)建與輔助函數(shù)
- BN測試
首先我們先定義一下神經(jīng)網(wǎng)絡(luò)的類,這個類里面主要包括了以下方法:
- build_network:前向計算
- fully_connected:全連接計算
- train:訓練模型
- test:測試模型
我們首先通過構(gòu)造函數(shù),把權(quán)重、激活函數(shù)以及是否使用BN這些變量傳入,并生成一個training_accuracies來記錄訓練過程中的模型準確率變化。這里的initial_weights是一個list,list中每一個元素是一個矩陣(二維tuple),存儲了每一層的權(quán)重矩陣。build_network實現(xiàn)了網(wǎng)絡(luò)的構(gòu)建,并調(diào)用了fully_connected函數(shù)(下面會提)進行計算。要注意的是,由于MNIST是多分類,在這里我們不需要對最后一層進行激活,保留計算的logits就好。
1.2 fully_connected這里的fully_connected主要用來每一層的線性與非線性計算。通過self.use_batch_norm來控制是否使用BN。
另外,值得注意的是,tf.layers.batch_normalization接口中training參數(shù)非常重要,官方文檔中描述為:
training: Either a Python boolean, or a TensorFlow boolean scalar tensor (e.g. a placeholder). Whether to return the output in training mode (normalized with statistics of the current batch) or in inference mode (normalized with moving statistics). NOTE: make sure to set this parameter correctly, or else your training/inference will not work properly.
當我們訓練時,要設(shè)置為True,保證在訓練過程中使用的是mini-batch的統(tǒng)計量進行normalization;在Inference階段,使用False,也就是使用總體樣本的無偏估計。
1.3 traintrain函數(shù)主要用來進行模型的訓練。除了要定義label,loss以及optimizer以外,我們還需要注意,官方文檔指出在使用BN時的事項:
Note: when training, the moving_mean and moving_variance need to be updated. By default the update ops are placed in tf.GraphKeys.UPDATE_OPS, so they need to be added as a dependency to the train_op.
因此當self.use_batch_norm為True時,要使用tf.control_dependencies保證模型正常訓練。
1.4 test注意:在訓練過程中batch_size選了60(mnist.train.next_batch(60)),這里是因為BN的原paper中用的60。( We trained the network for 50000 steps, with 60 examples per mini-batch.)
test階段與train類似,只是要設(shè)置self.is_training=False,保證Inference階段BN的正確。
經(jīng)過上面的步驟,我們的框架基本就搭好了,接下來我們再寫一個輔助函數(shù)train_and_test以及plot繪圖函數(shù)就可以開始對BN進行測試啦。train_and_test以及plot函數(shù)見GitHub代碼中,這里不再贅述。
*博客內(nèi)容為網(wǎng)友個人發(fā)布,僅代表博主個人觀點,如有侵權(quán)請聯(lián)系工作人員刪除。