博客專欄

EEPW首頁 > 博客 > NeurIPS 2022 | 四分鐘內就能訓練目標檢測器,商湯基模型團隊是怎么做到的?

NeurIPS 2022 | 四分鐘內就能訓練目標檢測器,商湯基模型團隊是怎么做到的?

發(fā)布人:機器之心 時間:2022-11-19 來源:工程師 發(fā)布文章

來自商湯的基模型團隊和香港大學等機構的研究人員提出了一種大批量訓練算法 AGVM,該研究已被NeurIPS 2022接收。


本文提出了一種大批量訓練算法 AGVM (Adaptive Gradient Variance Modulator),不僅可以適配于目標檢測任務,同時也可以適配各類分割任務。AGVM 可以把目標檢測的訓練批量大小擴大到 1536,幫助研究人員四分鐘訓練 Faster R-CNN,3.5 小時把 COCO 刷到 62.2 mAP,均打破了目標檢測訓練速度的世界紀錄。


圖片


  • 論文地址:https://arxiv.org/pdf/2210.11078.pdf

  • 代碼地址:https://github.com/Sense-X/AGVM


在當前的機器學習社區(qū)中,有三個普遍的趨勢。首先,神經(jīng)網(wǎng)絡模型會越來越大。在 NLP 領域中最大規(guī)模的模型已經(jīng)達到了上萬億級別。在視覺領域,最大規(guī)模的模型也達到了三百億的量級。其次,訓練的數(shù)據(jù)集也變得越來越大。比如,ImageNet 21k 和谷歌的 JFT 數(shù)據(jù)集都具有相當規(guī)模的數(shù)據(jù)集。另外,由于數(shù)據(jù)集變得越來越大,訓練 SOTA 模型的開銷越來越大。


因此,提升訓練效率就變得愈發(fā)重要。而分布式訓練因為其適應于數(shù)據(jù)并行、模型并行和流水線并行的加速訓練方法的同時,也具備較高的 Deep Learning 通信效率而被廣泛認為是一個有效的解決方案。


隨著大模型時代的到來,目標檢測器的訓練速度越來越成為學術界和工業(yè)界的瓶頸,例如,在 COCO 的標準 setting 上把 mAP 訓到 62 以上大概需要三天的時間,算上調試成本,這在業(yè)界幾乎是不可接受的。那么,我們能不能把這個訓練時間壓到小時級別呢?事實上,在圖片分類和自然語言處理任務上,先前的研究人員借助 32K 的批量大小(batch size),只需 14 分鐘就可以完成 ImageNet 的訓練,76 分鐘完成 Bert 的訓練。但是,在目標檢測領域,還很欠缺這類研究,導致研究人員無法充分利用當前的算力,數(shù)據(jù)集和大模型。


大批量訓練算法 AGVM 便是這個問題的最佳解決方案之一。為了支持如此大批量的訓練,同時保持模型的訓練精度,本研究提出了一套全新的訓練算法,根據(jù)密集預測不同模塊的梯度方差(gradient variance),動態(tài)調整每一個模塊的學習率。作者在大量的密集預測網(wǎng)絡和數(shù)據(jù)集上進行了實驗,并且證實了該方法的合理性。

 

方法介紹


大批量訓練是加速大型分布式系統(tǒng)中深度神經(jīng)網(wǎng)絡訓練的關鍵。尤其是在如今的大模型時代,如果不采用大批量訓練,一個網(wǎng)絡的訓練時間幾乎是難以接受的。但是,大批量訓練很難,因為它會產生泛化差距(generalization gap), 直接訓練會導致其準確率降低。此前的大批量工作往往針對于圖像分類以及一些自然語言處理的任務,但密集預測任務(包括檢測分割等),同樣在視覺中處于舉足輕重的位置,此前的方法并不能在密集預測任務上有很好的表現(xiàn),甚至結果比基準線更差,這導致我們難以快速訓練一個目標檢測器。

 

為了解決這個問題,研究人員進行了大量的實驗。最后發(fā)現(xiàn),相較于傳統(tǒng)的分類網(wǎng)絡,利用密集預測網(wǎng)絡一個很重要的特征:密集預測網(wǎng)絡往往是由多個組件組成的,以 Faster R-CNN 為例:它由四個部分組成,骨干網(wǎng)絡 (Backbone),特征金字塔網(wǎng)絡(FPN),區(qū)域生成網(wǎng)絡(RPN) 和檢測頭網(wǎng)絡(head),我們可以發(fā)現(xiàn)一個很有效的指標:密集預測網(wǎng)絡不同組件的梯度方差,在訓練批量很小時(例如 32),幾乎是相同的,但當訓練批量很大時(例如 512),它們呈現(xiàn)出很大的區(qū)別,如下圖所示:


圖片

那么,能不能直接把這些拉平呢?這直接引出了 AGVM 算法。以隨機梯度下降算法為例,上角標 i 代表第 i 個網(wǎng)絡模塊(例如 FPN 等),上角標 1 代表骨干網(wǎng)絡,圖片代表學習率,錨定骨干網(wǎng)絡,可以直接將不同網(wǎng)絡組件的梯度 g 的方差圖片


圖片


梯度的方差圖片可以由以下式子估計:


圖片

方差的具體求解細節(jié)可以參考原文,本研究同樣引入了滑動平均機制,防止網(wǎng)絡訓練發(fā)散。同時,研究證明了 AGVM 在非凸情況下的收斂性,討論了動量以及衰減的處理方式,具體實現(xiàn)細節(jié)可以參考原文。

 

實驗過程


本研究首先在目標檢測、實例分割、全景分割和語義分割的各種密集預測網(wǎng)絡上進行了測試,通過下表可以看到,當用標準批量大小訓練時,AGVM 相較傳統(tǒng)方法沒有明顯優(yōu)勢,但當在超大批量下訓練時,AGVM 相較傳統(tǒng)方法擁有壓倒性的優(yōu)勢,下圖第二列從左至右分別表示目標檢測,實例分割,全景分割和語義分割的表現(xiàn),AGVM 超越了有史以來的所有方法:


圖片

下表詳細對比了 AGVM 和傳統(tǒng)方法,體現(xiàn)出了本研究方法的優(yōu)勢:


圖片

同時,為了說明 AGVM 的優(yōu)越性,本研究進行了以下三個超大規(guī)模的實驗。研究人員把 Faster R-CNN 的 batch size 放到了 1536,這樣利用 768 張 A100 可以在 4.2 分鐘內完成訓練。其次,借助 UniNet-G,本研究可以在利用 480 張 A100 的情況下,3.5 個小時讓模型在 COCO 上達到 62.2mAP(不包括骨干網(wǎng)絡預訓練的時間),極大的減小了訓練時間:


圖片

甚至,在 RetinaNet 上,本研究把批量大小擴展到 10K。這在目標檢測領域是從未見的批量大小,在如此大的批量下,每一個 epoch 只有十幾個迭代次數(shù),AGVM 在如此大的批量下,仍然能展現(xiàn)出很強的穩(wěn)定性,性能如下圖所示:


圖片

結果分析


本研究探究了一個很重要的問題:以 RetinaNet 為例,如下圖第一列所示,探究為什么會出現(xiàn)梯度方差不匹配這一現(xiàn)象。


本研究認為,這一現(xiàn)象來自于:網(wǎng)絡不同模塊間的有效批量大小 (effective batch size) 是不同的。例如,RetinaNet 的頭網(wǎng)絡的輸入是由特征金字塔的五層網(wǎng)絡輸出的,特征金字塔的 top-down 和 bottom-up pathways,以及像素維度的損失函數(shù)計算會導致頭網(wǎng)絡和骨干網(wǎng)絡的等效批量大小不同,這一原理導致了梯度方差不匹配的現(xiàn)象。


為了驗證這一假設,本研究依次給每一層特征使用單獨的頭網(wǎng)絡,移去特征金字塔網(wǎng)絡,隨機忽略掉 75% 的用于計算損失函數(shù)的像素,最終,本研究發(fā)現(xiàn)骨干網(wǎng)絡和頭網(wǎng)絡的梯度方差曲線重合了,本研究也對 Faster R-CNN 做了類似的實驗,如下圖第二列所示,更多的討論請參見原文。


圖片

圖片



*博客內容為網(wǎng)友個人發(fā)布,僅代表博主個人觀點,如有侵權請聯(lián)系工作人員刪除。



關鍵詞: AI

相關推薦

技術專區(qū)

關閉