新闻  |   论坛  |   博客  |   在线研讨会
知识蒸馏综述:代码整理(1)
计算机视觉工坊 | 2022-01-16 19:47:18    阅读:1053   发布文章

作者 | PPRP 

来源 | GiantPandaCV

编辑 | 极市平台

导读

本文收集自RepDistiller中的蒸馏方法,尽可能简单解释蒸馏用到的策略,并提供了实现源码。

1. KD: Knowledge Distillation

全称:Distilling the Knowledge in a Neural Network

链接:https://arxiv.org/pdf/1503.02531.pd3f

发表:NIPS14

最经典的,也是明确提出知识蒸馏概念的工作,通过使用带温度的softmax函数来软化教师网络的逻辑层输出作为学生网络的监督信息,

使用KL divergence来衡量学生网络与教师网络的差异,具体流程如下图所示(来自Knowledge Distillation A Survey)

1.jpg

对学生网络来说,一部分监督信息来自hard label标签,另一部分来自教师网络提供的soft label。代码实现:

class DistillKL(nn.Module):
    """Distilling the Knowledge in a Neural Network"""
    def __init__(self, T):
        super(DistillKL, self).__init__()
        self.T = T
    def forward(self, y_s, y_t):
        p_s = F.log_softmax(y_s/self.T, dim=1)
        p_t = F.softmax(y_t/self.T, dim=1)
        loss = F.kl_div(p_s, p_t, size_average=False) * (self.T**2) / y_s.shape[0]
        return loss

核心就是一个kl_div函数,用于计算学生网络和教师网络的分布差异。

2. FitNet: Hints for thin deep nets

全称:Fitnets: hints for thin deep nets

链接:https://arxiv.org/pdf/1412.6550.pdf

发表:ICLR 15 Poster

对中间层进行蒸馏的开山之作,通过将学生网络的feature map扩展到与教师网络的feature map相同尺寸以后,使用均方误差MSE Loss来衡量两者差异。

2.jpg

实现如下:

class HintLoss(nn.Module):
    """Fitnets: hints for thin deep nets, ICLR 2015"""
    def __init__(self):
        super(HintLoss, self).__init__()
        self.crit = nn.MSELoss()
    def forward(self, f_s, f_t):
        loss = self.crit(f_s, f_t)
        return loss

实现核心就是MSELoss。

3. AT: Attention Transfer

全称:Paying More Attention to Attention: Improving the Performance of Convolutional Neural Networks via Attention Transfer

链接:https://arxiv.org/pdf/1612.03928.pdf

发表:ICLR16

为了提升学生模型性能提出使用注意力作为知识载体进行迁移,文中提到了两种注意力,一种是activation-based attention transfer,另一种是gradient-based attention transfer。实验发现第一种方法既简单效果又好。

3.jpg

实现如下:

class Attention(nn.Module):
    """Paying More Attention to Attention: Improving the Performance of Convolutional Neural Networks
    via Attention Transfer
    code: https://github.com/szagoruyko/attention-transfer"""
    def __init__(self, p=2):
        super(Attention, self).__init__()
        self.p = p
    def forward(self, g_s, g_t):
        return [self.at_loss(f_s, f_t) for f_s, f_t in zip(g_s, g_t)]
    def at_loss(self, f_s, f_t):
        s_H, t_H = f_s.shape[2], f_t.shape[2]
        if s_H > t_H:
            f_s = F.adaptive_avg_pool2d(f_s, (t_H, t_H))
        elif s_H < t_H:
            f_t = F.adaptive_avg_pool2d(f_t, (s_H, s_H))
        else:
            pass
        return (self.at(f_s) - self.at(f_t)).pow(2).mean()
    def at(self, f):
        return F.normalize(f.pow(self.p).mean(1).view(f.size(0), -1))

首先使用avgpool将尺寸调整一致,然后使用MSE Loss来衡量两者差距。

4. SP: Similarity-Preserving

全称:Similarity-Preserving Knowledge Distillation

链接:https://arxiv.org/pdf/1907.09682.pdf

发表:ICCV19SP

归属于基于关系的知识蒸馏方法。文章思想是提出相似性保留的知识,使得教师网络和学生网络会对相同的样本产生相似的激活。可以从下图看出处理流程,教师网络和学生网络对应feature map通过计算内积,得到bsxbs的相似度矩阵,然后使用均方误差来衡量两个相似度矩阵。

4.jpg

最终Loss为:

G代表的就是bsxbs的矩阵。实现如下:

class Similarity(nn.Module):
    """Similarity-Preserving Knowledge Distillation, ICCV2019, verified by original author"""
    def __init__(self):
        super(Similarity, self).__init__()
    def forward(self, g_s, g_t):
        return [self.similarity_loss(f_s, f_t) for f_s, f_t in zip(g_s, g_t)]
    def similarity_loss(self, f_s, f_t):
        bsz = f_s.shape[0]
        f_s = f_s.view(bsz, -1)
        f_t = f_t.view(bsz, -1)
        G_s = torch.mm(f_s, torch.t(f_s))
        # G_s = G_s / G_s.norm(2)
        G_s = torch.nn.functional.normalize(G_s)
        G_t = torch.mm(f_t, torch.t(f_t))
        # G_t = G_t / G_t.norm(2)
        G_t = torch.nn.functional.normalize(G_t)
        G_diff = G_t - G_s
        loss = (G_diff * G_diff).view(-1, 1).sum(0) / (bsz * bsz)
        return loss

5. CC: Correlation Congruence

全称:Correlation Congruence for Knowledge Distillation

链接:https://arxiv.org/pdf/1904.01802.pdf

发表:ICCV19

CC也归属于基于关系的知识蒸馏方法。不应该仅仅引导教师网络和学生网络单个样本向量之间的差异,还应该学习两个样本之间的相关性,而这个相关性使用的是Correlation Congruence 教师网络雨学生网络相关性之间的欧氏距离。

整体Loss如下:

实现如下:

class Correlation(nn.Module):
    """Similarity-preserving loss. My origianl own reimplementation 
    based on the paper before emailing the original authors."""
    def __init__(self):
        super(Correlation, self).__init__()
    def forward(self, f_s, f_t):
        return self.similarity_loss(f_s, f_t)
    def similarity_loss(self, f_s, f_t):
        bsz = f_s.shape[0]
        f_s = f_s.view(bsz, -1)
        f_t = f_t.view(bsz, -1)
        G_s = torch.mm(f_s, torch.t(f_s))
        G_s = G_s / G_s.norm(2)
        G_t = torch.mm(f_t, torch.t(f_t))
        G_t = G_t / G_t.norm(2)
        G_diff = G_t - G_s
        loss = (G_diff * G_diff).view(-1, 1).sum(0) / (bsz * bsz)
        return loss


*博客内容为网友个人发布,仅代表博主个人观点,如有侵权请联系工作人员删除。

参与讨论
登录后参与讨论
推荐文章
最近访客