Mask-Enhanced Autoregressive Prediction: Pay Less Attention to Learn More
Xialie Zhuang, Zhikai Jia, Jianjin Li, Zhenyu Zhang, Li Shen, Zheng Cao, Shiwei Liu
2025-02-12
Summary
This paper talks about a new way to train large language models called Mask-Enhanced Autoregressive Prediction (MEAP), which helps AI understand and remember important information better.
What's the problem?
Big AI language models are really good at many things, but they sometimes struggle to pick out and remember the most important bits of information from long texts. This makes it hard for them to answer questions accurately or understand complex ideas.
What's the solution?
The researchers created MEAP, which works by hiding (or 'masking') some words in the text the AI is learning from, and then asking the AI to predict what comes next. This helps the AI focus on the important parts of the text and ignore less relevant information. MEAP doesn't need any extra computer power to work, which is a big plus.
Why it matters?
This matters because it could make AI language models much better at understanding and using important information from long texts. In tests, MEAP did much better than other methods at tasks that need good memory and understanding of context. This could lead to AI assistants that are better at answering questions, summarizing long documents, or helping with complex research tasks.
Abstract
Large Language Models (LLMs) are discovered to suffer from accurately retrieving key information. To address this, we propose Mask-Enhanced Autoregressive Prediction (MEAP), a simple yet effective training paradigm that seamlessly integrates Masked Language Modeling (MLM) into Next-Token Prediction (NTP) to enhance the latter's in-context retrieval capabilities. Specifically, MEAP first randomly masks a small fraction of input tokens and then directly performs the standard next-token prediction autoregressive using a decoder-only Transformer. MEAP eliminates the need for bidirectional attention or encoder-decoder architectures for MLM, incurring no additional computational overhead during pre-training or inference. Intensive experiments demonstrate that MEAP substantially outperforms NTP on key information retrieval and long-context reasoning tasks, while performing on par or better on commonsense reasoning tasks. The benefits of MEAP also extend to supervised fine-tuning, where it shows remarkable advantages in lost-in-the-middle scenarios, outperforming NTP by 11.77 percentage points. Our analysis indicates that MEAP's effectiveness arises from its ability to promote more distinguishable attention scores by concentrating on a reduced set of non-masked tokens. This mechanism improves the model's focus on task-relevant signals while mitigating the influence of peripheral context. These findings position MEAP as a promising training paradigm for large language models.