ICCV 2021|“白嫖”性能的MixMo,一種新的數(shù)據(jù)增強(qiáng)or模型融合方法
本文作者提出了一種新的多輸入多輸出深度子網(wǎng)學(xué)習(xí)廣義框架MixMo,MixMo可以作為一種集成方法或一種新的混合樣本數(shù)據(jù)增強(qiáng)方法進(jìn)行分析,同時(shí)仍然與兩種研究方向的工作保持互補(bǔ)。
寫在前面
最近的工作提出的不用額外計(jì)算的集成方法,大多是在一個(gè)網(wǎng)絡(luò)中同時(shí)設(shè)置不同的subnet。訓(xùn)練時(shí)。每個(gè)subnet只學(xué)習(xí)分類多個(gè)輸入數(shù)據(jù)中的其中一個(gè)。然而,如何更好地混合這些多個(gè)輸入的問(wèn)題迄今尚未被研究。
在本文,作者提出了一種新的多輸入多輸出深度子網(wǎng)學(xué)習(xí)廣義框架MixMo。作者的Motivation是用一個(gè)更合適的混合機(jī)制來(lái)代替先前方法中求和導(dǎo)致的次優(yōu)操作。受到混合樣本數(shù)據(jù)增強(qiáng)的啟發(fā),作者發(fā)現(xiàn)特征的混合可以使subnet更強(qiáng),使得數(shù)據(jù)更加多樣,進(jìn)而提高模型performance。
基于MixMo,作者提升了CIFAR-100和Tiny ImageNet數(shù)據(jù)集上的SOTA性能。
1. 論文和代碼地址
論文地址:https://arxiv.org/abs/2103.06132
代碼地址:https://github.com/alexrame/mixmo-pytorch
2. Motivation
卷積神經(jīng)網(wǎng)絡(luò)(cnn)在計(jì)算機(jī)視覺任務(wù)中表現(xiàn)出了出色的性能,尤其是分類任務(wù)。為了在真實(shí)場(chǎng)景中增加魯棒性或贏得Kaggle競(jìng)賽,cnn通常會(huì)采用兩種實(shí)用策略:數(shù)據(jù)增強(qiáng) 和模型集成 。
數(shù)據(jù)增強(qiáng)可以減少過(guò)擬合并提升模型的泛化性。傳統(tǒng)的圖像增強(qiáng)是保留標(biāo)簽的:例如翻轉(zhuǎn)、裁剪等。然而,最近的混合樣本數(shù)據(jù)增強(qiáng)(MSDA)改變了這種方式:多個(gè)輸入和它們的標(biāo)簽按比例混合來(lái)創(chuàng)建人工樣本,代表工作有MixUp,CutMix等等。
模型集成證明了聚合來(lái)自多個(gè)神經(jīng)網(wǎng)絡(luò)的不同預(yù)測(cè)能夠顯著提高了泛化能力,尤其是不確定性估計(jì)。從經(jīng)驗(yàn)上講,幾個(gè)小網(wǎng)絡(luò)的集成通常比一個(gè)大網(wǎng)絡(luò)性能更好。然而,在訓(xùn)練和推理方面,集成在時(shí)間和顯存消耗方面都是昂貴的:這往往限制了模型集成的適用性。
在本文,作者提出了多輸入多輸出框架MixMo。為了解決傳統(tǒng)集成中出現(xiàn)的這些開銷,作者將M個(gè)獨(dú)立子網(wǎng)放入一個(gè)單一的base網(wǎng)絡(luò)中。這也是合理的,因?yàn)樵谀P图蓵r(shí),“最終采納的網(wǎng)絡(luò)”其實(shí)就和整體的網(wǎng)絡(luò)表現(xiàn)差不多。
所以,現(xiàn)在最大的問(wèn)題是如何在沒(méi)有結(jié)構(gòu)差異的情況下加強(qiáng)subnet之間的多樣性。
如上圖,作者在訓(xùn)練過(guò)程中同時(shí)考慮了M個(gè)輸入,M個(gè)輸入被M個(gè)參數(shù)不共享的Encoder編碼到共享空間中,然后將特征送到核心網(wǎng)絡(luò),核心網(wǎng)絡(luò)最終分成M個(gè)分支;這個(gè)M個(gè)分支用來(lái)預(yù)測(cè)不同輸入信息的label。在inference的時(shí)候,同一圖像重復(fù)M次:通過(guò)平均M個(gè)預(yù)測(cè)獲得“免費(fèi)”的集成效果。
與現(xiàn)有的MSDA相比,MixMo最大的不同就是multi-input mixing block。如果合并是一個(gè)基本的求和,MixMo將變成到MIMO[1]。作者對(duì)比了大量的MSDA的工作,設(shè)計(jì)了更合適的混合塊,因此作者采用binary masking的方法來(lái)確保子網(wǎng)絡(luò)的多樣性。(如上圖所示,作者對(duì)不同樣本采用了一個(gè)binary masking方法,這一點(diǎn)就類似CutMix,而不是像MIMO那樣直接相加 )。
這種不對(duì)稱的混合也會(huì)造成網(wǎng)絡(luò)特征中的信息不平衡的新問(wèn)題,因此作者通過(guò)一個(gè)新的加權(quán)函數(shù)來(lái)解決多個(gè)分類訓(xùn)練任務(wù)之間的不平衡問(wèn)題。
3. 方法
MixMo的整體結(jié)構(gòu)如上圖所示
3.1. General overview
核心網(wǎng)絡(luò)C需要同時(shí)處理兩種輸入的特征表示。然后多層的網(wǎng)絡(luò)D,通過(guò)這個(gè)mix的特征,再一次把各自樣本的類別識(shí)別出來(lái)。(個(gè)人理解這個(gè)網(wǎng)絡(luò)是一個(gè)“分-總-分”的結(jié)構(gòu),首先,這個(gè)網(wǎng)絡(luò)對(duì)不同輸入的樣本進(jìn)行分別編碼,這是第一個(gè)“分”的過(guò)程;然后這些被編碼的特征通過(guò)Mixing Block融合,這是“總”的過(guò)程;最后不同的層再根據(jù)這個(gè)混合的特征,識(shí)別出各自樣本的類別,這是最后一個(gè)“分”的過(guò)程 )
訓(xùn)練過(guò)程中的損失函數(shù)為各自樣本的交叉熵?fù)p失函數(shù)之和(分別乘上各自的權(quán)重,權(quán)重的計(jì)算見下文):
在inference的時(shí)候,同一個(gè)輸入x被輸入到不同的分支中,核心網(wǎng)絡(luò)C的輸入為的和,這最大的保留了來(lái)自兩種編碼信息。然后,最終的預(yù)測(cè)結(jié)果為將不同分支的預(yù)測(cè)平均值。這使得模型可以在一次前向傳播的過(guò)程中享受模型融合的結(jié)果。
3.2. Mixing block
Mixing block 是MixMo的核心,它將兩個(gè)輸入組合成一個(gè)共享表示。受MSDA混合方法的啟發(fā),MixMo通用框架包含了更廣泛的變化。
作者提出的第一個(gè)變體是 Linear-MixMo,借鑒了MixUp的思想,直接將兩張圖片通過(guò)一個(gè)透明度疊在一起:
接著,作者受到MixCut的啟發(fā),提出了 Cut-MixMo :
與Linear-MixMo不同,這里并不是將整張圖片相加,而是像MixCut一樣,每次都是加了一個(gè)patch。
Cut-MixMo比其他策略表現(xiàn)更好。具體來(lái)說(shuō),Cut-MixMo中的binary mixing取代了MIMO和 Linear-MixMo中的線性插值,使子網(wǎng)絡(luò)更加精確和多樣化。
為什么Cut-MixMo會(huì)比 Linear-MixMo要更好?
1)基于CutMix優(yōu)于Mixup的相同原因,M中的binary mixing訓(xùn)練了更強(qiáng)的單個(gè)子網(wǎng)。此外通過(guò)binary mixing,模擬了常見的物體遮擋問(wèn)題。
2)線性插值從根本上不適合誘導(dǎo)多樣性,因?yàn)閮蓚€(gè)輸入都保留了完整的信息。CutMix通過(guò)交替選擇的圖像patch,顯式地增加了數(shù)據(jù)集的多樣性。
3.3. Loss weighting
Mixing機(jī)制中的不對(duì)稱可能導(dǎo)致一種輸入蓋過(guò)另一種輸入。當(dāng)時(shí),權(quán)重更大的輸入可能更容易預(yù)測(cè)。因此,作者定義了一個(gè)權(quán)重函數(shù)來(lái)平衡多個(gè)損失函數(shù)的重要性。這種加權(quán)調(diào)整了有效學(xué)習(xí)率、梯度在網(wǎng)絡(luò)中的流動(dòng)方式以及混合信息在特征中表示的方式。
加權(quán)函數(shù)具體表示如下:
其中是一個(gè)超參數(shù),的曲線如下圖所示:
3.4. From manifold mixing to MixMo
相比于其他MSDA方法,MixMo使用兩個(gè)獨(dú)立的編碼器(每個(gè)編碼器輸入一個(gè)數(shù)據(jù)),并且它輸出是兩個(gè)預(yù)測(cè)而不是一個(gè)。
而其他MSDA方法使用一個(gè)單一的分類器,該分類器針對(duì)一個(gè)唯一的軟標(biāo)簽,通過(guò)線性插值反映不同的類。相反,MixMo選擇充分利用混合樣本的復(fù)合特性,訓(xùn)練分離的dense層,d0和d1,在測(cè)試時(shí)能夠在沒(méi)有額外計(jì)算的情況下,達(dá)到模型集成的效果。
4.實(shí)驗(yàn)
4.1. Main results on CIFAR-100 and CIFAR-10
上表展示了MixMo在CIFAR10和CIFAR100上的結(jié)果,可以看出相比于原始的網(wǎng)絡(luò),MixMo對(duì)于性能的提升非常明顯。
從上圖可以看出,隨著寬度w的增加,MixMo比DE(綠色曲線)的性能提升更加明顯。
4.2. Training time
可以看出,在相同的訓(xùn)練時(shí)間內(nèi),Cut-MixMo的表現(xiàn)優(yōu)于DE。
4.3. The mixing block
上表比較幾個(gè)mix block的性能,可以看出無(wú)論形狀如何,binary mixing的性能都優(yōu)于線性混合。
4.4. Weighting function
上圖比較了加權(quán)函數(shù)不同r下的性能,r在[3,6]范圍內(nèi)達(dá)到了很好的trade-off。
4.5. Multiple encoders and classifiers
上表的實(shí)驗(yàn)結(jié)果表明,2個(gè)編碼器和2個(gè)分類器對(duì)于實(shí)驗(yàn)結(jié)果是比較好的。
4.6. Pushing MixMo further: Tiny ImageNet
在更大的規(guī)模和更多樣的64 × 64圖像上,Cut-MixMo在Tiny ImageNet上達(dá)到了70.24%的新水平,如上表所示。
5. 總結(jié)
在本文中,作者提出了MixMo,一個(gè)多輸入多輸出策略的框架。MixMo可以作為一種集成方法或一種新的混合樣本數(shù)據(jù)增強(qiáng)方法進(jìn)行分析,同時(shí)仍然與兩種研究方向的工作保持互補(bǔ)。此外,作者引入了一個(gè)新的權(quán)重函數(shù),以平衡訓(xùn)練時(shí)的損失。最終,作者通過(guò)實(shí)驗(yàn)證明了MixMo的有效性。
參考文獻(xiàn)
[1]. Marton Havasi, Rodolphe Jenatton, Stanislav Fort,Jeremiah Liu, Jasper Roland Snoek, Balaji Lakshminarayanan, Andrew Mingbo Dai, and Dustin Tran. Training independent subnetworks for robust prediction. In ICLR,2021.
*博客內(nèi)容為網(wǎng)友個(gè)人發(fā)布,僅代表博主個(gè)人觀點(diǎn),如有侵權(quán)請(qǐng)聯(lián)系工作人員刪除。