Reasoning to Learn from Latent Thoughts
Yangjun Ruan, Neil Band, Chris J. Maddison, Tatsunori Hashimoto
2025-03-25
Summary
This paper explores a way to make AI learn more effectively, especially when there isn't a lot of training data available, by helping the AI understand the hidden thought processes behind the information it's learning.
What's the problem?
AI models need a lot of data to learn, but sometimes we don't have enough high-quality data. This is a problem because it limits how well the AI can perform.
What's the solution?
The researchers suggest that AI can learn better if it tries to figure out the 'latent thoughts' or reasoning steps that went into creating the data it's learning from. It's like showing your work in math class - the steps help you understand the answer better. The researchers tested this idea on math problems and found that it improved the AI's performance.
Why it matters?
This work matters because it could help AI learn more effectively with less data, which is important in situations where data is scarce or expensive to obtain. It could also help AI become better at reasoning and problem-solving by understanding the thought processes behind the information it's learning.
Abstract
Compute scaling for language model (LM) pretraining has outpaced the growth of human-written texts, leading to concerns that data will become the bottleneck to LM scaling. To continue scaling pretraining in this data-constrained regime, we propose that explicitly modeling and inferring the latent thoughts that underlie the text generation process can significantly improve pretraining data efficiency. Intuitively, our approach views web text as the compressed final outcome of a verbose human thought process and that the latent thoughts contain important contextual knowledge and reasoning steps that are critical to data-efficient learning. We empirically demonstrate the effectiveness of our approach through data-constrained continued pretraining for math. We first show that synthetic data approaches to inferring latent thoughts significantly improve data efficiency, outperforming training on the same amount of raw data (5.7\% rightarrow 25.4\% on MATH). Furthermore, we demonstrate latent thought inference without a strong teacher, where an LM bootstraps its own performance by using an EM algorithm to iteratively improve the capability of the trained LM and the quality of thought-augmented pretraining data. We show that a 1B LM can bootstrap its performance across at least three iterations and significantly outperform baselines trained on raw data, with increasing gains from additional inference compute when performing the E-step. The gains from inference scaling and EM iterations suggest new opportunities for scaling data-constrained pretraining.