博客專欄

EEPW首頁 > 博客 > 知識蒸餾綜述:代碼整理(2)

知識蒸餾綜述:代碼整理(2)

發(fā)布人:計算機視覺工坊 時間:2022-01-16 來源:工程師 發(fā)布文章

6. VID: Variational Information Distillation

全稱:Variational Information Distillation for Knowledge Transfer

鏈接:https://arxiv.org/pdf/1904.05835.pdf

發(fā)表:CVPR19

5.jpg

利用互信息(Mutual Information)來衡量學生網(wǎng)絡和教師網(wǎng)絡差異。互信息可以表示出兩個變量的互相依賴程度,其值越大,表示變量之間的依賴程度越高。互信息計算如下:

互信息是教師模型的熵減去在已知學生模型條件下教師模型的熵。目標是最大化互信息,因為互信息越大說明H(t|s)越小,即學生網(wǎng)絡確定的情況下,教師網(wǎng)絡的熵會變小,證明學生網(wǎng)絡已經(jīng)學習的比較充分。整體loss如下:

由于p(t|s)很難計算,可以使用變分分布q(t|s)去接近真實分布。

其中q(t|s)是使用方差可學習的高斯分布模擬(公式中的log_scale):

實現(xiàn)如下:

class VIDLoss(nn.Module):
    """Variational Information Distillation for Knowledge Transfer (CVPR 2019),
    code from author: https://github.com/ssahn0215/variational-information-distillation"""
    def __init__(self,
                 num_input_channels,
                 num_mid_channel,
                 num_target_channels,
                 init_pred_var=5.0,
                 eps=1e-5):
        super(VIDLoss, self).__init__()
        def conv1x1(in_channels, out_channels, stride=1):
            return nn.Conv2d(
                in_channels, out_channels,
                kernel_size=1, padding=0,
                bias=False, stride=stride)
        self.regressor = nn.Sequential(
            conv1x1(num_input_channels, num_mid_channel),
            nn.ReLU(),
            conv1x1(num_mid_channel, num_mid_channel),
            nn.ReLU(),
            conv1x1(num_mid_channel, num_target_channels),
        )
        self.log_scale = torch.nn.Parameter(
            np.log(np.exp(init_pred_var-eps)-1.0) * torch.ones(num_target_channels)
            )
        self.eps = eps
    def forward(self, input, target):
        # pool for dimentsion match
        s_H, t_H = input.shape[2], target.shape[2]
        if s_H > t_H:
            input = F.adaptive_avg_pool2d(input, (t_H, t_H))
        elif s_H < t_H:
            target = F.adaptive_avg_pool2d(target, (s_H, s_H))
        else:
            pass
        pred_mean = self.regressor(input)
        pred_var = torch.log(1.0+torch.exp(self.log_scale))+self.eps
        pred_var = pred_var.view(1, -1, 1, 1)
        neg_log_prob = 0.5*(
            (pred_mean-target)**2/pred_var+torch.log(pred_var)
            )
        loss = torch.mean(neg_log_prob)
        return loss

7. RKD: Relation Knowledge Distillation

全稱:Relational Knowledge Disitllation

鏈接:http://arxiv.org/pdf/1904.05068

發(fā)表:CVPR19

RKD也是基于關系的知識蒸餾方法,RKD提出了兩種損失函數(shù),二階的距離損失和三階的角度損失。

Distance-wise Loss

Angle-wise Loss

實現(xiàn)如下:

class RKDLoss(nn.Module):
    """Relational Knowledge Disitllation, CVPR2019"""
    def __init__(self, w_d=25, w_a=50):
        super(RKDLoss, self).__init__()
        self.w_d = w_d
        self.w_a = w_a
    def forward(self, f_s, f_t):
        student = f_s.view(f_s.shape[0], -1)
        teacher = f_t.view(f_t.shape[0], -1)
        # RKD distance loss
        with torch.no_grad():
            t_d = self.pdist(teacher, squared=False)
            mean_td = t_d[t_d > 0].mean()
            t_d = t_d / mean_td
        d = self.pdist(student, squared=False)
        mean_d = d[d > 0].mean()
        d = d / mean_d
        loss_d = F.smooth_l1_loss(d, t_d)
        # RKD Angle loss
        with torch.no_grad():
            td = (teacher.unsqueeze(0) - teacher.unsqueeze(1))
            norm_td = F.normalize(td, p=2, dim=2)
            t_angle = torch.bmm(norm_td, norm_td.transpose(1, 2)).view(-1)
        sd = (student.unsqueeze(0) - student.unsqueeze(1))
        norm_sd = F.normalize(sd, p=2, dim=2)
        s_angle = torch.bmm(norm_sd, norm_sd.transpose(1, 2)).view(-1)
        loss_a = F.smooth_l1_loss(s_angle, t_angle)
        loss = self.w_d * loss_d + self.w_a * loss_a
        return loss
    @staticmethod
    def pdist(e, squared=False, eps=1e-12):
        e_square = e.pow(2).sum(dim=1)
        prod = e @ e.t()
        res = (e_square.unsqueeze(1) + e_square.unsqueeze(0) - 2 * prod).clamp(min=eps)
        if not squared:
            res = res.sqrt()
        res = res.clone()
        res[range(len(e)), range(len(e))] = 0
        return res

8. PKT:Probabilistic Knowledge Transfer

全稱:Probabilistic Knowledge Transfer for deep representation learning鏈接:https://arxiv.org/abs/1803.10837發(fā)表:CoRR18

提出一種概率知識轉(zhuǎn)移方法,引入了互信息來進行建模。該方法具有可跨模態(tài)知識轉(zhuǎn)移、無需考慮任務類型、可將手工特征融入網(wǎng)絡等有點。

6.jpg

實現(xiàn)如下:

class PKT(nn.Module):
    """Probabilistic Knowledge Transfer for deep representation learning
    Code from author: https://github.com/passalis/probabilistic_kt"""
    def __init__(self):
        super(PKT, self).__init__()
    def forward(self, f_s, f_t):
        return self.cosine_similarity_loss(f_s, f_t)
    @staticmethod
    def cosine_similarity_loss(output_net, target_net, eps=0.0000001):
        # Normalize each vector by its norm
        output_net_norm = torch.sqrt(torch.sum(output_net ** 2, dim=1, keepdim=True))
        output_net = output_net / (output_net_norm + eps)
        output_net[output_net != output_net] = 0
        target_net_norm = torch.sqrt(torch.sum(target_net ** 2, dim=1, keepdim=True))
        target_net = target_net / (target_net_norm + eps)
        target_net[target_net != target_net] = 0
        # Calculate the cosine similarity
        model_similarity = torch.mm(output_net, output_net.transpose(0, 1))
        target_similarity = torch.mm(target_net, target_net.transpose(0, 1))
        # Scale cosine similarity to 0..1
        model_similarity = (model_similarity + 1.0) / 2.0
        target_similarity = (target_similarity + 1.0) / 2.0
        # Transform them into probabilities
        model_similarity = model_similarity / torch.sum(model_similarity, dim=1, keepdim=True)
        target_similarity = target_similarity / torch.sum(target_similarity, dim=1, keepdim=True)
        # Calculate the KL-divergence
        loss = torch.mean(target_similarity * torch.log((target_similarity + eps) / (model_similarity + eps)))
        return loss

9. AB: Activation Boundaries

全稱:Knowledge Transfer via Distillation of Activation Boundaries Formed by Hidden Neurons

鏈接:https://arxiv.org/pdf/1811.03233.pdf

發(fā)表:AAAI18

目標:讓教師網(wǎng)絡層的神經(jīng)元的激活邊界盡量和學生網(wǎng)絡的一樣。所謂的激活邊界指的是分離超平面(針對的是RELU這種激活函數(shù)),其決定了神經(jīng)元的激活與失活。AB提出的激活轉(zhuǎn)移損失,讓教師網(wǎng)絡與學生網(wǎng)絡之間的分離邊界盡可能一致。

7.jpg

實現(xiàn)如下:

class ABLoss(nn.Module):
    """Knowledge Transfer via Distillation of Activation Boundaries Formed by Hidden Neurons
    code: https://github.com/bhheo/AB_distillation
    """
    def __init__(self, feat_num, margin=1.0):
        super(ABLoss, self).__init__()
        self.w = [2**(i-feat_num+1) for i in range(feat_num)]
        self.margin = margin
    def forward(self, g_s, g_t):
        bsz = g_s[0].shape[0]
        losses = [self.criterion_alternative_l2(s, t) for s, t in zip(g_s, g_t)]
        losses = [w * l for w, l in zip(self.w, losses)] 
        # loss = sum(losses) / bsz
        # loss = loss / 1000 * 3
        losses = [l / bsz for l in losses]
        losses = [l / 1000 * 3 for l in losses]
        return losses
    def criterion_alternative_l2(self, source, target):
        loss = ((source + self.margin) ** 2 * ((source > -self.margin) & (target <= 0)).float() +
                (source - self.margin) ** 2 * ((source <= self.margin) & (target > 0)).float())
        return torch.abs(loss).sum()

10. FT: Factor Transfer

全稱:Paraphrasing Complex Network: Network Compression via Factor Transfer

鏈接:https://arxiv.org/pdf/1802.04977.pdf

發(fā)表:NIPS18

提出的是factor transfer的方法。所謂的factor,其實是對模型最后的數(shù)據(jù)結(jié)果進行一個編解碼的過程,提取出的一個factor矩陣,用教師網(wǎng)絡的factor來指導學生網(wǎng)絡的factor。

8.jpg

FT計算公式為:

實現(xiàn)如下:

class FactorTransfer(nn.Module):
    """Paraphrasing Complex Network: Network Compression via Factor Transfer, NeurIPS 2018"""
    def __init__(self, p1=2, p2=1):
        super(FactorTransfer, self).__init__()
        self.p1 = p1
        self.p2 = p2
    def forward(self, f_s, f_t):
        return self.factor_loss(f_s, f_t)
    def factor_loss(self, f_s, f_t):
        s_H, t_H = f_s.shape[2], f_t.shape[2]
        if s_H > t_H:
            f_s = F.adaptive_avg_pool2d(f_s, (t_H, t_H))
        elif s_H < t_H:
            f_t = F.adaptive_avg_pool2d(f_t, (s_H, s_H))
        else:
            pass
        if self.p2 == 1:
            return (self.factor(f_s) - self.factor(f_t)).abs().mean()
        else:
            return (self.factor(f_s) - self.factor(f_t)).pow(self.p2).mean()
    def factor(self, f):
        return F.normalize(f.pow(self.p1).mean(1).view(f.size(0), -1))


*博客內(nèi)容為網(wǎng)友個人發(fā)布,僅代表博主個人觀點,如有侵權請聯(lián)系工作人員刪除。



關鍵詞: AI

相關推薦

技術專區(qū)

關閉