Data Selection via Optimal Control for Language Models
Yuxian Gu, Li Dong, Hongning Wang, Yaru Hao, Qingxiu Dong, Furu Wei, Minlie Huang
2024-10-10

Summary
This paper discusses a new method for selecting high-quality training data for language models (LMs) to improve their performance on various tasks by using a technique called Optimal Control.
What's the problem?
Language models need large amounts of high-quality data to learn effectively. However, with so much data available, it can be challenging to choose the best examples for training. Poor data selection can lead to slower learning and lower performance in real-world applications.
What's the solution?
The authors propose a framework called PMP-based Data Selection (PDS), which uses a mathematical approach known as Pontryagin's Maximum Principle (PMP) to determine the best data to use for training. By applying this method, they can identify which pieces of data will help the model learn faster and perform better. They tested this approach using data from CommonCrawl and found that it improved the learning speed and overall performance of the models on a variety of tasks. Additionally, PDS helps reduce the amount of data needed for training, making it more efficient.
Why it matters?
This research is important because it provides a systematic way to select training data that enhances the capabilities of language models. By improving how models are trained, this method can lead to better AI applications in areas like natural language processing, content generation, and more, ultimately making AI systems more effective and reliable.
Abstract
This work investigates the selection of high-quality pre-training data from massive corpora to enhance LMs' capabilities for downstream usage. We formulate data selection as a generalized Optimal Control problem, which can be solved theoretically by Pontryagin's Maximum Principle (PMP), yielding a set of necessary conditions that characterize the relationship between optimal data selection and LM training dynamics. Based on these theoretical results, we introduce PMP-based Data Selection (PDS), a framework that approximates optimal data selection by solving the PMP conditions. In our experiments, we adopt PDS to select data from CommmonCrawl and show that the PDS-selected corpus accelerates the learning of LMs and constantly boosts their performance on a wide range of downstream tasks across various model sizes. Moreover, the benefits of PDS extend to ~400B models trained on ~10T tokens, as evidenced by the extrapolation of the test loss curves according to the Scaling Laws. PDS also improves data utilization when the pre-training data is limited, by reducing the data demand by 1.8 times, which mitigates the quick exhaustion of available web-crawled corpora. Our code, data, and model checkpoints can be found in https://github.com/microsoft/LMOps/tree/main/data_selection.