"); //-->
作者 | PPRP
来源 | GiantPandaCV
编辑 | 极市平台
导读
本文收集自RepDistiller中的蒸馏方法,尽可能简单解释蒸馏用到的策略,并提供了实现源码。
1. KD: Knowledge Distillation
全称:Distilling the Knowledge in a Neural Network
链接:https://arxiv.org/pdf/1503.02531.pd3f
发表:NIPS14
最经典的,也是明确提出知识蒸馏概念的工作,通过使用带温度的softmax函数来软化教师网络的逻辑层输出作为学生网络的监督信息,
使用KL divergence来衡量学生网络与教师网络的差异,具体流程如下图所示(来自Knowledge Distillation A Survey)
对学生网络来说,一部分监督信息来自hard label标签,另一部分来自教师网络提供的soft label。代码实现:
class DistillKL(nn.Module): """Distilling the Knowledge in a Neural Network""" def __init__(self, T): super(DistillKL, self).__init__() self.T = T def forward(self, y_s, y_t): p_s = F.log_softmax(y_s/self.T, dim=1) p_t = F.softmax(y_t/self.T, dim=1) loss = F.kl_div(p_s, p_t, size_average=False) * (self.T**2) / y_s.shape[0] return loss
核心就是一个kl_div函数,用于计算学生网络和教师网络的分布差异。
2. FitNet: Hints for thin deep nets
全称:Fitnets: hints for thin deep nets
链接:https://arxiv.org/pdf/1412.6550.pdf
发表:ICLR 15 Poster
对中间层进行蒸馏的开山之作,通过将学生网络的feature map扩展到与教师网络的feature map相同尺寸以后,使用均方误差MSE Loss来衡量两者差异。
实现如下:
class HintLoss(nn.Module): """Fitnets: hints for thin deep nets, ICLR 2015""" def __init__(self): super(HintLoss, self).__init__() self.crit = nn.MSELoss() def forward(self, f_s, f_t): loss = self.crit(f_s, f_t) return loss
实现核心就是MSELoss。
3. AT: Attention Transfer
全称:Paying More Attention to Attention: Improving the Performance of Convolutional Neural Networks via Attention Transfer
链接:https://arxiv.org/pdf/1612.03928.pdf
发表:ICLR16
为了提升学生模型性能提出使用注意力作为知识载体进行迁移,文中提到了两种注意力,一种是activation-based attention transfer,另一种是gradient-based attention transfer。实验发现第一种方法既简单效果又好。
实现如下:
class Attention(nn.Module): """Paying More Attention to Attention: Improving the Performance of Convolutional Neural Networks via Attention Transfer code: https://github.com/szagoruyko/attention-transfer""" def __init__(self, p=2): super(Attention, self).__init__() self.p = p def forward(self, g_s, g_t): return [self.at_loss(f_s, f_t) for f_s, f_t in zip(g_s, g_t)] def at_loss(self, f_s, f_t): s_H, t_H = f_s.shape[2], f_t.shape[2] if s_H > t_H: f_s = F.adaptive_avg_pool2d(f_s, (t_H, t_H)) elif s_H < t_H: f_t = F.adaptive_avg_pool2d(f_t, (s_H, s_H)) else: pass return (self.at(f_s) - self.at(f_t)).pow(2).mean() def at(self, f): return F.normalize(f.pow(self.p).mean(1).view(f.size(0), -1))
首先使用avgpool将尺寸调整一致,然后使用MSE Loss来衡量两者差距。
4. SP: Similarity-Preserving
全称:Similarity-Preserving Knowledge Distillation
链接:https://arxiv.org/pdf/1907.09682.pdf
发表:ICCV19SP
归属于基于关系的知识蒸馏方法。文章思想是提出相似性保留的知识,使得教师网络和学生网络会对相同的样本产生相似的激活。可以从下图看出处理流程,教师网络和学生网络对应feature map通过计算内积,得到bsxbs的相似度矩阵,然后使用均方误差来衡量两个相似度矩阵。
最终Loss为:
G代表的就是bsxbs的矩阵。实现如下:
class Similarity(nn.Module): """Similarity-Preserving Knowledge Distillation, ICCV2019, verified by original author""" def __init__(self): super(Similarity, self).__init__() def forward(self, g_s, g_t): return [self.similarity_loss(f_s, f_t) for f_s, f_t in zip(g_s, g_t)] def similarity_loss(self, f_s, f_t): bsz = f_s.shape[0] f_s = f_s.view(bsz, -1) f_t = f_t.view(bsz, -1) G_s = torch.mm(f_s, torch.t(f_s)) # G_s = G_s / G_s.norm(2) G_s = torch.nn.functional.normalize(G_s) G_t = torch.mm(f_t, torch.t(f_t)) # G_t = G_t / G_t.norm(2) G_t = torch.nn.functional.normalize(G_t) G_diff = G_t - G_s loss = (G_diff * G_diff).view(-1, 1).sum(0) / (bsz * bsz) return loss
5. CC: Correlation Congruence
全称:Correlation Congruence for Knowledge Distillation
链接:https://arxiv.org/pdf/1904.01802.pdf
发表:ICCV19
CC也归属于基于关系的知识蒸馏方法。不应该仅仅引导教师网络和学生网络单个样本向量之间的差异,还应该学习两个样本之间的相关性,而这个相关性使用的是Correlation Congruence 教师网络雨学生网络相关性之间的欧氏距离。
整体Loss如下:
实现如下:
class Correlation(nn.Module): """Similarity-preserving loss. My origianl own reimplementation based on the paper before emailing the original authors.""" def __init__(self): super(Correlation, self).__init__() def forward(self, f_s, f_t): return self.similarity_loss(f_s, f_t) def similarity_loss(self, f_s, f_t): bsz = f_s.shape[0] f_s = f_s.view(bsz, -1) f_t = f_t.view(bsz, -1) G_s = torch.mm(f_s, torch.t(f_s)) G_s = G_s / G_s.norm(2) G_t = torch.mm(f_t, torch.t(f_t)) G_t = G_t / G_t.norm(2) G_diff = G_t - G_s loss = (G_diff * G_diff).view(-1, 1).sum(0) / (bsz * bsz) return loss
*博客内容为网友个人发布,仅代表博主个人观点,如有侵权请联系工作人员删除。