"); //-->
Retentive 网络
RetNet 由 L 个相同的块堆叠而成,其布局与 Transformer 类似(即残差连接和 pre-LayerNorm)。每个 RetNet 块包含两个模块:多尺度retention(MSR)和前馈网络(FFN)。
给定输入序列,RetNet 以自回归方式对序列进行编码。输入向量
首先被封装为
,其中
是隐藏维度。然后,计算上下文向量表征
。
Retention
RetNet 具有循环和并行双重形式的 retention 机制,因此能够并行地训练模型,同时循环地进行推理。
给定输入,将其投影为一维函数 v (n) = X_n - w_V。考虑一个序列建模问题,通过状态 s_n 映射 v (n) → o (n)。
为简单起见,让 v_n, o_n 表示 v (n),o (n)。此处以循环的方式对映射进行表述:
其中,将 v_n 映射到状态向量 s_n,然后实现线性变换,对序列信息进行循环编码。
接下来,使投影 Q_n, K_n 具有内容感知能力:
其中是可学习矩阵。
将矩阵对角化,其中
。然后得到
。通过将 Λ 吸收到 W_Q 和 W_K 中,可以将方程(1)重写为
其中,称为 xPos,即为 Transformer 提出的相对位置嵌入。进一步将 γ 简化为标量,公式(3)则变为
其中†为共轭转置。该公式很容易在训练实例中并行化。
总之,从公式 (1) 所示的循环建模开始,然后推导出公式 (4) 中的并行公式。将原始映射 v (n) →o (n) 视为向量,得到如下的 retention 机制:
1)Retention 的并行表征
如图 3a 所示,Retention 层定义为:
与自注意力类似,并行表征使得能够使用 GPU 高效地训练模型。
2)Retention 的循环表征
如图 3b 所示,所提出机制也可以写成循环神经网络(RNN),这有利于推理。对于第 n 个时间步,循环得到的输出为
这里的 Q, K, V, γ 和公式 5 相同。
3)Retention 分块循环表征
并行表征和循环表征的混合形式可以加速训练,特别是对于长序列。此处将输入序列划分为若干小块。在每个块内,按照并行表征(公式(5))进行计算。相反,跨块信息则按照循环表征(公式(6))进行传递。具体来说,让 B 表示块长度。通过以下方式计算第 i 个分块的 retention 输出:
其中 [i] 表示第 i 个数据块,例如。
门控多尺度 Retention
在每个层中,研究者使用 h = d_model/d 个 retention 头,其中 d 是头的维度。这些头使用不同的参数矩阵 W_Q、W_K、W_V ∈ R^(d×d)。此外,多尺度 retention(MSR)为每个头分配不同的 γ。为了简化,研究者将 γ 设置为在不同层之间相同并保持固定。另外,他们添加了一个 swish 门 [RZL17] 来增加层的非线性性。形式上,给定输入 X,研究者将该层定义为:
其中,为可学习参数,GroupNorm [WH18] 对每个头的输出进行归一化,遵循 [SPP^+19] 中提出的 SubLN。注意,这些头使用多个 γ 尺度,这会带来不同的方差统计结果。所以研究者分别对头的输出进行归一化。
retention 的伪代码如图 4 所示。
Retention Score 归一化
研究者利用 GroupNorm 的尺度不变性来提高 retention 层的数值精度。具体而言,在 GroupNorm 中乘以一个标量值不会影响输出和反向梯度,即 GroupNorm (α ∗ head_i) = GroupNorm (head_i)。研究者在公式(5)中实现了三个归一化因子。首先,他们将 QK^⊺ 归一化为 QK^⊺ / √ d。其次,他们将 D 替换为。第三,他们用 R 表示 retention scores R = QK^⊺ ⊙ D,将其归一化为
。然后,retention 输出变为
。由于尺度不变的特性,上述技巧不会影响最终的结果,同时稳定了正向和反向传递的数值流动。
Retention 网络总体结构
对于一个 L 层的 retention 网络,研究者堆叠多尺度 retention (MSR) 和前馈网络(FFN)来构建模型。形式上,输入序列通过一个词嵌入层被转换为向量。研究者使用打包后的嵌入
作为输入,并计算模型的输出 X^L:
其中,LN (・) 为 LayerNorm [BKH16]。FFN 部分计算为 FFN (X) = gelu (XW_1) W_2,其中 W_1、W_2 为参数矩阵。
训练:研究者在训练过程中使用了并行(公式 5)表示和块循环(公式 7)表示。序列或块内的并行有效地利用了 GPU 来加速计算。更有利的是,块循环对于长序列训练特别有用,这在 FLOPs 和内存消耗方面都是有效的。
推理:在推理过程中,研究者采用了循环表示(公式 6),这非常适合自回归解码。O (1) 的复杂度减少了内存占用和推理延迟,同时实现了相当的结果。
*博客内容为网友个人发布,仅代表博主个人观点,如有侵权请联系工作人员删除。