个人技术分享

随着输入长度的增加,大型语言模型(LLMs)中的键值(KV)缓存需要存储更多的上下文信息以维持性能,这导致内存消耗和计算时间急剧上升。KV缓存的增长对内存和时间效率的挑战主要表现在两个方面:一是在处理长文本时,模型需要更多的内存资源来存储KV缓存,这不仅增加了硬件成本,还可能因内存限制而影响模型规模的扩展;二是在生成文本时,模型需要对KV缓存中的每个键值对进行注意力计算,随着缓存的增大,这个过程变得更加耗时,从而降低了模型的解码速度。

在深入分析大型语言模型(LLMs)的注意力机制时,研究者发现了一些关键模式,这些模式对于优化KV缓存至关重要。某些键在令牌生成期间始终吸引着模型的注意力,无论上下文的长度如何,这些“活跃”的键展现出稳定的高注意力权重。在长摘要和问答任务中,问题的位置(无论是在提示的开头还是结尾)对模型的注意力分配模式影响不大,显示出模型在处理长文本时的鲁棒性。研究者还发现注意力模式高度依赖于上下文,与用户的具体指令密切相关,这意味着不同的指令会引导模型关注不同的信息。

在超过 3k 的平均输入长度下,重要位置在生成过程中的逐层平均命中率。它显示了在多轮对话中,观察窗口中识别出的关键特征在随后的标记生成中保持了显著的一致性

这些观察结果促成了SnapKV的开发,它是一种创新的KV缓存压缩方法。SnapKV的创新之处在于它提出了一种无需微调的压缩方法,通过观察模型在生成过程中的注意力分配模式,自动识别并压缩KV缓存中的关键信息。SnapKV通过“投票”机制选出每个注意力头关注的关键KV位置,并通过聚类算法保留这些关键特征周围的信息,从而在不牺牲准确性的前提下显著减少KV缓存的大小。这种方法不仅减少了计算开销,还提高了内存效率,使得模型在处理长文本时更为高效。SnapKV在解码速度上实现了3.6倍的提升,在内存效率上实现了8.2倍的提升,同时在多个长序列数据集上保持了与基线模型相当的性能。SnapKV还能够与现有的深度学习框架轻松集成,仅需少量代码调整,为长文本处理提供了一种实用的解决方案。

 SnapKV 的简化工作流程。橙色区域表示 SnapKV 为每个注意力头选择的每头一组位置的聚类。这些聚类特征随后与观察窗口中的标记一起形成新的键值对,与观察窗口(标记为 "Window")连接。所选的前缀和观察窗口共同构成了用于生成的新 KV 缓存

SnapKV 的核心思想是在生成过程中保持提示(prompt)的 KV 缓存数量恒定,从而显著减少长上下文 LLMs 的服务时间。这是通过识别和选择每个注意力头(attention head)最关键的注意力特征(attention features)来实现的,以此创建一个新的、更小的 KV 缓存。

实现步骤

SnapKV 的实现分为两个主要阶段:

  1. 投票选择重要特征(Voting for Important Previous Features):

    • 利用定义好的投票过程(如公式1所示),基于观察窗口(observation window)——即提示的最后部分——来选择重要的特征。
    • 通过分析发现,这些特征在整个序列生成过程中表现出显著的一致性,表明它们对后续生成至关重要。
    • 此外,实施聚类算法以保留选定特征周围的特征,这有助于保留信息的完整性并避免丢失上下文。
  2. 更新和存储截断的键和值(Update and Store Truncated Key and Value):

    • 将选定的特征与观察窗口的特征连接起来,这些特征包含了所有提示信息。
    • 将连接后的 KV 缓存存储起来,以供后续生成使用,同时节省内存使用。

代码示例

def snap_kv(query_states, key_states, value_states, window_size, max_capacity_prompt, kernel_size):
    bsz, num_heads, q_len, head_dim = query_states.shape
    # 确保当前是处理提示阶段
    assert key_states.shape[-2] == query_states.shape[-2]
    if q_len < max_capacity_prompt:
        return key_states, value_states
    else:
        # 计算观察窗口的查询和前缀上下文的键的注意力权重
        attn_weights = compute_attn(query_states[..., -window_size:, :], key_states, attention_mask)
        # 沿着查询维度对权重求和
        attn_weights_sum = attn_weights[..., -window_size:, :-window_size].sum(dim=-2)
        # 应用1D池化进行聚类
        attn_cache = pool1d(attn_weights_sum, kernel_size=kernel_size, padding=kernel_size // 2, stride=1)
        # 基于池化后的权重选择每个头的top-k索引,以识别重要位置
        indices = attn_cache.topk(max_capacity_prompt - window_size, dim=-1).indices
        # 扩展索引以匹配头维度进行聚集
        indices = indices.unsqueeze(-1).expand(-1, -1, -1, head_dim)
        # 根据选定的索引聚集压缩的过去键和值状态
        k_past_compress = key_states[..., :-window_size, :].gather(dim=2, index=indices)
        v_past_compress = value_states[..., :-window_size, :].gather(dim=2, index=indices)
        k_obs = key_states[..., -window_size:, :]
        v_obs = value_states[..., -window_size:, :]
        # 将压缩后的过去键和观察窗口的键拼接在一起
        key_states = torch.cat([k_past_compress, k_obs], dim=2)
        # 将压缩后的过去值和观察窗口的值拼接在一起
        value_states = torch.cat([v_past_compress, v_obs], dim=2)
    return key_states, value_states
  • snap_kv 函数接受查询状态 query_states、键状态 key_states、值状态 value_states,以及其他参数如窗口大小 window_size、最大提示容量 max_capacity_prompt 和池化核大小 kernel_size
  • 首先,检查是否处于处理提示的阶段,如果不是,则直接返回原始的键和值状态。
  • 然后,计算观察窗口内查询和前缀上下文键之间的注意力权重 attn_weights
  • 对这些权重进行求和 attn_weights_sum,然后应用一维池化 pool1d 来聚类,以便选择重要的特征。
  • 使用 topk 方法根据池化后的权重选择每个头的重要位置 indices
  • 根据这些索引聚集压缩后的键和值状态 k_past_compress 和 v_past_compress
  • 最后,将压缩的键和值状态与观察窗口的键和值状态拼接起来,形成新的键和值状态,这些状态将用于后续的生成过程。

SnapKV方法论包含两个主要阶段:首先是通过一个称为“投票”的过程来识别重要的先前特征,其次是更新并存储截断的键和值。

在第一阶段,SnapKV利用了一个观察窗口,这个窗口位于提示的末端,它的作用是捕捉模型在生成过程中所关注的关键特征。通过计算观察窗口中每个查询的注意力权重,并在所有注意力头上聚合这些权重,SnapKV能够突出显示被认为是最重要的前缀位置。这个过程称为投票,它帮助系统识别出那些在生成文本时需要特别关注的KV位置。

第二阶段中,SnapKV将这些选出的重要特征与观察窗口中的特征结合起来,形成一个新的键值对。这个新的键值对随后被用于生成过程,同时通过仅保留这些关键特征,系统能够显著减少所需的内存和计算资源。此外,为了保持信息的完整性并避免因过度压缩而导致的细节丢失,SnapKV采用了一种基于池化的聚类算法。这个算法通过池化层对信息进行细粒度的压缩,确保了在压缩KV缓存的同时,依然能够保留足够的上下文信息,从而维持模型的准确性。

SnapKV的实现是高效的,因为它只需要对现有的深度学习框架进行少量的代码调整。这意味着它可以轻松地集成到现有的系统中,而无需进行大规模的重构。在实验中,SnapKV显示出了卓越的性能,它不仅在解码速度上实现了显著的提升,还在内存效率上达到了大幅度的增强。这些改进使得SnapKV成为一个在处理长文本方面极具潜力的解决方案,特别是在需要处理大量输入序列的应用场景中,如聊天机器人、代理服务、文档处理等。SnapKV通过其创新的方法论,为长文本处理中的内存和时间效率问题提供了一个有效的解决方案。

在对SnapKV进行的实验中,研究团队采取了一系列严谨的测试,旨在评估该方法在不同模型和长文本数据集上的性能。实验的目的是验证SnapKV在减少计算和内存负担的同时,是否能够保持或甚至提升模型的生成质量和效率。

实验首先在LWM-Text-Chat-1M模型上进行了压力测试,这是当时最先进的模型之一,能够处理长达一百万个令牌的上下文。测试中,SnapKV展现了其算法效率,特别是在硬件优化方面。通过“Needle-in-a-Haystack”测试,即在长达380K令牌的文档中准确检索特定句子的能力,SnapKV证明了其在极端条件下处理长文本的能力,即便在极高的压缩比下也能保持精确性。

在单 A100-80GB GPU 上,使用 SnapKV 优化的模型与原始实现在 "Needle-in-a-Haystack" 测试中的性能比较。图表显示了 SnapKV 在处理极长输入上下文时准确管理细节的潜力

实验通过不同的批处理大小对LWM-Text-Chat-1M模型进行了解码速度和内存限制的基准测试。结果表明,SnapKV优化的模型在解码速度上保持了稳定,与输入序列长度的增加无关,这与基线实现形成了鲜明对比,后者的解码速度随输入长度的增加而指数级增长。SnapKV显著提高了模型处理长序列的能力,显著减少了内存消耗。

在不同批量大小设置下,基线实现与 SnapKV 优化模型在解码速度和输入序列长度方面的比较。图表显示了 SnapKV 在保持解码速度恒定方面的优势

为了进一步验证SnapKV的有效性,研究团队还对Mistral-7B-Instruct-v0.2模型进行了消融研究,以理解池化技术对模型信息检索性能的影响。消融研究结果表明,通过池化增强了检索准确性,这可能是因为强大的注意力机制倾向于关注令牌序列的初始部分。

池化技术在 LongEval-Lines 测试中对检索准确性的影响。结果表明,使用池化的模型在压缩 KV 缓存时的性能显著优于不使用池化的模型

最后,实验使用了LongBench,这是一个多任务基准测试,旨在全面评估长文本理解能力。SnapKV在多个不同设置下进行了测试,包括压缩KV缓存到1024、2048和4096个令牌,并使用最大池化和观察窗口。测试结果显示,即使在压缩率高达92%的情况下,SnapKV与原始实现相比,在16个不同数据集上的性能下降可以忽略不计,某些情况下甚至超过了基线模型。

 SnapKV 和 H2O 在 LongBench 数据集上不同 LLMs 的性能比较。表格显示了 SnapKV 在压缩提示 KV 缓存时的性能与原始实现相比几乎没有下降

实验结果,SnapKV证明了其作为一种有效的KV缓存压缩方法,在保持大型语言模型处理长文本的能力的同时,显著提升了效率和减少了资源消耗。这些发现不仅证实了SnapKV的实用性,还为未来的研究和应用提供了有价值的见解。

SnapKV代码可在https://github.com/FasterDecoding/SnapKV上找到

论文链接:https://arxiv.org/abs/2404.14469