知識蒸餾綜述:代碼整理(1)
作者 | PPRP
來源 | GiantPandaCV
編輯 | 極市平臺
導(dǎo)讀
本文收集自RepDistiller中的蒸餾方法,盡可能簡單解釋蒸餾用到的策略,并提供了實現(xiàn)源碼。
1. KD: Knowledge Distillation
全稱:Distilling the Knowledge in a Neural Network
鏈接:https://arxiv.org/pdf/1503.02531.pd3f
發(fā)表:NIPS14
最經(jīng)典的,也是明確提出知識蒸餾概念的工作,通過使用帶溫度的softmax函數(shù)來軟化教師網(wǎng)絡(luò)的邏輯層輸出作為學生網(wǎng)絡(luò)的監(jiān)督信息,
使用KL divergence來衡量學生網(wǎng)絡(luò)與教師網(wǎng)絡(luò)的差異,具體流程如下圖所示(來自Knowledge Distillation A Survey)
對學生網(wǎng)絡(luò)來說,一部分監(jiān)督信息來自hard label標簽,另一部分來自教師網(wǎng)絡(luò)提供的soft label。代碼實現(xiàn):
class DistillKL(nn.Module): """Distilling the Knowledge in a Neural Network""" def __init__(self, T): super(DistillKL, self).__init__() self.T = T def forward(self, y_s, y_t): p_s = F.log_softmax(y_s/self.T, dim=1) p_t = F.softmax(y_t/self.T, dim=1) loss = F.kl_div(p_s, p_t, size_average=False) * (self.T**2) / y_s.shape[0] return loss
核心就是一個kl_div函數(shù),用于計算學生網(wǎng)絡(luò)和教師網(wǎng)絡(luò)的分布差異。
2. FitNet: Hints for thin deep nets
全稱:Fitnets: hints for thin deep nets
鏈接:https://arxiv.org/pdf/1412.6550.pdf
發(fā)表:ICLR 15 Poster
對中間層進行蒸餾的開山之作,通過將學生網(wǎng)絡(luò)的feature map擴展到與教師網(wǎng)絡(luò)的feature map相同尺寸以后,使用均方誤差MSE Loss來衡量兩者差異。
實現(xiàn)如下:
class HintLoss(nn.Module): """Fitnets: hints for thin deep nets, ICLR 2015""" def __init__(self): super(HintLoss, self).__init__() self.crit = nn.MSELoss() def forward(self, f_s, f_t): loss = self.crit(f_s, f_t) return loss
實現(xiàn)核心就是MSELoss。
3. AT: Attention Transfer
全稱:Paying More Attention to Attention: Improving the Performance of Convolutional Neural Networks via Attention Transfer
鏈接:https://arxiv.org/pdf/1612.03928.pdf
發(fā)表:ICLR16
為了提升學生模型性能提出使用注意力作為知識載體進行遷移,文中提到了兩種注意力,一種是activation-based attention transfer,另一種是gradient-based attention transfer。實驗發(fā)現(xiàn)第一種方法既簡單效果又好。
實現(xiàn)如下:
class Attention(nn.Module): """Paying More Attention to Attention: Improving the Performance of Convolutional Neural Networks via Attention Transfer code: https://github.com/szagoruyko/attention-transfer""" def __init__(self, p=2): super(Attention, self).__init__() self.p = p def forward(self, g_s, g_t): return [self.at_loss(f_s, f_t) for f_s, f_t in zip(g_s, g_t)] def at_loss(self, f_s, f_t): s_H, t_H = f_s.shape[2], f_t.shape[2] if s_H > t_H: f_s = F.adaptive_avg_pool2d(f_s, (t_H, t_H)) elif s_H < t_H: f_t = F.adaptive_avg_pool2d(f_t, (s_H, s_H)) else: pass return (self.at(f_s) - self.at(f_t)).pow(2).mean() def at(self, f): return F.normalize(f.pow(self.p).mean(1).view(f.size(0), -1))
首先使用avgpool將尺寸調(diào)整一致,然后使用MSE Loss來衡量兩者差距。
4. SP: Similarity-Preserving
全稱:Similarity-Preserving Knowledge Distillation
鏈接:https://arxiv.org/pdf/1907.09682.pdf
發(fā)表:ICCV19SP
歸屬于基于關(guān)系的知識蒸餾方法。文章思想是提出相似性保留的知識,使得教師網(wǎng)絡(luò)和學生網(wǎng)絡(luò)會對相同的樣本產(chǎn)生相似的激活??梢詮南聢D看出處理流程,教師網(wǎng)絡(luò)和學生網(wǎng)絡(luò)對應(yīng)feature map通過計算內(nèi)積,得到bsxbs的相似度矩陣,然后使用均方誤差來衡量兩個相似度矩陣。
最終Loss為:
G代表的就是bsxbs的矩陣。實現(xiàn)如下:
class Similarity(nn.Module): """Similarity-Preserving Knowledge Distillation, ICCV2019, verified by original author""" def __init__(self): super(Similarity, self).__init__() def forward(self, g_s, g_t): return [self.similarity_loss(f_s, f_t) for f_s, f_t in zip(g_s, g_t)] def similarity_loss(self, f_s, f_t): bsz = f_s.shape[0] f_s = f_s.view(bsz, -1) f_t = f_t.view(bsz, -1) G_s = torch.mm(f_s, torch.t(f_s)) # G_s = G_s / G_s.norm(2) G_s = torch.nn.functional.normalize(G_s) G_t = torch.mm(f_t, torch.t(f_t)) # G_t = G_t / G_t.norm(2) G_t = torch.nn.functional.normalize(G_t) G_diff = G_t - G_s loss = (G_diff * G_diff).view(-1, 1).sum(0) / (bsz * bsz) return loss
5. CC: Correlation Congruence
全稱:Correlation Congruence for Knowledge Distillation
鏈接:https://arxiv.org/pdf/1904.01802.pdf
發(fā)表:ICCV19
CC也歸屬于基于關(guān)系的知識蒸餾方法。不應(yīng)該僅僅引導(dǎo)教師網(wǎng)絡(luò)和學生網(wǎng)絡(luò)單個樣本向量之間的差異,還應(yīng)該學習兩個樣本之間的相關(guān)性,而這個相關(guān)性使用的是Correlation Congruence 教師網(wǎng)絡(luò)雨學生網(wǎng)絡(luò)相關(guān)性之間的歐氏距離。
整體Loss如下:
實現(xiàn)如下:
class Correlation(nn.Module): """Similarity-preserving loss. My origianl own reimplementation based on the paper before emailing the original authors.""" def __init__(self): super(Correlation, self).__init__() def forward(self, f_s, f_t): return self.similarity_loss(f_s, f_t) def similarity_loss(self, f_s, f_t): bsz = f_s.shape[0] f_s = f_s.view(bsz, -1) f_t = f_t.view(bsz, -1) G_s = torch.mm(f_s, torch.t(f_s)) G_s = G_s / G_s.norm(2) G_t = torch.mm(f_t, torch.t(f_t)) G_t = G_t / G_t.norm(2) G_diff = G_t - G_s loss = (G_diff * G_diff).view(-1, 1).sum(0) / (bsz * bsz) return loss
*博客內(nèi)容為網(wǎng)友個人發(fā)布,僅代表博主個人觀點,如有侵權(quán)請聯(lián)系工作人員刪除。