streaming-llm(无需微调无限扩展大模型输入)论文笔记
一、前言
在多轮对话等流式应用中部署大型语言模型(LLMs)是非常重要的,因为大模型的优点是能够记住足够长的上下文对话内容带来长时间的互动,但这也带来了两个主要挑战。首先,在解码阶段,缓存先前token的key和value状态(KV)会消耗大量的内存。其次,热门开源的LLM无法泛化到比训练序列长度更长的文本。仅缓存最近KV的窗口注意力是一种常见的做法,但我们发现当文本长度超过缓存大小时,这种方法就会失效。论文中观察到一个有趣的现象,即Attention Sink,保持初始token的key和value可以大大恢复窗口注意力的性能。
在这篇论文中,证明了attention sink的出现是由于即使初始token在语义上并不重要,它们仍然影响着注意力的分数。基于上述分析,论文中提出StreamingLLM,一个高效的框架,使用有限长度注意力窗口训练的LLMs能够无差别地泛化到无限序列长度,而无需进行任何微调。论文发现StreamingLLM可以使Llama-2、MPT、Falcon和Pythia等模型在多达400万甚至更多的token上进行稳定且高效的语言建模。此外,我们还发现,在预训练过程中添加一个占位符token作为专门的Attention Sink,可以进一步提高流式部署的性能。在流式设置中,StreamingLLM相对于滑动窗口重新计算基准实现了高达22.2倍的速度提升。
二、LLM扩展序列长度主要挑战
当将LLM应用于无限token输入时,会出现两个主要挑战:
- 在解码阶段,基于 Transformer 的 LLM 缓存 Key 和 Value 状态 (KV),这可能会导致内存过多
使用和增加解码延迟。
- 现有模型的长度外推能力有限,即当序列长度超过时,其性能会下降注意预训练期间设置的窗口大小。
四种注意力机制长文本表现
从上面图中可以看出四种注意力机制在长文本中的表现:
- 第一种注意力机制叫做密集注意力(Dense Attention) ,密集注意力随着token的数量增加,带来基本上是指数kv的增长,当token到达一定量级kv的缓存非常大,并且模型的**PPL(困惑度)**变成一个极大值,性能衰减非常严重。基本上不能实现长序列输入,对于多伦对话场景长上文输入会带来当前输入内容被遗忘,模型回答会不可控。
- 第二种注意力机制叫做窗口注意力(Window Attention) ,窗口注意力机制主要是使用窗口缓存token注意力,缓存窗口内token的key和value矩阵,当不在窗口内token的key和value将会被丢弃,尤其是起始的token,虽然不需要缓存这么多key和value能够带来推理性能的提升,但是同时丢弃了很多信息会导致效果下降严重,PPL急剧升高。
- 第三种注意力机制叫做滑动窗口重计算(Sliding Window w/ Re-computation) ,滑动式重新计算的窗口根每个新token的 L (窗口大小)个最近token重建 KV 状态。虽然它在长文本上表现良好,但它的 O( )复杂性,源于二次注意力在上下文重新计算中,使得速度相当慢,工业应用困难。
- 第三种注意力机制叫做初始token融合窗口重计算注意力(StreamingLLM,中文我瞎编的),StreamingLLM 保持注意力集中(几个初始token)与最近的token相结合,用于稳定的注意力计算。 效率很高并在扩展文本上提供稳定的性能。
三、Attention Sink的原因
要想理解窗口注意力为啥会存在比较大的缺陷,我们会发现自回归LLM有一个有趣的现象:注意力分数最大总是集中在初始的几个token上,而不管它们与语言建模任务的相关性如何,我们将这些标记token称为"Attention Sink"。尽管它们缺乏语义意义,但它们收集了大量的注意力分数。
我们把原因归因于Softmax操作,它要求所有上下文标记的注意力分数之和为1。因此,即使当前查询在许多先前标记token中没有强烈的匹配任务信息,模型仍然需要在某个地方分配这些不必要的注意力值,使其总和为1。初始标记token作为Sink的原因是很直观的:由于自回归语言建模的性质,初始标记token对于几乎所有后续token推理时都是可见的,这使得它们更容易训练成为Attention Sink。
四、StreamingLLM框架
基于上述见解,论文提出了StreamingLLM,这是一个简单且高效的框架,使用有限注意力窗口训练的LLM能够处理无限长度的文本,而不需要微调。StreamingLLM利用了注意力陷阱具有高注意力值的事实,保留它们可以保持注意力得分分布接近正常。因此,StreamingLLM简单地保留了Attention Sink token的KV(只需要4个初始标记token就足够了)以及滑动窗口的KV,以锚定注意力计算并稳定模型的性能。
借助StreamingLLM,包括Llama-2-[7, 13, 70]B、MPT-[7, 30]B、Falcon-[7, 40]B和Pythia-[2.9,6.9,12]B在内的模型可以可靠地模拟400万token扩展,甚至可能更多。与唯一的可行基线(带有重新计算的滑动窗口Sliding Window w/ Re-computation)相比,StreamingLLM实现了高达22.2倍的速度提升,实现了LLM的流式输出。
大模型前两层和后几层注意力分布
在Llama-2-7B上对256个句子(每个句子长度为16)的平均注意力logits的可视化。
观察结果包括:
(1)在前两层(层0和层1)的注意力热度图呈现出"局部"模式,最近的token接收到更多的注意力。
(2)在底层之外,模型在所有层和头中都对初始token进行了大量关注。
最后,证实了Attention Sink 假设,并证明了语言模型可以在预训练时仅需要一个Attention Sink Token 来进行流式部署。具体来说建议在所有训练样本的开始处添加一个可学习的额外token,作为指定的Attention Sink 。通过从头开始预训练1.6亿参数的语言模型,我们证明了添加这个单一的Attention SinkToken 可以保持模型在流式输出情况下的性能。这与传统的模型形成对比,后者需要在实现相同性能水平时将多个初始标记重新引入作为Attention Sink。
20k长文本输出四种注意力机制ppl表现
对各种 LLM 中具有 20K 个标记的文本进行语言建模的困惑度。 观察结果揭示一致的趋势:
(1)一旦输入长度超过预训练,密集注意力就会失败。
(2) 一旦输入长度超过缓存大小,窗口注意力就会崩溃。
(3) StreamingLLM表现出稳定的性能,其困惑度几乎与具有重新计算基线的滑动窗口相匹配。
五、无需微调可以长距离扩展
为了在已经训练好的LLM中启用LLM流式处理,论文提出了一种简单的方法,可以在不进行任何模型微调的情况下恢复窗口注意力的困惑度。
在这个方法中,注意力计算中重新引入了最初几个token的KV。StreamingLLM中的KV缓存可以从概念上分为两部分。
StreamingLLM原理计算细节
如上图所示:
(1)Attention Sink(四个初始token)使注意力计算稳定;
(2)滚动KV缓存保留了最近的token,这对于语言建模至关重要。StreamingLLM的设计是通用的,可以无缝地整合到任何使用相对位置编码的自回归语言模型中,如RoPE(Su等人,2021)和ALiBi(Press等人,2022)等相对位置编码模型。
当确定相对距离并为token添加位置信息时,StreamingLLM关注的是缓存内的位置,而不是原始文本中的位置。这种区别对于StreamingLLM的性能至关重要。例如,如果当前缓存中有标记[0, 1, 2, 3, 6, 7, 8],并且正在解码第9个token,那么分配的位置是[0, 1, 2, 3, 4, 5, 6, 7],而不是原始文本中的位置,即[0, 1, 2, 3, 6, 7, 8, 9]。对于像RoPE这样的编码,在引入旋转变换之前先缓存token的key。然后,在每个解码阶段对滚动缓存中的key应用位置变换。另一方面,与ALiBi集成更为直接。在这里,应用连续线性偏差代替跳跃偏差到注意力得分。这种方法为缓存内分配位置嵌入对于StreamingLLM的功能至关重要,确保模型即使在超出其预训练注意力窗口大小的情况下也能有效地运行。
六、提出变体SoftMax(SoftMax-off-by-One)
原始的SoftMax函数:
SoftMax-off-by-One函数:
不需要所有上下文token的关注分数之和为1的方法也可能有效。请注意,这种SoftMax替代方法相当于在注意力计算中使用具有全零key和value特征的token。我们将这种方法称为"Zero Sink "以使其与StreamingLLM框架中的一致。
为了验证,在相同的设置下从头开始预训练了三个具有1.6亿个参数的语言模型。第一个模型使用标准SoftMax注意力(Vanilla ),第二个用SoftMax1替换常规注意力机制(Zero Sink ),第三个在所有训练样本中前置一个可学习的占位符token(Sink Token )。如上表所示,虽然Zero Sink 在一定程度上缓解了Attention Sink 问题,但模型仍然依赖于其他初始token作为Attention Sink 。引入Sink Token 对于稳定注意力机制非常有效。简单地让这个Sink Token 与最近Token配对就足以锚定模型的性能,而且产生的评估困惑度甚至略有改善。根据这些发现,建议在所有样本中都使用Sink Token训练未来的LLM,以优化流式部署。