最近在看RoPE相关内容,一些方法通过简单修改位置编码就可以无需训练支持更长的文本内容。由于一些模型,已经训练好了,但是怎么修改已经训练好的模型位置编码。查了以下相关代码,记录一下。原理这里就不细讲了,贴几个相关博客。
十分钟读懂旋转编码(RoPE)
Transformer升级之路:11、将β进制位置进行到底
Transformer升级之路:10、RoPE是一种β进制编码
NTK
下图为NTK的原理证明:截取自Transformer升级之路:10、RoPE是一种β进制编码
看了上面的公式,我在考虑为什么需要建立 λ \lambda λ和k之间的关系?
因为我们要修改
β
\beta
β进制为
β
λ
\beta\lambda
βλ,由于k我们是可以知道的比如我们需要把位置编码缩小为10倍,直接设置k为10,但是采用NTK的方式,维度缩小为10倍,那么我们就不确定,
λ
\lambda
λ怎么设置了。所以需要简历
λ
\lambda
λ和k之间的关系,从上图可知,
λ
=
k
2
/
(
d
−
2
)
\lambda=k^{2/(d-2)}
λ=k2/(d−2)。
下面开始理解如何修改RoPE为NTK的形式:
以下为LLama的RoPE代码实现
class LlamaRotaryEmbedding(nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
super().__init__()
self.scaling_factor = scaling_factor
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
# For BC we register cos and sin cached
self.max_seq_len_cached = max_position_embeddings
t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
t = t / self.scaling_factor
freqs = torch.outer(t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("_cos_cached", emb.cos().to(torch.get_default_dtype()), persistent=False)
self.register_buffer("_sin_cached", emb.sin().to(torch.get_default_dtype()), persistent=False)
@property
def sin_cached(self):
logger.warning_once(
"The sin_cached attribute will be removed in 4.39. Bear in mind that its contents changed in v4.38. Use "
"the forward method of RoPE from now on instead. It is not used in the `LlamaAttention` class"
)
return self._sin_cached
@property
def cos_cached(self):
logger.warning_once(
"The cos_cached attribute will be removed in 4.39. Bear in mind that its contents changed in v4.38. Use "
"the forward method of RoPE from now on instead. It is not used in the `LlamaAttention` class"
)
return self._cos_cached
@torch.no_grad()
def forward(self, x, position_ids):
# x: [bs, num_attention_heads, seq_len, head_size]
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
position_ids_expanded = position_ids[:, None, :].float()
# Force float32 since bfloat16 loses precision on long contexts
# See https://github.com/huggingface/transformers/pull/29285
device_type = x.device.type
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
λ
\lambda
λ和k之间的关系,那么代码怎么实现呢,我们只需要修改
β
λ
\beta\lambda
βλ的结果即可,其中
β
\beta
β为
1000
0
2
/
d
10000^{2/d}
100002/d。
参考代码为:点击
import transformers
old_init = transformers.models.llama.modeling_llama.LlamaRotaryEmbedding.__init__
def ntk_scaled_init(self, dim, max_position_embeddings=2048, base=10000, device=None):
#The method is just these three lines
max_position_embeddings = 16384
k = 8 #Alpha value
base = base * k ** (dim / (dim-2)) #Base change formula
old_init(self, dim, max_position_embeddings, base, device)
transformers.models.llama.modeling_llama.LlamaRotaryEmbedding.__init__ = ntk_scaled_init
为什么采用base = base * k ** (dim / (dim-2)),原始的base为10000, 而 β \beta β为 1000 0 2 / d 10000^{2/d} 100002/d,发现在代码里面修改的仅仅是base的结果, β \beta β= b a s e 2 / d base^{2/d} base2/d,而 λ = k 2 / ( d − 2 ) \lambda=k^{2/(d-2)} λ=k2/(d−2),我们需要把k和base进行融合,修改成,base*k的形式形成新的base, λ \lambda λ等于k的指数 2 / ( d − 2 ) ∗ d / 2 ∗ 2 / d = d / ( d − 2 ) ∗ 2 / d 2/(d-2)*d/2*2/d=d/(d-2)*2/d 2/(d−2)∗d/2∗2/d=d/(d−2)∗2/d, λ = ( k d / ( d − 2 ) ) 2 / d \lambda=(k^{d/(d-2)})^{2/d} λ=(kd/(d−2))2/d,因为2/d在RoPE的代码里面已经计算过了:
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)
所以我们则赋值新的base为base * k ** (dim / (dim-2))。
Dynamic NTK
Dynamic在NTK的基础上进行简单的修改,采用NTK的时候更加灵活。
截图源于:RoPE到底是何方神圣(数学推理+优化方法)
代码实现:
class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding):
"""LlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
def forward(self, x, position_ids):
# difference to the original RoPE: inv_freq is recomputed when the sequence length > original length
seq_len = torch.max(position_ids) + 1
if seq_len > self.max_position_embeddings:#只有长度超过了预训练的阈值,进行NTK
base = self.base * (
(self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
) ** (self.dim / (self.dim - 2))
inv_freq = 1.0 / (
base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(x.device) / self.dim)
)
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: this may break with compilation
cos, sin = super().forward(x, position_ids)
return cos, sin