Sparse Logit Sampling: Accelerating Knowledge Distillation in LLMs
Anshumann, Mohd Abbas Zaidi, Akhil Kedia, Jinwoo Ahn, Taehwak Kwon, Kangwook Lee, Haejun Lee, Joohyung Lee
2025-03-27
Summary
This paper is about making AI language models learn faster and more efficiently by using a smart way to pick which information to focus on.
What's the problem?
Training large AI models takes a lot of time and resources, especially when trying to teach a smaller model what a larger model already knows.
What's the solution?
The researchers created a new method called 'Random Sampling Knowledge Distillation' that helps the smaller model learn from the larger model more efficiently by focusing on the most important information.
Why it matters?
This work matters because it can make AI more accessible by allowing smaller models to be trained more quickly and effectively, without sacrificing performance.
Abstract
Knowledge distillation can be a cost-effective technique to distill knowledge in Large Language Models, if the teacher output logits can be pre-computed and cached. However, successfully applying this to pre-training remains largely unexplored. In this work, we prove that naive approaches for sparse knowledge distillation such as caching Top-K probabilities, while intuitive, provide biased estimates of teacher probability distribution to the student, resulting in suboptimal performance and calibration. We propose an importance-sampling-based method `Random Sampling Knowledge Distillation', which provides unbiased estimates, preserves the gradient in expectation, and requires storing significantly sparser logits. Our method enables faster training of student models with marginal overhead (<10%) compared to cross-entropy based training, while maintaining competitive performance compared to full distillation, across a range of model sizes from 300M to 3B.