< Explain other AI papers

Learning to (Learn at Test Time): RNNs with Expressive Hidden States

Yu Sun, Xinhao Li, Karan Dalal, Jiarui Xu, Arjun Vikram, Genghan Zhang, Yann Dubois, Xinlei Chen, Xiaolong Wang, Sanmi Koyejo, Tatsunori Hashimoto, Carlos Guestrin

2024-07-08

Learning to (Learn at Test Time): RNNs with Expressive Hidden States

Summary

This paper talks about a new type of recurrent neural network (RNN) that can learn and adapt while processing data, called Learning to (Learn at Test Time). It introduces a special kind of layer that allows the model to update its internal memory during testing, improving its ability to handle long sequences of information.

What's the problem?

The main problem is that traditional RNNs struggle with long sequences of data because their hidden states (the internal memory) aren't powerful enough to capture all the necessary information. While self-attention models work well for long contexts, they are complex and slow. Existing RNNs have a simpler structure but can’t effectively manage long contexts, leading to poorer performance.

What's the solution?

To solve this issue, the authors propose a new class of layers called Test-Time Training (TTT) layers. These layers allow the hidden state to act like a small machine learning model that can learn from new data even while the model is being tested. They created two versions of these layers: one uses a simple linear model (TTT-Linear) and the other uses a more complex two-layer model (TTT-MLP). The researchers tested these layers against strong models like Transformers and found that both TTT-Linear and TTT-MLP performed as well as or better than existing models, especially in handling longer sequences.

Why it matters?

This research is important because it makes RNNs more flexible and capable of adapting to new information without needing extensive retraining. By allowing models to learn during testing, it can lead to better performance in real-world applications where data can change or be unpredictable. This advancement could enhance various fields, such as natural language processing, robotics, and any area that relies on understanding sequences of information.

Abstract

Self-attention performs well in long context but has quadratic complexity. Existing RNN layers have linear complexity, but their performance in long context is limited by the expressive power of their hidden state. We propose a new class of sequence modeling layers with linear complexity and an expressive hidden state. The key idea is to make the hidden state a machine learning model itself, and the update rule a step of self-supervised learning. Since the hidden state is updated by training even on test sequences, our layers are called Test-Time Training (TTT) layers. We consider two instantiations: TTT-Linear and TTT-MLP, whose hidden state is a linear model and a two-layer MLP respectively. We evaluate our instantiations at the scale of 125M to 1.3B parameters, comparing with a strong Transformer and Mamba, a modern RNN. Both TTT-Linear and TTT-MLP match or exceed the baselines. Similar to Transformer, they can keep reducing perplexity by conditioning on more tokens, while Mamba cannot after 16k context. With preliminary systems optimization, TTT-Linear is already faster than Transformer at 8k context and matches Mamba in wall-clock time. TTT-MLP still faces challenges in memory I/O, but shows larger potential in long context, pointing to a promising direction for future research.