Post

Tiny-LLM(二):实现旋转位置编码RoPE

Tiny-LLM(二):实现旋转位置编码RoPE

1 任务总览

在Transformer架构中,由于自注意力机制本身不包含位置信息,我们需要通过位置编码来为模型提供序列中token的顺序信息。本章将实现现代大语言模型广泛采用的旋转位置编码(RoPE)

任务细节Tiny-LLM Week1-Day2

1.1 实现传统 RoPE

在第一步中,我们需要在 src/tiny_llm/positional_encoding.py 中实现标准版本的 RoPE。

每个注意力头的维度 D 会被看作由成对的元素组成(例如 [x0, x1][x2, x3] 等)。每一对被当作一个复数 (x_even, x_odd),然后与对应的 cossin 频率进行旋转变换:

1
2
x'[0] = x[0] * cos - x[1] * sin
x'[1] = x[0] * sin + x[1] * cos

这样,模型就能通过旋转角度体现序列中的位置信息。

同时,我们还需要理解 offset 的概念。如果没有偏移量(offset),我们就认为序列从第 0 个 token 开始编码;如果有偏移量,比如 5..10,那就代表这个序列实际上是原始序列中第 5~9 个位置的片段,因此需要从第 5 个频率开始应用位置编码。

1.2 实现 Qwen2 风格的非传统 RoPE

在 Qwen2 模型中,RoPE 的形式略有不同。它不是成对旋转,而是将头部维度拆分为前半部分和后半部分,然后分别使用不同频率的旋转。也就是说,我们把输入向量 x 拆成 x1(前半部分)和 x2(后半部分),再执行:

1
2
output[:HALF] = x1 * cos - x2 * sin
output[HALF:] = x1 * sin + x2 * cos

这种方式可以更好地兼容多头注意力下的特征分布,让模型在保持性能的同时拥有更高的灵活性。

2 背景知识

在深入实现旋转位置编码(Rotary Position Embedding,RoPE)之前,我们需要先理解它是如何从传统位置编码演化而来的,以及它为何能更好地捕获相对位置信息并提升大模型的外推性(extrapolation)

2.1 为什么需要位置编码?

Transformer 模型不像 RNN 那样会按顺序处理序列,因此模型本身无法“知道”一个 token 在句子中的位置。为了让模型理解序列的顺序,我们需要给每个 token 注入“位置感”,也就是位置编码(Positional Encoding)

简单来说,词嵌入(embedding)告诉模型“是什么词”,而位置编码告诉模型“这个词在句子中在哪里”。

在经过 embedding 层后,我们得到每个 token 的词向量:xᵢ ∈ ℝᵈ 表示第 i 个 token 的 d 维向量。在计算 self-attention 之前,我们会把位置信息注入到 Q、K、V 向量中:

tinyllm-1

这里 m/n 表示位置信息。所有位置编码的设计,核心目标都是:设计一个合理的函数 f,让模型既能捕获顺序信息,又能泛化到未见过的序列长度。

2.2 传统绝对位置编码

tinyllm-2

最早的 Transformer 使用了正弦-余弦函数(Sinusoidal)生成固定的位置编码。其思路是用不同频率的 sin 和 cos 函数为不同维度的向量编码,让每个维度对应一种周期变化规律。

通俗地讲,每个位置 i 都会生成一组独特的波形模式,模型通过这些不同周期的 sin/cos 波动,就能区分“第1个词”和“第100个词”的区别。不过,这种方式存在两个问题:

  1. 它只能表达绝对位置(即“我是谁”),而不能表达相对距离(即“我和你差几位”)。
  2. 当我们希望模型在推理阶段处理比训练阶段更长的序列时,sin/cos 的周期性会让模型的泛化能力迅速下降——这就是所谓的外推性差(poor extrapolation)

外推性是大模型中一个非常关键的能力。举个例子:如果一个模型在训练时只见过长度为 512 的文本,那么在推理时输入 2048 个 token,它可能会“迷路”,因为它从未见过那么长的上下文。

RoPE 就是为了解决这个问题:它能让模型学到位置之间的相对关系,而不是固定的绝对编号。因此,哪怕你给它更长的文本,它仍能根据“相对距离”来计算注意力,而不是依赖“第几号位置”。

2.3 旋转位置编码

论文 RoFormer: Enhanced Transformer with Rotary Position Embedding 提出了 RoPE。它的核心思想是:**用旋转操作代替位置相加操作。找到一个位置编码方式,使得 query 向量 $q_m$ 和 key 向量 $k_n$ 之间的内积能够自然地包含它们之间的相对位置信息$(m-n)$。

假设词嵌入维度为2维 $d=2$,利用二维平面上的几何性质,RoPE提出,对于query向量 $q_m$:

tinyllm-3

这实际上是一个旋转矩阵,角度 θᵢ 随位置 i 变化。这样,RoPE 把每个 token 的位置编码成一个“旋转角度”,让相对位置信息通过旋转角度差自然地体现在注意力计算中。

总的来说,RoPE 的精髓在于:

  1. 不再将位置编码加到 embedding 上,而是通过旋转操作融入;
  2. 将“绝对位置”映射为“相对角度差”;
  3. 让注意力计算对序列长度具有更好的外推性。

3 代码实现

RoPE 的核心是:把每一对维度 (x1, x2) 看作复数的实部与虚部,然后让它按角度 θ 做旋转:

1
2
[real]   [cosθ  -sinθ] [x1]
[imag] = [sinθ   cosθ] [x2]

这里 $θ = pos × ω_i$,其中 $ω_i = base^{-2i/d}$,代表不同维度对应不同的旋转频率。

主要实现步骤:

  1. 构造频率矩阵:在初始化函数 __init__ 中,首先确定维度和频率,预先计算好 cosθsinθ

  2. 拆分:RoPE 需要把最后一维分成两部分(相当于复数的实部和虚部),有两种模式

    • traditional 模式:交错配对 [x0,x1], [x2,x3], ...

    • split 模式:前半 [0:D/2) 为实部,后半 [D/2:D) 为虚部

  3. 旋转:把预计算好的 cos/sin reshape 成可广播的形状,执行复数旋转
  4. 结果重组:把 (real, imag) 拼回原维度,返回的 y 与输入形状相同,但每个元素都已带有位置编码。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
class RoPE:
    def __init__(
        self,
        dims: int,
        seq_len: int,
        base: int = 10000,
        traditional: bool = False,
    ):
        half_dims = dims // 2
        self.half_dims = half_dims
        self.traditional = traditional
        pos = mx.arange(seq_len)    # pos = [0, 1, 2, ..., L-1]
        inner = mx.arange(0, half_dims, dtype=mx.float32) / half_dims   # [2i/d], i=0,...,half_dims-1
        w =  mx.power(base, -inner) # [1/10000^(i/half_dims)], i=0,...,half_dims-1
        theta = mx.outer(pos, w)    # (seq_len, half_dims)
        self.cos_theta = mx.cos(theta)
        self.sin_theta = mx.sin(theta)

    def __call__(
        self, x: mx.array, offset: list[slice] | slice | None = None
    ) -> mx.array:
        N, S, H, D = x.shape

        # apply offset
        if offset is not None:
            if isinstance(offset, slice):
                assert offset.stop - offset.start == S, f"offset must be of length {S}"
            elif isinstance(offset, list):
                assert len(offset) == N, (
                    f"offsets must have the same length as batch size {N}"
                )
                for o in offset:
                    assert o.stop - o.start == S, f"offset must be of length {S}"
                offset = mx.array([list(range(i.start, i.stop)) for i in offset])
        cos_biasis = (self.cos_theta[:S, :] if offset is None else self.cos_theta[offset, :])
        sin_biasis = (self.sin_theta[:S, :] if offset is None else self.sin_theta[offset, :])

        # reshape x: (N, S, H, D // 2, 2)
        if self.traditional:    # [0, 2, 4, 6] [1, 3, 5, 7] format
            x = x.reshape(N, S, H, self.half_dims, 2)
            x1 = x[..., 0]
            x2 = x[..., 1]
        else:                   # [0, 1, 2, 3] [4, 5, 6, 7] format
            x1 = x[..., 0 : self.half_dims]
            x2 = x[..., self.half_dims : D]
        
        # reshape basis: (N, S, 1, half_dims)
        cos_biasis = cos_biasis.reshape(-1, S, 1, self.half_dims)
        sin_biasis = sin_biasis.reshape(-1, S, 1, self.half_dims)

        # [real; imag] = [cos -sin; sin cos] * [x1; x2]
        real = mx.multiply(x1, cos_biasis) - mx.multiply(x2, sin_biasis)
        imag = mx.multiply(x2, cos_biasis) + mx.multiply(x1, sin_biasis)
        if self.traditional:
            y = mx.stack([real, imag], axis=-1)
            y = y.reshape(N, S, H, D)
        else:
            y = mx.concat([real, imag], axis=-1)
            y = y.reshape(N, S, H, D)
        return y.astype(x.dtype)

4 参考资料

This post is licensed under CC BY 4.0 by the author.