跳转至

注意力计算

下面的代码节选自site-packages/transformers/modeling_gpt2.pyAttention类, 版本: transformers==2.1.1

  • batch_size: B
  • seq_len: T
  • hidden_dim: D
  • self.num_head: H

初始化

def __init__(self, nx, n_ctx, config, scale=False):
    super(Attention, self).__init__()
    self.output_attentions = config.output_attentions

    n_state = nx  # in Attention: n_state=768 (nx=n_embd)
    # [switch nx => n_state from Block to Attention to keep identical to TF implem]
    assert n_state % config.n_head == 0
    self.register_buffer("bias", torch.tril(torch.ones(n_ctx, n_ctx)).view(1, 1, n_ctx, n_ctx))
    self.n_head = config.n_head
    self.split_size = n_state
    self.scale = scale

    self.c_attn = Conv1D(n_state * 3, nx)
    self.c_proj = Conv1D(n_state, nx)
    self.attn_dropout = nn.Dropout(config.attn_pdrop)
    self.resid_dropout = nn.Dropout(config.resid_pdrop)
    self.pruned_heads = set()
  • nx: 输入嵌入的维度, 在GPT2中, 这个值是768
  • n_ctx: 上下文最大长度, 在GPT2中, 这个值是1024, 它表示的是一次前向传播中模型的输入+输出的最大token数量
  • config: 一个配置对象, 里面存放了一些超参数, 如n_head, attn_pdrop, redis_pdrop, output_attn
  • scale: 决定在注意力计算的时候是否执行1/sqrt(d_k)的步骤
  • n_state: 就是nx
  • assert n_state % config.n_head: 确保n_state可以被n_head整除, 在多头注意力中会把embedding的维度平分到各个头上
  • self.register_buffer("bias", torch.tril(torch.ones(n_ctx, n_ctx)).view(1, 1, n_ctx, n_ctx)): 生成一个下三角矩阵, 用于因果掩码. 注册的是一个常量张量, 不会被当作可训练参数
  • self.n_head = config.n_head: 多头注意力的头数, 在GPT2中, 这个值是12
  • self.split_size = n_state: 就是nx, 可读性, 兼容性
  • self.scale: 表示后续在_attn函数中是否使用1/sqrt(d_k)做缩放
  • self.c_attn = Conv1D(n_state * 3, nx): 一个Conv1D层, 将输入映射到Q, K, V三个矩阵, 总特征维度是n_state*3
  • self.c_proj = Conv1D(n_state, nx): 一个Conv1D层, 将注意力计算结果(value)矩阵做一个线性变换
  • self.attn_dropout = nn.Dropout(config.attn_pdrop): 用于注意力权重的Dropout
  • self.resid_dropout = nn.Dropout(config.resid_pdrop): 用于残差连接的Dropout
  • self.pruned_heads = set(): 记录剪枝被剪掉的注意力头

前向传播

def forward(self, x, layer_past=None, attention_mask=None, head_mask=None):
    x = self.c_attn(x)
    query, key, value = x.split(self.split_size, dim=2)
    query = self.split_heads(query)
    key = self.split_heads(key, k=True)
    value = self.split_heads(value)
    if layer_past is not None:
        past_key, past_value = layer_past[0].transpose(-2, -1), layer_past[1]  # transpose back cf below
        key = torch.cat((past_key, key), dim=-1)
        value = torch.cat((past_value, value), dim=-2)
    present = torch.stack((key.transpose(-2, -1), value))  # transpose to have same shapes for stacking

    attn_outputs = self._attn(query, key, value, attention_mask, head_mask)
    a = attn_outputs[0]

    a = self.merge_heads(a)
    a = self.c_proj(a)
    a = self.resid_dropout(a)

    outputs = [a, present] + attn_outputs[1:]
    return outputs  # a, present, (attentions)
  • layer_past: 过去时间步的K, V矩阵的缓存, 可以加速推理
  • head_mask: 用来对注意力头进行选择性地屏蔽
  • x = self.c_attn(x): 将输入映射到3个矩阵, 形状从(B, T, D)变成(B, T, D*3)
  • query, key, value = x.split(self.split_size, dim=2): 将上一步的拼接结果拆分成Q, K, V三个部分, 每个部分的形状都是(B, T, D), 会在维度2就是最后一个维度上均分三份.
  • query = self.split_heads(query): 把Q拆分成多个注意力头, 形状从(B, T, D)变为(B, H, T, D//H)
  • key = self.split_heads(key, k=True): 将K拆分为多个注意力头, 形状从(B, T, D)变为(B, H, D//H, T), 注意, 这里k=True, 这是因为要对矩阵进行转置, 使其能执行torch.matmul(q, k)
  • value = self.split_heads(value): 把V拆分成多个注意力头, 形状从(B, T, D)变为(B, H, T, D//H)
  • if layer_past is not None:: 确实是否开启KV Cache

    KV Cache

    非常好的解释: https://www.bilibili.com/video/BV17CPkeEEzk/?spm_id_from=333.337.search-card.all.click&vd_source=f86bed5e9ae170543d583b3f354fcaa9

    我们在进行自回归输出的时候, 这个T的值是在不断的变大的, 一开始的时候是提示词的token长度, 随着自回归, 这个值会变得非常大, 例如, 目前输入+输出的长度已经达到了100万token. 那么我们的Q, K, V矩阵会变得非常长, 宽还是嵌入维度(如768). torch.matmul(q, k)产生的方阵会非常非常大, 100万*100万. 这就是为啥你在比较弱一点的硬件上跑的时候, 随着输出长度越来越大或者对话轮数越来越大, 蹦字的速度越来越慢的原因, GPT-3.5的n_ctx是4096, 所以随着你和它聊天轮数的增加, 超出了窗口, 它会对前面的内容进行截取, 导致前面的上下文丢失. 举个例子, 你在叫GPT翻译, 你把"请翻译文本"放在prompt的最前面, 然后粘贴了一个1000字的论文, 然后你聊天聊了几轮之后, 你会发现GPT-3.5似乎在和你对话了, 而不是翻译文本, 这是因为超出了它的窗口, 它看不到一开始的指令tokens了. 这可以通过KV Cache解决, 简单的来说, 我们可以搞一个超级超级大的窗口, 然后用KV Cache实现高效推理.

    • 如果没有KV Cache, 那么这个方阵的下三角区域全部都要重新计算, 非常消耗计算资源
    • 如果有KV Cache, 那么这个方阵的下三角区域只有最后一行要重新计算, 即只有当前的新token的Q, K, V是变的, 而之前所有tokens的K, V都是缓存的. 为什么Q不缓存呢? 我们想要得到这一行, 需要知道前面所有tokens的K, torch.matmul(Q, K)得到权重向量, 这个权重向量代表了前面所有tokens的权重, 和它们的V相乘把信息汇总到当前的这个token的V里面.

    KV Cache的Trade OFF是内存的消耗增加了, 但是QK矩阵乘法的效率增加了.

  • key = torch.cat((past_key, key), dim=-1): 将之前的key和新key在最后一维(dim = -1)进行拼接, 注意了之前Key的维度是(B, H, D//H, T), 拼接之后的维度是(B, H, D//H, T+1)

  • value = torch.cat((past_value, value), dim=-2): 将之前的value和新value在倒数第二维(dim = -2)进行拼接, 注意了之前Value的维度是(B, H, T, D//H), 拼接之后的维度是(B, H, T+1, D//H)
  • present = torch.stack((key.transpose(-2, -1), value)): 每个transformer block都会执行一个forward函数, 所以在一次前向传播中past_keypast_value这两个量都是不会变的, 但是新token的key, value, query在不停改变, 所以present表示的是当前block的KV Cache+新token的KV
  • attn_outputs = self._attn(query, key, value, attention_mask, head_mask): 执行注意力机制的核心计算过程, 返回的是经过更新后的V, 形状为(B, H, T+1, D//H)
  • a = attn_outputs[0]: 取第一个元素, attn_outputs的输出第一个元素是更新后的V, 第二个元素是注意力权重, 用于可视化
  • a = self.merge_heads(a): 将注意力头合并, 形状从(B, H, T+1, D//H)变为(B, T+1, D), 这里的merge_heads函数是split_heads函数的逆操作
  • a = self.c_proj(a): 将注意力计算结果做线性变换, 形状保持不变, 还是(B, T+1, D), 这里的c_proj是一个Conv1D层
  • a = self.resid_dropout(a): 对注意力计算结果施加dropout, 减少模型过拟合, 让注意力分布更具有随机性
  • outputs = [a, present] + attn_outputs[1:]: 返回的结果是一个列表, 包括更新后的Value矩阵, present矩阵和注意力权重(如有)

注意力头生成

def split_heads(self, x, k=False):
    new_x_shape = x.size()[:-1] + (self.n_head, x.size(-1) // self.n_head)
    x = x.view(*new_x_shape)  # in Tensorflow implem: fct split_states
    if k:
        return x.permute(0, 2, 3, 1)  # (batch, head, head_features, seq_length)
    else:
        return x.permute(0, 2, 1, 3)  # (batch, head, seq_length, head_features)
  • x: 一个批次的嵌入, 形状为(B, T, D), 如(2, 200, 768)
  • k: 决定如何重新排列维度, 对K矩阵要执行转置(见forward函数)
  • new_x_shape = x.size()[:-1] + (self.n_head, x.size(-1) // self.n_head): 计算新的形状, 原来的形状是(B, T, D), 新的形状是(B, H, T, D//H), 例如, 假设self.n_head=12, (2, 200, 768)->(2, 12, 200, 64)
  • x = x.view(*new_x_shape): 执行形状变换

注意力头合并

def merge_heads(self, x):
        x = x.permute(0, 2, 1, 3).contiguous()
        new_x_shape = x.size()[:-2] + (x.size(-2) * x.size(-1),)
        return x.view(*new_x_shape)  # in Tensorflow implem: fct merge_states
  • x: 更新之后的V, 形状为(B, H, T+1, D//H)
  • x = x.permute(0, 2, 1, 3).contiguous(): 将维度重新排列, 变为(B, T+1, H, D//H), 注意这里的contiguous()是为了保证内存连续性, 这样可以避免后续的view操作出错
  • new_x_shape = x.size()[:-2] + (x.size(-2) * x.size(-1),): 计算新的形状, 原来的形状是(B, T+1, H, D//H), 新的形状是(B, T+1, D), 例如, 假设self.n_head=12, (2, 201, 12, 64)->(2, 201, 768)
  • return x.view(*new_x_shape): 执行形状变换, 这里的view操作会把最后两个维度合并成一个维度, 这样就得到了更新后的V矩阵, 形状为(B, T+1, D)

注意力权重矩阵计算

def _attn(self, q, k, v, attention_mask=None, head_mask=None):
    w = torch.matmul(q, k)
    if self.scale:
        w = w / math.sqrt(v.size(-1))
    nd, ns = w.size(-2), w.size(-1)
    b = self.bias[:, :, ns-nd:ns, :ns]
    w = w * b - 1e4 * (1 - b)

    if attention_mask is not None:
        w = w + attention_mask

    w = nn.Softmax(dim=-1)(w)
    w = self.attn_dropout(w)

    if head_mask is not None:
        w = w * head_mask

    outputs = [torch.matmul(w, v)]
    if self.output_attentions:
        outputs.append(w)
    return outputs
  • w = torch.matmul(q, k): 计算Q*K^T, 得到原始的注意力分数矩阵w
  • if self.scale: w = w / math.sqrt(v.size(-1)): 使用1/sqrt(d_k)进行缩放, 避免向量维度过大造成的内积值过大
  • nd, ns = w.size(-2), w.size(-1): 取得注意力分数矩阵的形状
  • b = self.bias[:, :, ns-nd:ns, :ns]: 掩码矩阵, 常常是用一个上三角或者下三角矩阵表示
  • w = w * b - 1e4 * (1 - b): 把被遮住的注意力分数直接变成一个很大的负值, 这样在softmax的时候, 该位置的分数会被极大值抑制
  • if attention_mask is not None: w = w + attention_mask: 如果外部额外传入了attention_mask, 则把这个mask加到注意力分数里
  • w = nn.Softmax(dim=-1)(w): 在最后的一个维度上做softmax, 将注意力分数转为注意力权重
  • self.attn_dropout(w): 对注意力权重施加dropout, 减少模型过拟合, 让注意力分布更具有随机性
  • if head_mask is not None: w = w * head_mask: 如果有head_mask, 说明在训练或者推理的时候要屏蔽掉某些注意力头, 会在对应注意力权重上乘以0(或其他权重)
  • outputs = [torch.matmul(w, v)]: 最后用wv做矩阵乘法, 得到加权后的value输出
  • if self.output_attentions: outputs.append(w): 如果self.output_attentions为真, 那么会在outputs中把注意力权重w一并返回, 方便后续做可视化或其他分析

注意力头剪枝

下面的这个函数用于从模型中移除指定的注意力头, 这是一种常见的模型压缩技术, 可以减小模型的大小, 加快推理速度, 同时尽量保持模型的性能.

def prune_heads(self, heads):
    if len(heads) == 0:
        return
    mask = torch.ones(self.n_head, self.split_size // self.n_head)
    heads = set(heads) - self.pruned_heads  # Convert to set and emove already pruned heads
    for head in heads:
        # Compute how many pruned heads are before the head and move the index accordingly
        head = head - sum(1 if h < head else 0 for h in self.pruned_heads)
        mask[head] = 0
    mask = mask.view(-1).contiguous().eq(1)
    index = torch.arange(len(mask))[mask].long()
    index_attn = torch.cat([index, index + self.split_size, index + (2*self.split_size)])

    # Prune conv1d layers
    self.c_attn = prune_conv1d_layer(self.c_attn, index_attn, dim=1)
    self.c_proj = prune_conv1d_layer(self.c_proj, index, dim=0)

    # Update hyper params
    self.split_size = (self.split_size // self.n_head) * (self.n_head - len(heads))
    self.n_head = self.n_head - len(heads)
    self.pruned_heads = self.pruned_heads.union(heads)
  • heads: 需要剪枝的注意力头的列表, 例如[0, 1, 2], 表示要剪掉第0, 1, 2个注意力头
  • if len(heads) == 0: return: 如果没有要剪枝的头, 直接返回
  • mask = torch.ones(self.n_head, self.split_size // self.n_head): 创建一个掩码矩阵, 形状为(H, D//H), 用于标记哪些注意力头是保留的, 哪些是剪掉的, 初始为全1, 表示所有头都保留.
  • heads = set(heads) - self.pruned_heads: 将要剪掉的头转换为集合, 并去除已经剪掉的头, 以避免重复剪枝
  • head = head - sum(1 if h < head else 0 for h in self.pruned_heads): 计算在当前头索引之前有多少头已经被剪枝, 并调整索引值(因为之前的剪枝操作会影响后面的索引)
  • mask[head]=0: 将当前剪掉的头在掩码矩阵中标记为0, 表示该头被剪掉

    调整索引

    假设我们的Transformer模型最初有8个注意力头, 编号为0-7.

    1. 初始状态

      • 所有头都存在: [0, 1, 2, 3, 4, 5, 6, 7]
      • 已剪枝头集合: self.pruned_heads = {}
    2. 第一次剪枝操作

      假设我们要剪掉头 [2, 5]:

      heads = [2, 5]
      

      处理头2:

      • 计算2之前有多少个已被剪枝的头: sum(1 if h < 2 else 0 for h in self.pruned_heads) = 0
      • 调整后的索引: head = 2 - 0 = 2
      • 设置mask[2] = 0(标记为需要剪枝)

      处理头5:

      • 计算5之前有多少个已被剪枝的头: sum(1 if h < 5 else 0 for h in self.pruned_heads) = 0
      • 调整后的索引: head = 5 - 0 = 5
      • 设置mask[5] = 0

      剪枝后:

      • 剩余头: [0, 1, 3, 4, 6, 7]
      • 已剪枝头集合更新为: self.pruned_heads = {2, 5}
    3. 第二次剪枝操作

      现在, 假设我们要剪掉原始编号为 [3, 6] 的头:

      heads = [3, 6]
      

      处理头3:

      • 关键点: 计算3之前有多少个已被剪枝的头: sum(1 if h < 3 else 0 for h in self.pruned_heads) = 1(只有头2小于3)
      • 调整后的索引: head = 3 - 1 = 2
      • 设置mask[2] = 0(注意: 这里的索引2对应的是当前数组中的第三个元素, 也就是原始头3)

      处理头6:

      • 关键点: 计算6之前有多少个已被剪枝的头: sum(1 if h < 6 else 0 for h in self.pruned_heads) = 2(头2和头5都小于6)
      • 调整后的索引: head = 6 - 2 = 4
      • 设置mask[4] = 0(注意: 这里的索引4对应的是当前数组中的第五个元素, 也就是原始头6)

      剪枝后:

      • 剩余头: [0, 1, 4, 7]
      • 已剪枝头集合更新为: self.pruned_heads = {2, 3, 5, 6}
  • 后续代码

    剪枝

    假设我们有一个Transformer模型, 其配置如下:

    • 4个注意力头(n_head = 4)
    • 每个头的维度为16(split_size // n_head = 16)
    • 总特征维度为64(split_size = 64)
    • 我们要剪枝的头是 [1, 3]
    1. 创建并应用掩码

      经过前面的代码处理后, 我们得到了一个掩码矩阵:

      [[1,1,...,1],  # 头0: 保留(16个1)
      [0,0,...,0],  # 头1: 剪枝(16个0)
      [1,1,...,1],  # 头2: 保留(16个1)
      [0,0,...,0]]  # 头3: 剪枝(16个0)
      

    2. 处理掩码和创建索引

      mask = mask.view(-1).contiguous().eq(1)
      

      这一行将二维掩码展平成一维并转换为布尔值:

      • 展平后: [1,1,...,1, 0,0,...,0, 1,1,...,1, 0,0,...,0]
      • 转换为布尔值后: [True,True,...,True, False,False,...,False, True,True,...,True, False,False,...,False]
      index = torch.arange(len(mask))[mask].long()
      

      这一行创建保留位置的索引:

      • 首先生成 [0,1,2,...,63]
      • 然后只保留掩码为True的位置
      • 结果是: [0,1,...,15, 32,33,...,47]
    3. 创建注意力矩阵的索引

      index_attn = torch.cat([index, index + self.split_size, index + (2*self.split_size)])
      

      在Transformer中, QKV(查询, 键, 值)通常合并在一个矩阵中:

      • 前64个位置对应Q
      • 中间64个位置对应K
      • 后64个位置对应V

      因此这行代码构建了完整的索引:

      • index: [0,1,...,15, 32,33,...,47] (Q部分)
      • index + 64: [64,65,...,79, 96,97,...,111] (K部分)
      • index + 128: [128,129,...,143, 160,161,...,175] (V部分)
      • 合并后: [0,1,...,15, 32,33,...,47, 64,65,...,79, 96,97,...,111, 128,129,...,143, 160,161,...,175]
    4. 实际执行剪枝操作

      self.c_attn = prune_conv1d_layer(self.c_attn, index_attn, dim=1)  # 列方向索引
      self.c_proj = prune_conv1d_layer(self.c_proj, index, dim=0)  # 行方向索引
      

      若没有剪枝:

      • c_attn输入形状: [B, T, 64]
      • c_proj输入形状: [B, T+1, 64]
      • 经过c_attn的输出形状: [B, T, 3×64] (计算QKV的层)
      • 经过c_proj的输出形状: [B, T+1, 64] (注意力输出投影层)

      若有剪枝(原本是4个头, 现在只要2个头, 但是每个头的维度还是16):

      • c_attn输入形状: [B, T, 64]
      • c_proj输入形状: [B, T+1, 32]
      • 经过c_attn的输出形状: [B, T, 3×32] (只保留头0和头2的列)
      • 经过c_proj的输出形状: [B, T+1, 32] (只保留头0和头2的行)
    5. 更新模型参数

      self.split_size = (self.split_size // self.n_head) * (self.n_head - len(heads))
      

      新的特征维度 = (64 ÷ 4) × (4 - 2) = 16 × 2 = 32

      self.n_head = self.n_head - len(heads)
      

      新的头数量 = 4 - 2 = 2

      self.pruned_heads = self.pruned_heads.union(heads)
      

      更新已剪枝头的集合: self.pruned_heads = self.pruned_heads ∪ {1, 3}

评论