上一篇文章搞定了RoPE,接下来该实现llama中使用的另外一个关键技术:Group Query Attention,分组查询注意力机制。 相比RoPE,个人觉得这个好理解一些,不过代码实现上难度大一些。当然了,我还是给它实现出来了。不过有关解码的部分,GQA的kv cache与原来的MHA不太一样,目前我的项目还没做到解码的部分,等我的SimpleLLM项目实现到这一部分应该会在后面加上,或者专门写一篇文章讲讲解码的全过程实现,里面会包含这几种注意力机制(还有MQA)的解码。
手搓LLM架构与训练过程:SimpleLLM
在讲解GQA之前,我们先讲讲最开始的MHA(Multi Head Attention,多头注意力机制),与改进后的MQA(Multi Query Attention,多查询注意力机制)。
MHA与MQA
先讲MHA,多头注意力机制,这是在transformer中最开始的注意力机制,相信大家也比较熟悉。
我们现在有这么一个序列:
$$ Sequence=\{token_i\}_{i=1}^N $$经过embedding之后,得到的词嵌入序列:
$$ X=\{x_i\}_{i=1}^N $$在多头的其中一个头内部的注意力机制中,每个$x_i$都要进行三次投影,分别得到$q_i$, $k_i$, $v_i$,相当于是一套q,k,v。有多少头,就是说的一个token对应的向量会被生成几套q,k,v。
图中表示的是一个token对应的向量,在经过q,k,v的生成后,一共生成了8套q,k,v,那么q的数量应该与k和v的数量相等。
当进行注意力计算时,每个头各算各的,因为都有独立的q,k,v。任意一个头,例如第一个头$head_0$,会拿自己的q,和同样是$head_0$生成出来的所有token对应的k进行相似度计算,再算出attn weights,对同一个头的v进行加权求和得到结果。每个头都做这个操作,最后将每个头的结果concatenate起来,就是这一层多头注意力的输出结果。
Attention is All You Need中提到,多头的灵感来源于CNN的不同通道提取不同的特征,于是也想在注意力机制中引入类似的东西:不同的头进行的操作相同且互不干扰,提取同一个数据中的不同特征。
接下来我们来讲讲MQA,多查询注意力机制。
MHA固然很好,但是它的复杂度还是太大了,每个token都要与其它token作用,便是$O(n^2)$.每次训练时注意力机制的开销都很大。并且,运营成本中比较关键的推理时性能,MHA也很不好。推理时有个东西叫做kv cache,就是保存已经计算过的token的k与v。MHA中每个token都要有num heads个k与v,这样随着序列的增长,kv cache的内存占用会迅速变大,面对长上下文的追求,这种缺点必须被克服或缓解。除了减少num heads或者d model,这里有个比较激进的方法,就是MQA。
MQA中,一个token只会生成一套k,v,但是依旧会生成num heads个q,但是这num heads个头的计算中,是共享k与v的,也就是强制每个头中的k与v都是相同的值。例如说,一个token对应的向量$x_i$,生成一个$k$与一个$v$, 以及num heads个$q$.在计算时,第一个head,会拿$q_0$与所有的k进行相似度计算;第二个头中,会拿$q_1$与所有的k进行相似度计算;……不过,这里的k都是同一套,只有不同token之间才有区别,同一个token的k在不同的头中都是相同的值。v也同理。
MQA使得无论你num heads是多少,都只影响q的数量。不影响k与v的数量。那这样进行解码(推理)时,每个token只需要保存一套kv cache就可以了,在大部分模型的num heads都是几十的时候,这就降低了几十倍kv cache的内存占用。同样的,在训练时,由于参数大量共享,实际只有一套,参数本身内存占用与运行时内存占用也小很多(注意,实际上计算量是没有什么变化)。而这种张量变小的优势不仅仅是内存占用小,还有带宽占用小,能更快地把张量从内存加载到缓存,进行计算上的加速。
那么,代价呢?
从直觉上看,这种方式必然损失性能,不管怎么说,只有一套kv,表示能力肯定会下降不少。这一点我们在后面的图表中可以看出来。
于是,很自然地,我们想要在两种方案中折中,想要一种没那么极端的解决方案,甚至还能通过一些超参数来进行调节,更好地对齐我们实际需求。这就有了GQA,分组查询注意力机制。
GQA,分组查询注意力机制
直接看图。这张图还是很显然的:既然一套kv太拉跨,那么我们就多整几套,但是还是比num heads要小。从直觉上来说,这么做相比MQA提升了性能,相比MHA降低了开销,达成了一种中间状态。
在有多套kv的情况下,就不是所有头都共享kv了,而是要将所有头进行分组(其实原来的说法是对q进行分组,不过也没差啦~)。这里有4套kv,自然要分成4组,每组2个头,组内的头kv的值相同。例如说,第一个头$head_0$与第二个$head_1$同组,那么它们在进行注意力计算时,$q_i^0$与$q_i^1$要进行操作的k与v是相同的(这里i的意思是第i个位置的token对应的q);而第三个头$head_2$与第四个头$head_3$同组,那么它们在进行注意力计算时,$q_i^2$与$q_i^3$要进行操作的k与v是相同的。以此类推。
接下来我们看看GQA的优势:
这张图里面的MHA-Large,就是比MHA-XXL更小的MHA,用于证明在差不多的开销减少的情况下,MQA和GQA的性能强于单纯减少num heads或者d model等粗暴方法。
横轴是每个样本所花时间,也就是时间开销,纵轴是performance。可以看到GQA的性能仅仅比MHA弱了一点,而时间开销上却远远小于MHA,同时也只比MQA大一点点。当然了,也有可能是研究人员用的硬件带宽他们自己知道,故意卡了一下那个传输瓶颈,毕竟总计算量大家都差不多,比的基本就是传输开销。
代码实现
原本想参考这个项目与这篇知乎文章.但是这个里面竟然是我最讨厌在成品里面使用的einops。这种需要字符串解析而且通用性过强的东西我怎么看怎么觉得效率不行。咱也不是说要性能压榨得多厉害,不然我会去搓c/cuda,但是我还是有点难以接受这种。
所以还是拿原生torch搓了,不管效率是不是真的提高了,起码看着舒服一些。
import torch
import torch.nn as nn
class GroupQueryAttention(nn.Module):
def __init__(self,
d_xq: int,
d_xk: int,
d_xv: int,
h: int,
num_heads: int,
group_size: int,
dropout: float=0.1,
rope_config: dict| None=None
) -> None:
'''
Args:
d_xq: dimension of query
d_xk: dimension of key
d_xv: dimension of value
h: dimension of hidden
num_heads: number of heads
dropout: dropout rate
rotary_positional_encoding: whether to use RoPE (Rotary Positional Encoding)
'''
super().__init__()
self.d_xq = d_xq
self.d_xk = d_xk
self.d_xv = d_xv
self.num_heads = num_heads
self.h = h
self.group_size = group_size
self.apply_rope = False
self.kv_num_heads = kv_num_heads = num_heads // group_size
self.W_q = nn.Linear(d_xq, h * num_heads, bias=False)
self.W_k = nn.Linear(d_xk, h * kv_num_heads, bias=False)
self.W_v = nn.Linear(d_xv, h * kv_num_heads, bias=False)
self.W_o = nn.Linear(h * num_heads, d_xv, bias=False)
self.dropout = nn.Dropout(dropout)
if rope_config is not None:
self.apply_rope = True
self.rope = RoPE(h, max_n=rope_config.get('max_n', 1024),\
base=rope_config.get('base', 10000.0))
def forward(self,
xq: torch.Tensor,
xk: torch.Tensor,
xv: torch.Tensor,
mask: torch.Tensor | None=None
) -> torch.Tensor:
'''
Args:
xq: query, [batch_size, n_q, d_xq]
xk: key, [batch_size, n_k, d_xk]
xv: value, [batch_size, n_v, d_xv]
mask: [batch_size, n_q, n_k] or [n_q, n_k] or None, True: mask, False: no mask
Returns:
o: [batch_size, n_q, d_xv]
'''
# get batch_size, n_q, n_k, n_v, num_heads, h, group_size, kv_num_heads.
n_q, n_k, n_v = xq.size(1), xk.size(1), xv.size(1)
num_heads = self.num_heads
h = self.h
group_size = self.group_size
kv_num_heads = self.kv_num_heads
# separate each head.
# (batch_size, n, h * num_heads) -> (batch_size, n, num_heads, h) -> (batch_size, num_heads, n, h) -> (batch_size * num_heads, n_q, h)
q = self.W_q(xq).reshape(-1, n_q, num_heads, h).permute(0, 2, 1, 3).reshape(-1, n_q, h)
# (batch_size, n, h * kv_num_heads) -> (batch_size, n, kv_num_heads, h) -> (batch_size, kv_num_heads, n, h)
k = self.W_k(xk).reshape(-1, n_k, kv_num_heads, h).permute(0, 2, 1, 3)
v = self.W_v(xv).reshape(-1, n_v, kv_num_heads, h).permute(0, 2, 1, 3)
# (batch_size, kv_num_heads, n_k, h) -> (batch_size, 1, kv_num_heads, n_k, h) -> (batch_size, group_size, kv_num_heads, n_k, h)
# -> (batch_size * group_size * kv_num_heads, n_k, h) = (batch_size * num_heads, n_k, h)
# Tips: expand will not allocate new memory, it will just create a new view of the original tensor.
k = k.reshape(-1, 1, kv_num_heads, n_k, h).expand(-1, group_size, -1, -1, -1).reshape(-1, n_k, h)
v = v.reshape(-1, 1, kv_num_heads, n_v, h).expand(-1, group_size, -1, -1, -1).reshape(-1, n_v, h)
# apply RoPE if enabled
if self.apply_rope:
q = self.rope(q)
k = self.rope(k)
# Q*K^T/sqrt(h)
attn_score = torch.bmm(q, k.transpose(-1, -2)) / sqrt(h) # (batch_size * num_heads, n_q, n_k)
# apply mask
# if mask is a 2d tensor, it will add a batch dimension to broadcast it to the batch size.
# if mask is a 3d tensor, it will repeat the mask for each head.
if mask is not None:
if mask.dim() == 2:
mask = mask.unsqueeze(0)
elif mask.dim() == 3:
mask = torch.repeat_interleave(mask, num_heads, dim=0)
else:
raise ValueError("mask's dim must be 2 or 3")
attn_score = attn_score.masked_fill(mask, -1e9)
# softmax attn_score to get attn_weights, and apply it to v to get v_pool.
attn_weights = F.softmax(attn_score, dim=-1)
v_merge = torch.bmm(self.dropout(attn_weights), v)
# (batch_size * num_heads * n_q, h) -> (batch_size, num_heads, n_q, h) -> (batch_size, n_q, num_heads, h) -> (batch_size, n_q, h * num_heads)
v_merge = v_merge.reshape(-1, num_heads, n_q, h).permute(0, 2, 1, 3).reshape(-1, n_q, h * num_heads)
o = self.W_o(v_merge) # (batch_size, n_q, d_xv)
return o
形状标注得应该还是比较详细。不过这里我要单独讲一讲这个torch.Tensor.expand。这个玩意其实是不会复制元素的,只修改了元信息,创建了一个目标形状的view。所以并不会增加内存开销。这一点官方文档有写。所以不要说你按照元素数量算出来内存开销怎么和MHA一样。
实际上广播机制也是不会分配新内存的。广播函数broadcast_to也不会。
这里面的RoPE其实是我之前的一篇文章,专门讲旋转位置编码的,感兴趣可以去看看。
文章和代码都是自己一个人搓的,难免有错误。欢迎纠错,也欢迎和平讨论~
欢迎友好讨论~