Fork me on GitHub

谁是大模型的显存救星?

以下文章来源于 https://zhuanlan.zhihu.com/p/667993454



一、前言

我们在训练或者推理大模型的时候,会发现随着时间的增长显存也会慢慢的上升,其实是大模型训练和推理时采用的一种加速方式导致的。

因为Decoder 每次前向,当前时间步计算 Attention 要用到的key和value,如之前时间步的KV(Key 和 Value)值都计算过的,只是之前每次前向完后给计算结果都丢掉,只保留最后输出。这样就会让模型大量重复计算,于是KV Cache的加速方法就出现了,用空间换时间,但是大模型参数巨量,计算时会产生大量的KV矩阵,使得显存占用极高,对显存的需求也是非常惊人。

而且 K 和 V 能直接存在缓存中,模型规模小还好,一旦模型规模很大长度很长时,KV 根本就存不进缓存。现在h100和a100的SRAM 缓存都远小于需求,除了能存一部份kv显存之外,大量的kv都是需要重新计算,其实并没有加速多少。

有一些方法针对这方面进行优化,对Attention加速和KV Cache显存占用优化,在效果和速度上求个均衡,lllama系列就提出了MQA(Multi Query Attention)和GQA(Grouped Query Attention),还有Flash Attention1和2。


二、MQA(Multi Query Attention)

MQA是对于传统的MHA(Multi Head Attention)做的改进,query仍然保持多个,而key和value是共享的,这样避免重复每一个query都计算kv,这种方式肯定是非常粗暴的省显存,思想有点类似于Albert共享层权重一样,这样确实能够达到目的,但是因为所有query都用一个kv,在效果上肯定是会打折扣的,但是同样也带来了不少收益,模型吞吐量能提高30%-40%。所以后面又提出来GQA(Grouped Query Attention)方法优化,争取在效果和速度上均衡。

import os import math import torch.nn as nn import torch class MultiQuerySelfAttention(nn.Module): def __init__(self, num_attention_heads, hidden_size): super().__init__() self.num_attention_heads = num_attention_heads self.attention_head_size = int(hidden_size / num_attention_heads) self.all_head_size = self.num_attention_heads * self.attention_head_size self.query = nn.Linear(hidden_size, self.all_head_size) self.key = nn.Linear(hidden_size, self.attention_head_size) self.value = nn.Linear(hidden_size, self.attention_head_size) self.dropout = nn.Dropout(0.1) def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) x = x.view(new_x_shape) return x.permute(0, 2, 1, 3) def forward(self,hidden_states): # hidden_states (B, L, D) mixed_query_layer = self.query(hidden_states) # query_layer (B, h, L, d) query_layer = self.transpose_for_scores(mixed_query_layer) # 每个key、value head参数都是一样的,只计算一次 key = self.key(hidden_states) #key_layer (B, 1, L, d) key_layer = key.unsqueeze(1) value = self.value(hidden_states) # value_layer (B, 1, L, d) value_layer = value.unsqueeze(1) # key_layer (B, 1, d, L) key_layer = key_layer.transpose(-1, -2) #广播算法 (B, h, L, d) * (B, 1, d, L) => (B, h, L, d) * (B, h, d, L) = (B, h, L, L) attention_scores = torch.matmul(query_layer, key_layer) attention_scores = attention_scores / math.sqrt(self.attention_head_size) attention_probs = nn.functional.softmax(attention_scores, dim=-1) attention_probs = self.dropout(attention_probs) #广播算法 (B, h, L, L) * (B, 1, L, d) =>(B, h, L, L) * (B, h, L, d)= (B, h, L, d) context_layer = torch.matmul(attention_probs, value_layer) #(B, h, L, d) => (B, L, h, d) context_layer = context_layer.permute(0, 2, 1, 3).contiguous() # (B,L, h*d) => (B,L,D) new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) # (B,L, h*d) => (B,L,D) context_layer = context_layer.view(new_context_layer_shape) return context_layer

三、GQA(Grouped Query Attention)

上面说了MQA(Multi Query Attention)的改进方法,是多个query共用一个key和value矩阵,这样会让多个query差异变小模型效果变差,做法过于粗暴。GQA则是在上面做了些优化,不让所有的query共用一个key和value矩阵,分组来共享一个key和value矩阵,也是能够减少一些key和value矩阵的产生,这是性能和效果的取舍较为平衡的点。

代码来源:GitHub - fkodom/grouped-query-attention-pytorch: (Unofficial) PyTorch implementation of grouped-query attention (GQA) from "GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints" (https://arxiv.org/pdf/2305.13245.pdf)

import os import math import torch.nn as nn import torch class MultiheadGQA(nn.Module): """Multi-head grouped query attention (GQA) layer. Reference: "GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints" https://arxiv.org/pdf/2305.13245v1.pdf GQA is a variant of multihead attention (MHA) that uses fewer write heads (key / value) than query heads. GQA can be viewed as a generalization of multi-query attention (MQA), which uses a single write head. GQA and MQA give significant speedups over standard MHA in decoder layers, with minimal loss in accuracy. In the paper, GQA is shown to be more accurate than MQA, while still having a significant speedup over MHA. NOTE: The original authors only benchmark GQA by adapting the T5 (XL or XXL) model from MHA to GQA. As a result, they do not mention parameter initialization or layer normalization strategies. I follow the best practices laid out in the MAGNETO paper, which improves Transformer performance through better parameter initialization and layer norm placement. See: https://arxiv.org/pdf/2210.06423.pdf, Fig. 2 """ def __init__( self, embed_dim: int, query_heads: int, kv_heads: int, dropout: float = 0.0, bias: bool = True, layer_norm: bool = True, layer_norm_eps: float = 1e-5, gamma_init: float = 1.0, device: Optional[Union[torch.device, str]] = None, dtype: Optional[torch.dtype] = None, ): super().__init__() self.query_heads = query_heads self.kv_heads = kv_heads self.dropout = dropout self.layer_norm = layer_norm self.gamma_init = gamma_init if self.query_heads % self.kv_heads != 0: raise ValueError( f"query_heads ({query_heads}) must be divisible by " f"kv_heads ({kv_heads})" ) elif (embed_dim % self.query_heads != 0) or (embed_dim % self.kv_heads != 0): raise ValueError( f"embed_dim ({embed_dim}) must be divisible by " f"query_heads ({query_heads}) and kv_heads ({kv_heads})" ) head_dim = embed_dim // query_heads if not head_dim % 8 == 0: raise ValueError( f"head_dim (embed_dim / num_heads = {head_dim}) must be divisible by 8" ) if not head_dim <= 128: raise ValueError( f"head_dim (embed_dim / num_heads = {head_dim}) must be <= 128" ) # Query projection layer is the same as in vanilla MHA. self.q_proj = nn.Linear( embed_dim, embed_dim, bias=bias, device=device, dtype=dtype ) # Key/value projection layers have a smaller output dimension, so that # the we have fewer key/value attention heads after reshaping. kv_embed_dim = embed_dim // query_heads * kv_heads self.k_proj = nn.Linear( embed_dim, kv_embed_dim, bias=bias, device=device, dtype=dtype ) self.v_proj = nn.Linear( embed_dim, kv_embed_dim, bias=bias, device=device, dtype=dtype ) self.norm: Optional[nn.LayerNorm] = None if layer_norm: self.norm = nn.LayerNorm( kv_embed_dim, eps=layer_norm_eps, device=device, dtype=dtype ) # Grouped attention output will have the same embedding dimension as the # key/value Tensors. So the output projection layer needs to accept the # same dimension (kv_embed_dim). self.out_proj = nn.Linear( kv_embed_dim, embed_dim, bias=bias, device=device, dtype=dtype ) self._reset_parameters() def _reset_parameters(self): nn.init.xavier_normal_(self.q_proj.weight) if self.q_proj.bias is not None: nn.init.constant_(self.q_proj.bias, 0) nn.init.xavier_normal_(self.k_proj.weight) if self.k_proj.bias is not None: nn.init.constant_(self.k_proj.bias, 0) # NOTE: We follow the initialization strategy from MAGNETO. See: # https://arxiv.org/pdf/2210.06423.pdf, Fig. 2 # Gain (self.gamma_init) should be provided as a keyword argument when # initializing the larger Transformer model, since it requires knowledge # of the number of encoder/decoder layers in the model. nn.init.xavier_normal_(self.v_proj.weight, gain=self.gamma_init) if self.v_proj.bias is not None: nn.init.constant_(self.v_proj.bias, 0) nn.init.xavier_normal_(self.out_proj.weight, gain=self.gamma_init) if self.out_proj.bias is not None: nn.init.constant_(self.out_proj.bias, 0) def forward( self, query: Tensor, key: Tensor, value: Tensor, need_weights: bool = False, # TODO # attn_mask: Optional[Tensor] = None, is_causal: bool = False, average_attn_weights: bool = False, ) -> Tuple[Tensor, Optional[Tensor]]: # Notation: # b - batch size # n - sequence length # h - number of heads # d - embedding dimension # # Input shape: (b, n, d) q: Tensor = self.q_proj(query) k: Tensor = self.k_proj(key) v: Tensor = self.v_proj(value) # Unfold 'd' dimension into 'h' separate attention heads. q = rearrange(q, "b n (h d) -> b n h d", h=self.query_heads) k = rearrange(k, "b n (h d) -> b n h d", h=self.kv_heads) v = rearrange(v, "b n (h d) -> b n h d", h=self.kv_heads) # Apply attention, then fold 'h' attention heads back into 'd'. x, attn = scaled_dot_product_gqa( query=q, key=k, value=v, # TODO # mask=attn_mask, is_causal=is_causal, need_weights=need_weights, average_attn_weights=average_attn_weights, force_grouped=False, ) x = rearrange(x, "b n h d -> b n (h d)") # NOTE: This is different from 'nn.MultiheadAttention'! We follow the MAGNETO # architecture (https://arxiv.org/pdf/2210.06423.pdf), which applies an extra # layer norm before the linear output projection. The cross-attention layer in # the MAGNETO decoder does not include this layer norm, so users have the # option to disable it (layer_norm=False). if self.layer_norm: assert self.norm is not None x = self.norm(x) # Linear projection on attention outputs. x = self.out_proj(x) return x, attn

lllama2就是采用GQA(Grouped Query Attention)方案来做加速,提高模型吞吐量。


四、Flash Attention 1

看图上基于NVIDIA GPU显存架构的a100的40g显卡,可以看出来SRAM的速度比HBM(显存)快非常多,但是SRAM容量小很多,所以如果把涉及到显存IO的操作放在SRAM,能够大幅度提高运算速度。

普调的Attention运算query、key、value矩阵都在HBM(显存)中计算,但HBM(显存)的运算速度比不上SRAM。

标准注意力计算:

标准注意力计算会将S和P的矩阵值存在显存中,以便于后续反向传播时使用。标准注意力在计算attention score时会频繁IO key和value矩阵,而这些矩阵则是存在HBM显存中,运算速度不快,尤其是在大模型参数多的时候,有很多attention score要计算就会存在大量IO操作,如果能将这些IO操作转移到SRAM则会提升很多速度。

Flash Attention 1

优化一 (不全局计算softmax,采用分块计算机制):

标准注意力计算的方法是先将key和query点乘之后再通过softmax归一化然后再和value相乘,Flash Attention中,将query、key、value进行分块计算softmax,这样可以将计算放到SRAM中提升计算速度,每一个块的softmax输出的缩放数据可以合并成一个完整的softmax数据。


优化二 (反向传播过程中通过归一化因子重计算中间注意力,减少整个矩阵存储):

在传统的注意力机制中,为了计算中间的注意力矩阵,需要将计算过程中的S和P存储到HBM(显存)中。然而,这些中间矩阵的大小与输入序列的长度成二次型关系,因此会增加内存的消耗。为了解决这个问题,Flash Attention提出了一种新的算法,它避免了存储中间注意力矩阵的步骤,而是通过存储归一化因子来减少HBM内存的使用。

在Flash Attention的前向计算算法中,可以看出它并没有将S和P直接写入HBM中,而是将它们分块写入HBM,然后存储前向传播的softmax归一化因子。这样,在后向传播阶段可以快速重新计算片上的注意力,这比从HBM中读取中间注意力矩阵的标准方法更快。尽管由于重新计算而增加了FLOPS(浮点运算数),但Flash Attention在运行速度上更快且使用更少的内存(与序列长度呈线性关系),主要是因为大大减少了HBM的访问量。


Flash Attention 1让运算速度提升,显存占用减少,是一个非常有效的优化方法。

具体实现:
https://github.com/Dao-AILab/flash-attention

五、Flash Attention 2

Flash Attention 2创新优化点:

1、减少了non-matmul运算,虽然非矩阵运算占比少,但是耗时高,不能利用gpu矩阵运算优势,非矩阵运算会比矩阵运算慢16倍,会拖慢整个运算速度,将一些算子计算写成矩阵运算来加速

在计算输入块attention分数时flash attention在计算的时候会除以当前的块的缩放因子\operatorname{diag}\left(\ell^{(2)}\right)^{-1},为了减少non-matmul运算,不用在每次块计算时除以当前缩放因子,当所有计算完成之后再乘上缩放因子缩放即可:

2、在反向传播时不需要保存 m^{(j)} 和指数之和 \zeta^{j} ,转而保存logsumexp值 L_{(j)} = m^{(j)} + log(\zeta^{(j)})

计算转化:

因果掩码

因果遮蔽(Causal masking)。注意力机制的一个常见用例是自回归语言模型,在这种情况下,我们需要对注意力矩阵S应用一个因果遮蔽(causal mask)(即将任何满足 > 的S 的条目设置为-∞)。

  1. 由于FlashAttention和FlashAttention-2已经按块操作,对于所有列索引都大于行索引的块(对于大的序列长度约占一半的块),我们可以跳过该块的计算。这相对于没有因果遮蔽的注意力机制,带来了大约1.7-1.8倍的速度提升。
  2. 对于行索引始终严格小于列索引的块,我们不需要应用因果遮蔽。这意味着对于每一行,我们只需要将因果遮蔽应用到1个块(假设块是矩阵)中。

并行化

FlashAttention 1在批大小(batch size)和头数(number of heads)上进行并行化。我们使用一个线程块来处理一个注意力头(attention head),总共有(batch size x 头数)这么多的线程块。每个线程块被调度在一个流式多处理器(streaming multiprocessor,SM)上运行。例如,在A100 GPU上有108个这样的SM。当这个数量很大(例如≥ 80)时,这种调度是高效的,因为我们可以有效地利用GPU上的几乎所有计算资源。

在长序列的情况下(通常意味着小的批大小或小的头数),为了更好地利用GPU上的多处理器,在序列长度维度上进行了额外的并行化。这在这种情况下显著提高了速度。

前向传播:

可以看到外部循环(序列长度)可以进行并行处理,而且它们可以在不需要彼此通信的不同线程块上进行调度。还像FlashAttention一样在批维度和头数维度上进行了并行化。当批大小和头数较小时,序列长度的增加并行性有助于提高占用率(正在使用的GPU资源的比例),从而在这种情况下提高速度。

反向传播:

需要注意的是,不同列块之间唯一共享的计算是在更新dQ的过程中,我们需要从HBM加载dQ 到SRAM,然后在芯片上更新 dQ_{i} \leftarrow dQ_{i} + dS_{i}^{(j)}K_{j} ,并写回到HBM。因此,在序列长度维度上进行了并行化,并为反向传播的每个列块安排了1个线程块。我们使用原子加法(atomic adds)在不同线程块之间进行通信,以更新dQ。

对于每个块,FlashAttention 将 K 和 V 分成 4 个warp,同时保持 query 可访问通过所有warp。 每个warp相乘得到 QK⊤ 的切片,然后需要与 V 的切片相乘,将结果相加。 这称为"split-K"方案。 因为所有 warp 都需要将其中间结果写入共享内存进行同步,然后将结果相加中间结果。 这些共享内存读/写会减慢 FlashAttention 中的前向传播速度。

在 FlashAttention-2 中,我们将 Q 分成 4 个warp,同时保持所有warp均可访问 key 和 value。在每个warp执行矩阵乘法以获得 QK⊤ 的切片后,它们只需要与它们共享的相乘V 的切片以获得其相应的输出切片。 warp 之间不需要通信。共享内存读/写的减少可以提高速度。

增加块大小通常会减少共享内存加载/存储,但会增加所需寄存器的数量和共享内存的总量。 超过一定的块大小,内存溢出会导致通信的速度减慢,或者所需的共享内存量大于 GPU 剩余的内存量内核根本无法运行。 通常我们选择大小为 {64, 128} × {64, 128} 的块,取决于head尺寸 和设备共享内存大小。手动调整每个头部尺寸,因为块尺寸基本上只有 4 种选择,但是这可以受益于自动调整以避免这种手工劳动。

从上面看出来FlashAttention-2在长度更长和参数更多速度提升越明显。


本文地址:https://www.6aiq.com/article/1700666721319
本文版权归作者和AIQ共有,欢迎转载,但未经作者同意必须保留此段声明,且在文章页面明显位置给出