博客專欄

EEPW首頁 > 博客 > 知識蒸餾綜述:代碼整理(1)

知識蒸餾綜述:代碼整理(1)

發(fā)布人:計算機視覺工坊 時間:2022-01-16 來源:工程師 發(fā)布文章

作者 | PPRP 

來源 | GiantPandaCV

編輯 | 極市平臺

導(dǎo)讀

本文收集自RepDistiller中的蒸餾方法,盡可能簡單解釋蒸餾用到的策略,并提供了實現(xiàn)源碼。

1. KD: Knowledge Distillation

全稱:Distilling the Knowledge in a Neural Network

鏈接:https://arxiv.org/pdf/1503.02531.pd3f

發(fā)表:NIPS14

最經(jīng)典的,也是明確提出知識蒸餾概念的工作,通過使用帶溫度的softmax函數(shù)來軟化教師網(wǎng)絡(luò)的邏輯層輸出作為學生網(wǎng)絡(luò)的監(jiān)督信息,

使用KL divergence來衡量學生網(wǎng)絡(luò)與教師網(wǎng)絡(luò)的差異,具體流程如下圖所示(來自Knowledge Distillation A Survey)

1.jpg

對學生網(wǎng)絡(luò)來說,一部分監(jiān)督信息來自hard label標簽,另一部分來自教師網(wǎng)絡(luò)提供的soft label。代碼實現(xiàn):

class DistillKL(nn.Module):
    """Distilling the Knowledge in a Neural Network"""
    def __init__(self, T):
        super(DistillKL, self).__init__()
        self.T = T
    def forward(self, y_s, y_t):
        p_s = F.log_softmax(y_s/self.T, dim=1)
        p_t = F.softmax(y_t/self.T, dim=1)
        loss = F.kl_div(p_s, p_t, size_average=False) * (self.T**2) / y_s.shape[0]
        return loss

核心就是一個kl_div函數(shù),用于計算學生網(wǎng)絡(luò)和教師網(wǎng)絡(luò)的分布差異。

2. FitNet: Hints for thin deep nets

全稱:Fitnets: hints for thin deep nets

鏈接:https://arxiv.org/pdf/1412.6550.pdf

發(fā)表:ICLR 15 Poster

對中間層進行蒸餾的開山之作,通過將學生網(wǎng)絡(luò)的feature map擴展到與教師網(wǎng)絡(luò)的feature map相同尺寸以后,使用均方誤差MSE Loss來衡量兩者差異。

2.jpg

實現(xiàn)如下:

class HintLoss(nn.Module):
    """Fitnets: hints for thin deep nets, ICLR 2015"""
    def __init__(self):
        super(HintLoss, self).__init__()
        self.crit = nn.MSELoss()
    def forward(self, f_s, f_t):
        loss = self.crit(f_s, f_t)
        return loss

實現(xiàn)核心就是MSELoss。

3. AT: Attention Transfer

全稱:Paying More Attention to Attention: Improving the Performance of Convolutional Neural Networks via Attention Transfer

鏈接:https://arxiv.org/pdf/1612.03928.pdf

發(fā)表:ICLR16

為了提升學生模型性能提出使用注意力作為知識載體進行遷移,文中提到了兩種注意力,一種是activation-based attention transfer,另一種是gradient-based attention transfer。實驗發(fā)現(xiàn)第一種方法既簡單效果又好。

3.jpg

實現(xiàn)如下:

class Attention(nn.Module):
    """Paying More Attention to Attention: Improving the Performance of Convolutional Neural Networks
    via Attention Transfer
    code: https://github.com/szagoruyko/attention-transfer"""
    def __init__(self, p=2):
        super(Attention, self).__init__()
        self.p = p
    def forward(self, g_s, g_t):
        return [self.at_loss(f_s, f_t) for f_s, f_t in zip(g_s, g_t)]
    def at_loss(self, f_s, f_t):
        s_H, t_H = f_s.shape[2], f_t.shape[2]
        if s_H > t_H:
            f_s = F.adaptive_avg_pool2d(f_s, (t_H, t_H))
        elif s_H < t_H:
            f_t = F.adaptive_avg_pool2d(f_t, (s_H, s_H))
        else:
            pass
        return (self.at(f_s) - self.at(f_t)).pow(2).mean()
    def at(self, f):
        return F.normalize(f.pow(self.p).mean(1).view(f.size(0), -1))

首先使用avgpool將尺寸調(diào)整一致,然后使用MSE Loss來衡量兩者差距。

4. SP: Similarity-Preserving

全稱:Similarity-Preserving Knowledge Distillation

鏈接:https://arxiv.org/pdf/1907.09682.pdf

發(fā)表:ICCV19SP

歸屬于基于關(guān)系的知識蒸餾方法。文章思想是提出相似性保留的知識,使得教師網(wǎng)絡(luò)和學生網(wǎng)絡(luò)會對相同的樣本產(chǎn)生相似的激活??梢詮南聢D看出處理流程,教師網(wǎng)絡(luò)和學生網(wǎng)絡(luò)對應(yīng)feature map通過計算內(nèi)積,得到bsxbs的相似度矩陣,然后使用均方誤差來衡量兩個相似度矩陣。

4.jpg

最終Loss為:

G代表的就是bsxbs的矩陣。實現(xiàn)如下:

class Similarity(nn.Module):
    """Similarity-Preserving Knowledge Distillation, ICCV2019, verified by original author"""
    def __init__(self):
        super(Similarity, self).__init__()
    def forward(self, g_s, g_t):
        return [self.similarity_loss(f_s, f_t) for f_s, f_t in zip(g_s, g_t)]
    def similarity_loss(self, f_s, f_t):
        bsz = f_s.shape[0]
        f_s = f_s.view(bsz, -1)
        f_t = f_t.view(bsz, -1)
        G_s = torch.mm(f_s, torch.t(f_s))
        # G_s = G_s / G_s.norm(2)
        G_s = torch.nn.functional.normalize(G_s)
        G_t = torch.mm(f_t, torch.t(f_t))
        # G_t = G_t / G_t.norm(2)
        G_t = torch.nn.functional.normalize(G_t)
        G_diff = G_t - G_s
        loss = (G_diff * G_diff).view(-1, 1).sum(0) / (bsz * bsz)
        return loss

5. CC: Correlation Congruence

全稱:Correlation Congruence for Knowledge Distillation

鏈接:https://arxiv.org/pdf/1904.01802.pdf

發(fā)表:ICCV19

CC也歸屬于基于關(guān)系的知識蒸餾方法。不應(yīng)該僅僅引導(dǎo)教師網(wǎng)絡(luò)和學生網(wǎng)絡(luò)單個樣本向量之間的差異,還應(yīng)該學習兩個樣本之間的相關(guān)性,而這個相關(guān)性使用的是Correlation Congruence 教師網(wǎng)絡(luò)雨學生網(wǎng)絡(luò)相關(guān)性之間的歐氏距離。

整體Loss如下:

實現(xiàn)如下:

class Correlation(nn.Module):
    """Similarity-preserving loss. My origianl own reimplementation 
    based on the paper before emailing the original authors."""
    def __init__(self):
        super(Correlation, self).__init__()
    def forward(self, f_s, f_t):
        return self.similarity_loss(f_s, f_t)
    def similarity_loss(self, f_s, f_t):
        bsz = f_s.shape[0]
        f_s = f_s.view(bsz, -1)
        f_t = f_t.view(bsz, -1)
        G_s = torch.mm(f_s, torch.t(f_s))
        G_s = G_s / G_s.norm(2)
        G_t = torch.mm(f_t, torch.t(f_t))
        G_t = G_t / G_t.norm(2)
        G_diff = G_t - G_s
        loss = (G_diff * G_diff).view(-1, 1).sum(0) / (bsz * bsz)
        return loss


*博客內(nèi)容為網(wǎng)友個人發(fā)布,僅代表博主個人觀點,如有侵權(quán)請聯(lián)系工作人員刪除。



關(guān)鍵詞: AI

相關(guān)推薦

技術(shù)專區(qū)

關(guān)閉