TidalDecode: Fast and Accurate LLM Decoding with Position Persistent Sparse Attention
Lijie Yang, Zhihao Zhang, Zhuofu Chen, Zikun Li, Zhihao Jia
2024-10-09

Summary
This paper introduces TidalDecode, a new method for improving how large language models (LLMs) generate text by using a technique called position persistent sparse attention to make the process faster and more efficient.
What's the problem?
As LLMs handle longer pieces of text, they face memory issues because of the large amount of data they need to keep track of during the decoding phase. Existing methods to reduce this memory load often struggle to accurately identify the most important parts of the text and can lead to performance problems.
What's the solution?
TidalDecode addresses these issues by using a combination of full attention and sparse attention. It first identifies the most relevant tokens (or pieces of information) using full attention, then applies sparse attention for the remaining layers. This approach allows the model to maintain high-quality text generation while significantly reducing the memory and processing time needed. The authors tested TidalDecode on various tasks and found that it could generate text as well as traditional methods while being up to 2.1 times faster.
Why it matters?
This research is important because it enhances the efficiency of LLMs, making them more practical for real-world applications like chatbots, translation services, and content creation. By improving how these models handle long texts, TidalDecode could help make advanced AI technologies more accessible and useful.
Abstract
Large language models (LLMs) have driven significant advancements across diverse NLP tasks, with long-context models gaining prominence for handling extended inputs. However, the expanding key-value (KV) cache size required by Transformer architectures intensifies the memory constraints, particularly during the decoding phase, creating a significant bottleneck. Existing sparse attention mechanisms designed to address this bottleneck have two limitations: (1) they often fail to reliably identify the most relevant tokens for attention, and (2) they overlook the spatial coherence of token selection across consecutive Transformer layers, which can lead to performance degradation and substantial overhead in token selection. This paper introduces TidalDecode, a simple yet effective algorithm and system for fast and accurate LLM decoding through position persistent sparse attention. TidalDecode leverages the spatial coherence of tokens selected by existing sparse attention methods and introduces a few token selection layers that perform full attention to identify the tokens with the highest attention scores, while all other layers perform sparse attention with the pre-selected tokens. This design enables TidalDecode to substantially reduce the overhead of token selection for sparse attention without sacrificing the quality of the generated results. Evaluation on a diverse set of LLMs and tasks shows that TidalDecode closely matches the generative performance of full attention methods while reducing the LLM decoding latency by up to 2.1x.