< Explain other AI papers

DuoAttention: Efficient Long-Context LLM Inference with Retrieval and Streaming Heads

Guangxuan Xiao, Jiaming Tang, Jingwei Zuo, Junxian Guo, Shang Yang, Haotian Tang, Yao Fu, Song Han

2024-10-15

DuoAttention: Efficient Long-Context LLM Inference with Retrieval and Streaming Heads

Summary

This paper introduces DuoAttention, a new method designed to improve how large language models (LLMs) handle long pieces of text by optimizing memory and processing speed.

What's the problem?

When LLMs try to understand long texts, they often face significant challenges with memory usage and processing speed. Storing all the information needed for every part of the text can take up a lot of memory, making it hard to run these models efficiently. Current methods to reduce this memory usage often hurt the model's ability to understand long contexts effectively.

What's the solution?

DuoAttention solves this problem by categorizing the model's attention heads into two types: Retrieval Heads and Streaming Heads. Retrieval Heads focus on important details from the entire text and need full attention, while Streaming Heads deal with more recent information and can work with less memory. This method allows the model to use a full memory cache for Retrieval Heads while using a smaller, fixed-size cache for Streaming Heads, which significantly reduces memory usage and speeds up processing without losing performance on long texts.

Why it matters?

This research is important because it makes it easier to deploy LLMs for tasks that require understanding long documents, like summarizing articles or managing conversations. By improving efficiency, DuoAttention can help make advanced AI systems more accessible and effective in real-world applications, such as customer support, education, and content creation.

Abstract

Deploying long-context large language models (LLMs) is essential but poses significant computational and memory challenges. Caching all Key and Value (KV) states across all attention heads consumes substantial memory. Existing KV cache pruning methods either damage the long-context capabilities of LLMs or offer only limited efficiency improvements. In this paper, we identify that only a fraction of attention heads, a.k.a, Retrieval Heads, are critical for processing long contexts and require full attention across all tokens. In contrast, all other heads, which primarily focus on recent tokens and attention sinks--referred to as Streaming Heads--do not require full attention. Based on this insight, we introduce DuoAttention, a framework that only applies a full KV cache to retrieval heads while using a light-weight, constant-length KV cache for streaming heads, which reduces both LLM's decoding and pre-filling memory and latency without compromising its long-context abilities. DuoAttention uses a lightweight, optimization-based algorithm with synthetic data to identify retrieval heads accurately. Our method significantly reduces long-context inference memory by up to 2.55x for MHA and 1.67x for GQA models while speeding up decoding by up to 2.18x and 1.50x and accelerating pre-filling by up to 1.73x and 1.63x for MHA and GQA models, respectively, with minimal accuracy loss compared to full attention. Notably, combined with quantization, DuoAttention enables Llama-3-8B decoding with 3.3 million context length on a single A100 GPU. Code is provided in https://github.com/mit-han-lab/duo-attention.