Stuffed Mamba: State Collapse and State Capacity of RNN-Based Long-Context Modeling
Yingfa Chen, Xinrong Zhang, Shengding Hu, Xu Han, Zhiyuan Liu, Maosong Sun
2024-10-10

Summary
This paper discusses Stuffed Mamba, a study on improving the ability of recurrent neural networks (RNNs) to handle long sequences of data without losing performance.
What's the problem?
Recurrent neural networks (RNNs) are designed to process sequences of data, like sentences or time series, but they often struggle with long sequences beyond what they were trained on. This leads to a problem known as 'state collapse,' where the model's performance drops significantly when it encounters longer inputs than it has seen during training. Additionally, RNNs have limitations in their memory capacity, making it hard for them to remember information from long contexts.
What's the solution?
The authors investigate the causes of state collapse and propose solutions to improve RNNs' ability to process longer sequences. They identify that state collapse occurs because the model is over-parameterized for the training length, which leads to overfitting. To mitigate this, they suggest three methods that help the model forget unnecessary information and one method that involves continual training on longer sequences. They also train a new version of their model called Mamba-2, which can effectively handle over 1 million tokens without collapsing.
Why it matters?
This research is important because it enhances the capabilities of RNNs, making them more effective for tasks that require understanding long sequences, such as language modeling and information retrieval. By improving how RNNs manage memory and context, this work can lead to better performance in various applications, including natural language processing and AI systems that rely on sequential data.
Abstract
One essential advantage of recurrent neural networks (RNNs) over transformer-based language models is their linear computational complexity concerning the sequence length, which makes them much faster in handling long sequences during inference. However, most publicly available RNNs (e.g., Mamba and RWKV) are trained on sequences with less than 10K tokens, and their effectiveness in longer contexts remains largely unsatisfying so far. In this paper, we study the cause of the inability to process long context for RNNs and suggest critical mitigations. We examine two practical concerns when applying state-of-the-art RNNs to long contexts: (1) the inability to extrapolate to inputs longer than the training length and (2) the upper bound of memory capacity. Addressing the first concern, we first investigate *state collapse* (SC), a phenomenon that causes severe performance degradation on sequence lengths not encountered during training. With controlled experiments, we attribute this to overfitting due to the recurrent state being overparameterized for the training length. For the second concern, we train a series of Mamba-2 models on long documents to empirically estimate the recurrent state capacity in language modeling and passkey retrieval. Then, three SC mitigation methods are proposed to improve Mamba-2's length generalizability, allowing the model to process more than 1M tokens without SC. We also find that the recurrent state capacity in passkey retrieval scales exponentially to the state size, and we empirically train a Mamba-2 370M with near-perfect passkey retrieval accuracy on 256K context length. This suggests a promising future for RNN-based long-context modeling.