注意力计算
下面的代码节选自site-packages/transformers/modeling_gpt2.py
的Attention
类, 版本: transformers==2.1.1
batch_size
:B
seq_len
:T
hidden_dim
:D
self.num_head
:H
初始化¶
def __init__(self, nx, n_ctx, config, scale=False):
super(Attention, self).__init__()
self.output_attentions = config.output_attentions
n_state = nx # in Attention: n_state=768 (nx=n_embd)
# [switch nx => n_state from Block to Attention to keep identical to TF implem]
assert n_state % config.n_head == 0
self.register_buffer("bias", torch.tril(torch.ones(n_ctx, n_ctx)).view(1, 1, n_ctx, n_ctx))
self.n_head = config.n_head
self.split_size = n_state
self.scale = scale
self.c_attn = Conv1D(n_state * 3, nx)
self.c_proj = Conv1D(n_state, nx)
self.attn_dropout = nn.Dropout(config.attn_pdrop)
self.resid_dropout = nn.Dropout(config.resid_pdrop)
self.pruned_heads = set()
nx
: 输入嵌入的维度, 在GPT2中, 这个值是768n_ctx
: 上下文最大长度, 在GPT2中, 这个值是1024, 它表示的是一次前向传播中模型的输入+输出的最大token数量config
: 一个配置对象, 里面存放了一些超参数, 如n_head
,attn_pdrop
,redis_pdrop
,output_attn
等scale
: 决定在注意力计算的时候是否执行1/sqrt(d_k)
的步骤n_state
: 就是nx
assert n_state % config.n_head
: 确保n_state
可以被n_head
整除, 在多头注意力中会把embedding的维度平分到各个头上self.register_buffer("bias", torch.tril(torch.ones(n_ctx, n_ctx)).view(1, 1, n_ctx, n_ctx))
: 生成一个下三角矩阵, 用于因果掩码. 注册的是一个常量张量, 不会被当作可训练参数self.n_head = config.n_head
: 多头注意力的头数, 在GPT2中, 这个值是12self.split_size = n_state
: 就是nx
, 可读性, 兼容性self.scale
: 表示后续在_attn
函数中是否使用1/sqrt(d_k)
做缩放self.c_attn = Conv1D(n_state * 3, nx)
: 一个Conv1D层, 将输入映射到Q, K, V三个矩阵, 总特征维度是n_state*3
self.c_proj = Conv1D(n_state, nx)
: 一个Conv1D层, 将注意力计算结果(value)矩阵做一个线性变换self.attn_dropout = nn.Dropout(config.attn_pdrop)
: 用于注意力权重的Dropoutself.resid_dropout = nn.Dropout(config.resid_pdrop)
: 用于残差连接的Dropoutself.pruned_heads = set()
: 记录剪枝被剪掉的注意力头
前向传播¶
def forward(self, x, layer_past=None, attention_mask=None, head_mask=None):
x = self.c_attn(x)
query, key, value = x.split(self.split_size, dim=2)
query = self.split_heads(query)
key = self.split_heads(key, k=True)
value = self.split_heads(value)
if layer_past is not None:
past_key, past_value = layer_past[0].transpose(-2, -1), layer_past[1] # transpose back cf below
key = torch.cat((past_key, key), dim=-1)
value = torch.cat((past_value, value), dim=-2)
present = torch.stack((key.transpose(-2, -1), value)) # transpose to have same shapes for stacking
attn_outputs = self._attn(query, key, value, attention_mask, head_mask)
a = attn_outputs[0]
a = self.merge_heads(a)
a = self.c_proj(a)
a = self.resid_dropout(a)
outputs = [a, present] + attn_outputs[1:]
return outputs # a, present, (attentions)
layer_past
: 过去时间步的K, V矩阵的缓存, 可以加速推理head_mask
: 用来对注意力头进行选择性地屏蔽x = self.c_attn(x)
: 将输入映射到3个矩阵, 形状从(B, T, D)变成(B, T, D*3)query, key, value = x.split(self.split_size, dim=2)
: 将上一步的拼接结果拆分成Q, K, V三个部分, 每个部分的形状都是(B, T, D), 会在维度2就是最后一个维度上均分三份.query = self.split_heads(query)
: 把Q拆分成多个注意力头, 形状从(B, T, D)变为(B, H, T, D//H)key = self.split_heads(key, k=True)
: 将K拆分为多个注意力头, 形状从(B, T, D)变为(B, H, D//H, T), 注意, 这里k=True
, 这是因为要对矩阵进行转置, 使其能执行torch.matmul(q, k)value = self.split_heads(value)
: 把V拆分成多个注意力头, 形状从(B, T, D)变为(B, H, T, D//H)-
if layer_past is not None:
: 确实是否开启KV CacheKV Cache
我们在进行自回归输出的时候, 这个T的值是在不断的变大的, 一开始的时候是提示词的token长度, 随着自回归, 这个值会变得非常大, 例如, 目前输入+输出的长度已经达到了100万token. 那么我们的Q, K, V矩阵会变得非常长, 宽还是嵌入维度(如768).
torch.matmul(q, k)
产生的方阵会非常非常大, 100万*100万. 这就是为啥你在比较弱一点的硬件上跑的时候, 随着输出长度越来越大或者对话轮数越来越大, 蹦字的速度越来越慢的原因, GPT-3.5的n_ctx
是4096, 所以随着你和它聊天轮数的增加, 超出了窗口, 它会对前面的内容进行截取, 导致前面的上下文丢失. 举个例子, 你在叫GPT翻译, 你把"请翻译文本"放在prompt的最前面, 然后粘贴了一个1000字的论文, 然后你聊天聊了几轮之后, 你会发现GPT-3.5似乎在和你对话了, 而不是翻译文本, 这是因为超出了它的窗口, 它看不到一开始的指令tokens了. 这可以通过KV Cache解决, 简单的来说, 我们可以搞一个超级超级大的窗口, 然后用KV Cache实现高效推理.- 如果没有KV Cache, 那么这个方阵的下三角区域全部都要重新计算, 非常消耗计算资源
- 如果有KV Cache, 那么这个方阵的下三角区域只有最后一行要重新计算, 即只有当前的新token的Q, K, V是变的, 而之前所有tokens的K, V都是缓存的. 为什么Q不缓存呢? 我们想要得到这一行, 需要知道前面所有tokens的K, torch.matmul(Q, K)得到权重向量, 这个权重向量代表了前面所有tokens的权重, 和它们的V相乘把信息汇总到当前的这个token的V里面.
KV Cache的Trade OFF是内存的消耗增加了, 但是QK矩阵乘法的效率增加了.
-
key = torch.cat((past_key, key), dim=-1)
: 将之前的key和新key在最后一维(dim = -1)进行拼接, 注意了之前Key的维度是(B, H, D//H, T), 拼接之后的维度是(B, H, D//H, T+1) value = torch.cat((past_value, value), dim=-2)
: 将之前的value和新value在倒数第二维(dim = -2)进行拼接, 注意了之前Value的维度是(B, H, T, D//H), 拼接之后的维度是(B, H, T+1, D//H)present = torch.stack((key.transpose(-2, -1), value))
: 每个transformer block都会执行一个forward
函数, 所以在一次前向传播中past_key
和past_value
这两个量都是不会变的, 但是新token的key
,value
,query
在不停改变, 所以present
表示的是当前block的KV Cache+新token的KVattn_outputs = self._attn(query, key, value, attention_mask, head_mask)
: 执行注意力机制的核心计算过程, 返回的是经过更新后的V, 形状为(B, H, T+1, D//H)a = attn_outputs[0]
: 取第一个元素, attn_outputs的输出第一个元素是更新后的V, 第二个元素是注意力权重, 用于可视化a = self.merge_heads(a)
: 将注意力头合并, 形状从(B, H, T+1, D//H)变为(B, T+1, D), 这里的merge_heads
函数是split_heads
函数的逆操作a = self.c_proj(a)
: 将注意力计算结果做线性变换, 形状保持不变, 还是(B, T+1, D), 这里的c_proj
是一个Conv1D层a = self.resid_dropout(a)
: 对注意力计算结果施加dropout, 减少模型过拟合, 让注意力分布更具有随机性outputs = [a, present] + attn_outputs[1:]
: 返回的结果是一个列表, 包括更新后的Value矩阵, present矩阵和注意力权重(如有)
注意力头生成¶
def split_heads(self, x, k=False):
new_x_shape = x.size()[:-1] + (self.n_head, x.size(-1) // self.n_head)
x = x.view(*new_x_shape) # in Tensorflow implem: fct split_states
if k:
return x.permute(0, 2, 3, 1) # (batch, head, head_features, seq_length)
else:
return x.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features)
x
: 一个批次的嵌入, 形状为(B, T, D)
, 如(2, 200, 768)
k
: 决定如何重新排列维度, 对K矩阵要执行转置(见forward
函数)new_x_shape = x.size()[:-1] + (self.n_head, x.size(-1) // self.n_head)
: 计算新的形状, 原来的形状是(B, T, D)
, 新的形状是(B, H, T, D//H)
, 例如, 假设self.n_head=12
,(2, 200, 768)->(2, 12, 200, 64)
x = x.view(*new_x_shape)
: 执行形状变换
注意力头合并¶
def merge_heads(self, x):
x = x.permute(0, 2, 1, 3).contiguous()
new_x_shape = x.size()[:-2] + (x.size(-2) * x.size(-1),)
return x.view(*new_x_shape) # in Tensorflow implem: fct merge_states
x
: 更新之后的V, 形状为(B, H, T+1, D//H)x = x.permute(0, 2, 1, 3).contiguous()
: 将维度重新排列, 变为(B, T+1, H, D//H), 注意这里的contiguous()
是为了保证内存连续性, 这样可以避免后续的view操作出错new_x_shape = x.size()[:-2] + (x.size(-2) * x.size(-1),)
: 计算新的形状, 原来的形状是(B, T+1, H, D//H), 新的形状是(B, T+1, D), 例如, 假设self.n_head=12
,(2, 201, 12, 64)->(2, 201, 768)
return x.view(*new_x_shape)
: 执行形状变换, 这里的view
操作会把最后两个维度合并成一个维度, 这样就得到了更新后的V矩阵, 形状为(B, T+1, D)
注意力权重矩阵计算¶
def _attn(self, q, k, v, attention_mask=None, head_mask=None):
w = torch.matmul(q, k)
if self.scale:
w = w / math.sqrt(v.size(-1))
nd, ns = w.size(-2), w.size(-1)
b = self.bias[:, :, ns-nd:ns, :ns]
w = w * b - 1e4 * (1 - b)
if attention_mask is not None:
w = w + attention_mask
w = nn.Softmax(dim=-1)(w)
w = self.attn_dropout(w)
if head_mask is not None:
w = w * head_mask
outputs = [torch.matmul(w, v)]
if self.output_attentions:
outputs.append(w)
return outputs
w = torch.matmul(q, k)
: 计算Q*K^T
, 得到原始的注意力分数矩阵w
if self.scale: w = w / math.sqrt(v.size(-1))
: 使用1/sqrt(d_k)
进行缩放, 避免向量维度过大造成的内积值过大nd, ns = w.size(-2), w.size(-1)
: 取得注意力分数矩阵的形状b = self.bias[:, :, ns-nd:ns, :ns]
: 掩码矩阵, 常常是用一个上三角或者下三角矩阵表示w = w * b - 1e4 * (1 - b)
: 把被遮住的注意力分数直接变成一个很大的负值, 这样在softmax的时候, 该位置的分数会被极大值抑制if attention_mask is not None: w = w + attention_mask
: 如果外部额外传入了attention_mask
, 则把这个mask加到注意力分数里w = nn.Softmax(dim=-1)(w)
: 在最后的一个维度上做softmax, 将注意力分数转为注意力权重self.attn_dropout(w)
: 对注意力权重施加dropout, 减少模型过拟合, 让注意力分布更具有随机性if head_mask is not None: w = w * head_mask
: 如果有head_mask
, 说明在训练或者推理的时候要屏蔽掉某些注意力头, 会在对应注意力权重上乘以0
(或其他权重)outputs = [torch.matmul(w, v)]
: 最后用w
和v
做矩阵乘法, 得到加权后的value输出if self.output_attentions: outputs.append(w)
: 如果self.output_attentions
为真, 那么会在outputs
中把注意力权重w
一并返回, 方便后续做可视化或其他分析
注意力头剪枝¶
下面的这个函数用于从模型中移除指定的注意力头, 这是一种常见的模型压缩技术, 可以减小模型的大小, 加快推理速度, 同时尽量保持模型的性能.
def prune_heads(self, heads):
if len(heads) == 0:
return
mask = torch.ones(self.n_head, self.split_size // self.n_head)
heads = set(heads) - self.pruned_heads # Convert to set and emove already pruned heads
for head in heads:
# Compute how many pruned heads are before the head and move the index accordingly
head = head - sum(1 if h < head else 0 for h in self.pruned_heads)
mask[head] = 0
mask = mask.view(-1).contiguous().eq(1)
index = torch.arange(len(mask))[mask].long()
index_attn = torch.cat([index, index + self.split_size, index + (2*self.split_size)])
# Prune conv1d layers
self.c_attn = prune_conv1d_layer(self.c_attn, index_attn, dim=1)
self.c_proj = prune_conv1d_layer(self.c_proj, index, dim=0)
# Update hyper params
self.split_size = (self.split_size // self.n_head) * (self.n_head - len(heads))
self.n_head = self.n_head - len(heads)
self.pruned_heads = self.pruned_heads.union(heads)
heads
: 需要剪枝的注意力头的列表, 例如[0, 1, 2]
, 表示要剪掉第0, 1, 2个注意力头if len(heads) == 0: return
: 如果没有要剪枝的头, 直接返回mask = torch.ones(self.n_head, self.split_size // self.n_head)
: 创建一个掩码矩阵, 形状为(H, D//H)
, 用于标记哪些注意力头是保留的, 哪些是剪掉的, 初始为全1, 表示所有头都保留.heads = set(heads) - self.pruned_heads
: 将要剪掉的头转换为集合, 并去除已经剪掉的头, 以避免重复剪枝head = head - sum(1 if h < head else 0 for h in self.pruned_heads)
: 计算在当前头索引之前有多少头已经被剪枝, 并调整索引值(因为之前的剪枝操作会影响后面的索引)-
mask[head]=0
: 将当前剪掉的头在掩码矩阵中标记为0, 表示该头被剪掉调整索引
假设我们的Transformer模型最初有8个注意力头, 编号为0-7.
-
初始状态
- 所有头都存在: [0, 1, 2, 3, 4, 5, 6, 7]
- 已剪枝头集合:
self.pruned_heads = {}
-
第一次剪枝操作
假设我们要剪掉头 [2, 5]:
处理头2:
- 计算2之前有多少个已被剪枝的头:
sum(1 if h < 2 else 0 for h in self.pruned_heads) = 0
- 调整后的索引:
head = 2 - 0 = 2
- 设置
mask[2] = 0
(标记为需要剪枝)
处理头5:
- 计算5之前有多少个已被剪枝的头:
sum(1 if h < 5 else 0 for h in self.pruned_heads) = 0
- 调整后的索引:
head = 5 - 0 = 5
- 设置
mask[5] = 0
剪枝后:
- 剩余头: [0, 1, 3, 4, 6, 7]
- 已剪枝头集合更新为:
self.pruned_heads = {2, 5}
- 计算2之前有多少个已被剪枝的头:
-
第二次剪枝操作
现在, 假设我们要剪掉原始编号为 [3, 6] 的头:
处理头3:
- 关键点: 计算3之前有多少个已被剪枝的头:
sum(1 if h < 3 else 0 for h in self.pruned_heads) = 1
(只有头2小于3) - 调整后的索引:
head = 3 - 1 = 2
- 设置
mask[2] = 0
(注意: 这里的索引2对应的是当前数组中的第三个元素, 也就是原始头3)
处理头6:
- 关键点: 计算6之前有多少个已被剪枝的头:
sum(1 if h < 6 else 0 for h in self.pruned_heads) = 2
(头2和头5都小于6) - 调整后的索引:
head = 6 - 2 = 4
- 设置
mask[4] = 0
(注意: 这里的索引4对应的是当前数组中的第五个元素, 也就是原始头6)
剪枝后:
- 剩余头: [0, 1, 4, 7]
- 已剪枝头集合更新为:
self.pruned_heads = {2, 3, 5, 6}
- 关键点: 计算3之前有多少个已被剪枝的头:
-
-
后续代码
剪枝
假设我们有一个Transformer模型, 其配置如下:
- 4个注意力头(
n_head = 4
) - 每个头的维度为16(
split_size // n_head = 16
) - 总特征维度为64(
split_size = 64
) - 我们要剪枝的头是 [1, 3]
-
创建并应用掩码
经过前面的代码处理后, 我们得到了一个掩码矩阵:
-
处理掩码和创建索引
这一行将二维掩码展平成一维并转换为布尔值:
- 展平后: [1,1,...,1, 0,0,...,0, 1,1,...,1, 0,0,...,0]
- 转换为布尔值后: [True,True,...,True, False,False,...,False, True,True,...,True, False,False,...,False]
这一行创建保留位置的索引:
- 首先生成 [0,1,2,...,63]
- 然后只保留掩码为True的位置
- 结果是: [0,1,...,15, 32,33,...,47]
-
创建注意力矩阵的索引
在Transformer中, QKV(查询, 键, 值)通常合并在一个矩阵中:
- 前64个位置对应Q
- 中间64个位置对应K
- 后64个位置对应V
因此这行代码构建了完整的索引:
- index: [0,1,...,15, 32,33,...,47] (Q部分)
- index + 64: [64,65,...,79, 96,97,...,111] (K部分)
- index + 128: [128,129,...,143, 160,161,...,175] (V部分)
- 合并后: [0,1,...,15, 32,33,...,47, 64,65,...,79, 96,97,...,111, 128,129,...,143, 160,161,...,175]
-
实际执行剪枝操作
self.c_attn = prune_conv1d_layer(self.c_attn, index_attn, dim=1) # 列方向索引 self.c_proj = prune_conv1d_layer(self.c_proj, index, dim=0) # 行方向索引
若没有剪枝:
- c_attn输入形状: [B, T, 64]
- c_proj输入形状: [B, T+1, 64]
- 经过c_attn的输出形状: [B, T, 3×64] (计算QKV的层)
- 经过c_proj的输出形状: [B, T+1, 64] (注意力输出投影层)
若有剪枝(原本是4个头, 现在只要2个头, 但是每个头的维度还是16):
- c_attn输入形状: [B, T, 64]
- c_proj输入形状: [B, T+1, 32]
- 经过c_attn的输出形状: [B, T, 3×32] (只保留头0和头2的列)
- 经过c_proj的输出形状: [B, T+1, 32] (只保留头0和头2的行)
-
更新模型参数
新的特征维度 = (64 ÷ 4) × (4 - 2) = 16 × 2 = 32
新的头数量 = 4 - 2 = 2
更新已剪枝头的集合: self.pruned_heads = self.pruned_heads ∪ {1, 3}
- 4个注意力头(