"); //-->
11. FSP: Flow of Solution Procedure
全称:A Gift from Knowledge Distillation: Fast Optimization, Network Minimization and Transfer Learning
链接:https://openaccess.thecvf.com/content_cvpr_2017/papers/Yim_A_Gift_From_CVPR_2017_paper.pdf
发表:CVPR17
FSP认为教学生网络不同层输出的feature之间的关系比教学生网络结果好
定义了FSP矩阵来定义网络内部特征层之间的关系,是一个Gram矩阵反映老师教学生的过程。
使用的是L2 Loss进行约束FSP矩阵。实现如下:
class FSP(nn.Module): """A Gift from Knowledge Distillation: Fast Optimization, Network Minimization and Transfer Learning""" def __init__(self, s_shapes, t_shapes): super(FSP, self).__init__() assert len(s_shapes) == len(t_shapes), 'unequal length of feat list' s_c = [s[1] for s in s_shapes] t_c = [t[1] for t in t_shapes] if np.any(np.asarray(s_c) != np.asarray(t_c)): raise ValueError('num of channels not equal (error in FSP)') def forward(self, g_s, g_t): s_fsp = self.compute_fsp(g_s) t_fsp = self.compute_fsp(g_t) loss_group = [self.compute_loss(s, t) for s, t in zip(s_fsp, t_fsp)] return loss_group @staticmethod def compute_loss(s, t): return (s - t).pow(2).mean() @staticmethod def compute_fsp(g): fsp_list = [] for i in range(len(g) - 1): bot, top = g[i], g[i + 1] b_H, t_H = bot.shape[2], top.shape[2] if b_H > t_H: bot = F.adaptive_avg_pool2d(bot, (t_H, t_H)) elif b_H < t_H: top = F.adaptive_avg_pool2d(top, (b_H, b_H)) else: pass bot = bot.unsqueeze(1) top = top.unsqueeze(2) bot = bot.view(bot.shape[0], bot.shape[1], bot.shape[2], -1) top = top.view(top.shape[0], top.shape[1], top.shape[2], -1) fsp = (bot * top).mean(-1) fsp_list.append(fsp) return fsp_list
12. NST: Neuron Selectivity Transfer
全称:Like what you like: knowledge distill via neuron selectivity transfer
链接:https://arxiv.org/pdf/1707.01219.pdf
发表:CoRR17
使用新的损失函数最小化教师网络与学生网络之间的Maximum Mean Discrepancy(MMD), 文中选择的是对其教师网络与学生网络之间神经元选择样式的分布。
使用核技巧(对应下面poly kernel)并进一步展开以后可得:
实际上提供了Linear Kernel、Poly Kernel、Gaussian Kernel三种,这里实现只给了Poly这种,这是因为Poly这种方法可以与KD进行互补,这样整体效果会非常好。实现如下:
class NSTLoss(nn.Module): """like what you like: knowledge distill via neuron selectivity transfer""" def __init__(self): super(NSTLoss, self).__init__() pass def forward(self, g_s, g_t): return [self.nst_loss(f_s, f_t) for f_s, f_t in zip(g_s, g_t)] def nst_loss(self, f_s, f_t): s_H, t_H = f_s.shape[2], f_t.shape[2] if s_H > t_H: f_s = F.adaptive_avg_pool2d(f_s, (t_H, t_H)) elif s_H < t_H: f_t = F.adaptive_avg_pool2d(f_t, (s_H, s_H)) else: pass f_s = f_s.view(f_s.shape[0], f_s.shape[1], -1) f_s = F.normalize(f_s, dim=2) f_t = f_t.view(f_t.shape[0], f_t.shape[1], -1) f_t = F.normalize(f_t, dim=2) # set full_loss as False to avoid unnecessary computation full_loss = True if full_loss: return (self.poly_kernel(f_t, f_t).mean().detach() + self.poly_kernel(f_s, f_s).mean() - 2 * self.poly_kernel(f_s, f_t).mean()) else: return self.poly_kernel(f_s, f_s).mean() - 2 * self.poly_kernel(f_s, f_t).mean() def poly_kernel(self, a, b): a = a.unsqueeze(1) b = b.unsqueeze(2) res = (a * b).sum(-1).pow(2) return res
13. CRD: Contrastive Representation Distillation
全称:Contrastive Representation Distillation
链接:https://arxiv.org/abs/1910.10699v2
发表:ICLR20
将对比学习引入知识蒸馏中,其目标修正为:学习一个表征,让正样本对的教师网络与学生网络尽可能接近,负样本对教师网络与学生网络尽可能远离。构建的对比学习问题表示如下:
整体的蒸馏Loss表示如下:
实现如下:https://github.com/HobbitLong/RepDistiller
class ContrastLoss(nn.Module): """ contrastive loss, corresponding to Eq (18) """ def __init__(self, n_data): super(ContrastLoss, self).__init__() self.n_data = n_data def forward(self, x): bsz = x.shape[0] m = x.size(1) - 1 # noise distribution Pn = 1 / float(self.n_data) # loss for positive pair P_pos = x.select(1, 0) log_D1 = torch.div(P_pos, P_pos.add(m * Pn + eps)).log_() # loss for K negative pair P_neg = x.narrow(1, 1, m) log_D0 = torch.div(P_neg.clone().fill_(m * Pn), P_neg.add(m * Pn + eps)).log_() loss = - (log_D1.sum(0) + log_D0.view(-1, 1).sum(0)) / bsz return loss class CRDLoss(nn.Module): """CRD Loss function includes two symmetric parts: (a) using teacher as anchor, choose positive and negatives over the student side (b) using student as anchor, choose positive and negatives over the teacher side Args: opt.s_dim: the dimension of student's feature opt.t_dim: the dimension of teacher's feature opt.feat_dim: the dimension of the projection space opt.nce_k: number of negatives paired with each positive opt.nce_t: the temperature opt.nce_m: the momentum for updating the memory buffer opt.n_data: the number of samples in the training set, therefor the memory buffer is: opt.n_data x opt.feat_dim """ def __init__(self, opt): super(CRDLoss, self).__init__() self.embed_s = Embed(opt.s_dim, opt.feat_dim) self.embed_t = Embed(opt.t_dim, opt.feat_dim) self.contrast = ContrastMemory(opt.feat_dim, opt.n_data, opt.nce_k, opt.nce_t, opt.nce_m) self.criterion_t = ContrastLoss(opt.n_data) self.criterion_s = ContrastLoss(opt.n_data) def forward(self, f_s, f_t, idx, contrast_idx=None): """ Args: f_s: the feature of student network, size [batch_size, s_dim] f_t: the feature of teacher network, size [batch_size, t_dim] idx: the indices of these positive samples in the dataset, size [batch_size] contrast_idx: the indices of negative samples, size [batch_size, nce_k] Returns: The contrastive loss """ f_s = self.embed_s(f_s) f_t = self.embed_t(f_t) out_s, out_t = self.contrast(f_s, f_t, idx, contrast_idx) s_loss = self.criterion_s(out_s) t_loss = self.criterion_t(out_t) loss = s_loss + t_loss return loss
14. Overhaul
全称:A Comprehensive Overhaul of Feature Distillation链接:http://openaccess.thecvf.com/content_ICCV_2019/papers/发表:CVPR19
teacher transform中提出使用margin RELU激活函数。
student transform中提出使用1x1卷积。
distillation feature postion选择Pre-ReLU。
distance function部分提出了Partial L2 损失函数。
部分实现如下:
class OFD(nn.Module): ''' A Comprehensive Overhaul of Feature Distillation http://openaccess.thecvf.com/content_ICCV_2019/papers/ Heo_A_Comprehensive_Overhaul_of_Feature_Distillation_ICCV_2019_paper.pdf ''' def __init__(self, in_channels, out_channels): super(OFD, self).__init__() self.connector = nn.Sequential(*[ nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=False), nn.BatchNorm2d(out_channels) ]) for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') if m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.BatchNorm2d): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) def forward(self, fm_s, fm_t): margin = self.get_margin(fm_t) fm_t = torch.max(fm_t, margin) fm_s = self.connector(fm_s) mask = 1.0 - ((fm_s <= fm_t) & (fm_t <= 0.0)).float() loss = torch.mean((fm_s - fm_t)**2 * mask) return loss def get_margin(self, fm, eps=1e-6): mask = (fm < 0.0).float() masked_fm = fm * mask margin = masked_fm.sum(dim=(0,2,3), keepdim=True) / (mask.sum(dim=(0,2,3), keepdim=True)+eps) return margin
参考文献
https://blog.csdn.net/weixin_44579633/article/details/119350631
https://blog.csdn.net/winycg/article/details/105297089
https://blog.csdn.net/weixin_46239293/article/details/120289163
https://blog.csdn.net/DD_PP_JJ/article/details/121578722
https://blog.csdn.net/DD_PP_JJ/article/details/121714957
https://zhuanlan.zhihu.com/p/344881975
https://blog.csdn.net/weixin_44633882/article/details/108927033
https://blog.csdn.net/weixin_46239293/article/details/120266111
https://blog.csdn.net/weixin_43402775/article/details/109011296
https://blog.csdn.net/m0_37665984/article/details/103288582
https://blog.csdn.net/m0_37665984/article/details/103269740
本文仅做学术分享,如有侵权,请联系删文。
*博客内容为网友个人发布,仅代表博主个人观点,如有侵权请联系工作人员删除。