知識蒸餾綜述:代碼整理(2)
6. VID: Variational Information Distillation
全稱:Variational Information Distillation for Knowledge Transfer
鏈接:https://arxiv.org/pdf/1904.05835.pdf
發(fā)表:CVPR19
利用互信息(Mutual Information)來衡量學生網(wǎng)絡和教師網(wǎng)絡差異。互信息可以表示出兩個變量的互相依賴程度,其值越大,表示變量之間的依賴程度越高。互信息計算如下:
互信息是教師模型的熵減去在已知學生模型條件下教師模型的熵。目標是最大化互信息,因為互信息越大說明H(t|s)越小,即學生網(wǎng)絡確定的情況下,教師網(wǎng)絡的熵會變小,證明學生網(wǎng)絡已經(jīng)學習的比較充分。整體loss如下:
由于p(t|s)很難計算,可以使用變分分布q(t|s)去接近真實分布。
其中q(t|s)是使用方差可學習的高斯分布模擬(公式中的log_scale):
實現(xiàn)如下:
class VIDLoss(nn.Module): """Variational Information Distillation for Knowledge Transfer (CVPR 2019), code from author: https://github.com/ssahn0215/variational-information-distillation""" def __init__(self, num_input_channels, num_mid_channel, num_target_channels, init_pred_var=5.0, eps=1e-5): super(VIDLoss, self).__init__() def conv1x1(in_channels, out_channels, stride=1): return nn.Conv2d( in_channels, out_channels, kernel_size=1, padding=0, bias=False, stride=stride) self.regressor = nn.Sequential( conv1x1(num_input_channels, num_mid_channel), nn.ReLU(), conv1x1(num_mid_channel, num_mid_channel), nn.ReLU(), conv1x1(num_mid_channel, num_target_channels), ) self.log_scale = torch.nn.Parameter( np.log(np.exp(init_pred_var-eps)-1.0) * torch.ones(num_target_channels) ) self.eps = eps def forward(self, input, target): # pool for dimentsion match s_H, t_H = input.shape[2], target.shape[2] if s_H > t_H: input = F.adaptive_avg_pool2d(input, (t_H, t_H)) elif s_H < t_H: target = F.adaptive_avg_pool2d(target, (s_H, s_H)) else: pass pred_mean = self.regressor(input) pred_var = torch.log(1.0+torch.exp(self.log_scale))+self.eps pred_var = pred_var.view(1, -1, 1, 1) neg_log_prob = 0.5*( (pred_mean-target)**2/pred_var+torch.log(pred_var) ) loss = torch.mean(neg_log_prob) return loss
7. RKD: Relation Knowledge Distillation
全稱:Relational Knowledge Disitllation
鏈接:http://arxiv.org/pdf/1904.05068
發(fā)表:CVPR19
RKD也是基于關系的知識蒸餾方法,RKD提出了兩種損失函數(shù),二階的距離損失和三階的角度損失。
Distance-wise Loss
Angle-wise Loss
實現(xiàn)如下:
class RKDLoss(nn.Module): """Relational Knowledge Disitllation, CVPR2019""" def __init__(self, w_d=25, w_a=50): super(RKDLoss, self).__init__() self.w_d = w_d self.w_a = w_a def forward(self, f_s, f_t): student = f_s.view(f_s.shape[0], -1) teacher = f_t.view(f_t.shape[0], -1) # RKD distance loss with torch.no_grad(): t_d = self.pdist(teacher, squared=False) mean_td = t_d[t_d > 0].mean() t_d = t_d / mean_td d = self.pdist(student, squared=False) mean_d = d[d > 0].mean() d = d / mean_d loss_d = F.smooth_l1_loss(d, t_d) # RKD Angle loss with torch.no_grad(): td = (teacher.unsqueeze(0) - teacher.unsqueeze(1)) norm_td = F.normalize(td, p=2, dim=2) t_angle = torch.bmm(norm_td, norm_td.transpose(1, 2)).view(-1) sd = (student.unsqueeze(0) - student.unsqueeze(1)) norm_sd = F.normalize(sd, p=2, dim=2) s_angle = torch.bmm(norm_sd, norm_sd.transpose(1, 2)).view(-1) loss_a = F.smooth_l1_loss(s_angle, t_angle) loss = self.w_d * loss_d + self.w_a * loss_a return loss @staticmethod def pdist(e, squared=False, eps=1e-12): e_square = e.pow(2).sum(dim=1) prod = e @ e.t() res = (e_square.unsqueeze(1) + e_square.unsqueeze(0) - 2 * prod).clamp(min=eps) if not squared: res = res.sqrt() res = res.clone() res[range(len(e)), range(len(e))] = 0 return res
8. PKT:Probabilistic Knowledge Transfer
全稱:Probabilistic Knowledge Transfer for deep representation learning鏈接:https://arxiv.org/abs/1803.10837發(fā)表:CoRR18
提出一種概率知識轉(zhuǎn)移方法,引入了互信息來進行建模。該方法具有可跨模態(tài)知識轉(zhuǎn)移、無需考慮任務類型、可將手工特征融入網(wǎng)絡等有點。
實現(xiàn)如下:
class PKT(nn.Module): """Probabilistic Knowledge Transfer for deep representation learning Code from author: https://github.com/passalis/probabilistic_kt""" def __init__(self): super(PKT, self).__init__() def forward(self, f_s, f_t): return self.cosine_similarity_loss(f_s, f_t) @staticmethod def cosine_similarity_loss(output_net, target_net, eps=0.0000001): # Normalize each vector by its norm output_net_norm = torch.sqrt(torch.sum(output_net ** 2, dim=1, keepdim=True)) output_net = output_net / (output_net_norm + eps) output_net[output_net != output_net] = 0 target_net_norm = torch.sqrt(torch.sum(target_net ** 2, dim=1, keepdim=True)) target_net = target_net / (target_net_norm + eps) target_net[target_net != target_net] = 0 # Calculate the cosine similarity model_similarity = torch.mm(output_net, output_net.transpose(0, 1)) target_similarity = torch.mm(target_net, target_net.transpose(0, 1)) # Scale cosine similarity to 0..1 model_similarity = (model_similarity + 1.0) / 2.0 target_similarity = (target_similarity + 1.0) / 2.0 # Transform them into probabilities model_similarity = model_similarity / torch.sum(model_similarity, dim=1, keepdim=True) target_similarity = target_similarity / torch.sum(target_similarity, dim=1, keepdim=True) # Calculate the KL-divergence loss = torch.mean(target_similarity * torch.log((target_similarity + eps) / (model_similarity + eps))) return loss
9. AB: Activation Boundaries
全稱:Knowledge Transfer via Distillation of Activation Boundaries Formed by Hidden Neurons
鏈接:https://arxiv.org/pdf/1811.03233.pdf
發(fā)表:AAAI18
目標:讓教師網(wǎng)絡層的神經(jīng)元的激活邊界盡量和學生網(wǎng)絡的一樣。所謂的激活邊界指的是分離超平面(針對的是RELU這種激活函數(shù)),其決定了神經(jīng)元的激活與失活。AB提出的激活轉(zhuǎn)移損失,讓教師網(wǎng)絡與學生網(wǎng)絡之間的分離邊界盡可能一致。
實現(xiàn)如下:
class ABLoss(nn.Module): """Knowledge Transfer via Distillation of Activation Boundaries Formed by Hidden Neurons code: https://github.com/bhheo/AB_distillation """ def __init__(self, feat_num, margin=1.0): super(ABLoss, self).__init__() self.w = [2**(i-feat_num+1) for i in range(feat_num)] self.margin = margin def forward(self, g_s, g_t): bsz = g_s[0].shape[0] losses = [self.criterion_alternative_l2(s, t) for s, t in zip(g_s, g_t)] losses = [w * l for w, l in zip(self.w, losses)] # loss = sum(losses) / bsz # loss = loss / 1000 * 3 losses = [l / bsz for l in losses] losses = [l / 1000 * 3 for l in losses] return losses def criterion_alternative_l2(self, source, target): loss = ((source + self.margin) ** 2 * ((source > -self.margin) & (target <= 0)).float() + (source - self.margin) ** 2 * ((source <= self.margin) & (target > 0)).float()) return torch.abs(loss).sum()
10. FT: Factor Transfer
全稱:Paraphrasing Complex Network: Network Compression via Factor Transfer
鏈接:https://arxiv.org/pdf/1802.04977.pdf
發(fā)表:NIPS18
提出的是factor transfer的方法。所謂的factor,其實是對模型最后的數(shù)據(jù)結(jié)果進行一個編解碼的過程,提取出的一個factor矩陣,用教師網(wǎng)絡的factor來指導學生網(wǎng)絡的factor。
FT計算公式為:
實現(xiàn)如下:
class FactorTransfer(nn.Module): """Paraphrasing Complex Network: Network Compression via Factor Transfer, NeurIPS 2018""" def __init__(self, p1=2, p2=1): super(FactorTransfer, self).__init__() self.p1 = p1 self.p2 = p2 def forward(self, f_s, f_t): return self.factor_loss(f_s, f_t) def factor_loss(self, f_s, f_t): s_H, t_H = f_s.shape[2], f_t.shape[2] if s_H > t_H: f_s = F.adaptive_avg_pool2d(f_s, (t_H, t_H)) elif s_H < t_H: f_t = F.adaptive_avg_pool2d(f_t, (s_H, s_H)) else: pass if self.p2 == 1: return (self.factor(f_s) - self.factor(f_t)).abs().mean() else: return (self.factor(f_s) - self.factor(f_t)).pow(self.p2).mean() def factor(self, f): return F.normalize(f.pow(self.p1).mean(1).view(f.size(0), -1))
*博客內(nèi)容為網(wǎng)友個人發(fā)布,僅代表博主個人觀點,如有侵權請聯(lián)系工作人員刪除。