Critical Tokens Matter: Token-Level Contrastive Estimation Enhence LLM's Reasoning Capability
Zicheng Lin, Tian Liang, Jiahao Xu, Xing Wang, Ruilin Luo, Chufan Shi, Siheng Li, Yujiu Yang, Zhaopeng Tu
2024-12-04

Summary
This paper discusses a new method called cDPO that improves the reasoning abilities of large language models (LLMs) by focusing on the importance of specific tokens in their responses.
What's the problem?
Large language models are good at reasoning tasks, but they sometimes make mistakes due to certain 'critical tokens'—specific words or phrases that can lead to incorrect conclusions. These models often perform better when they focus on other tokens instead of these critical ones, which can result in errors in reasoning and understanding.
What's the solution?
The researchers developed a method called contrastive token-level estimation (cDPO) to identify and reward these critical tokens during the training process. They created two types of models: positive and negative, which help the system learn which tokens are critical for correct reasoning. By fine-tuning these models, they can better recognize when critical tokens are leading to mistakes and adjust their responses accordingly. This approach helps improve the overall accuracy of the model's reasoning capabilities.
Why it matters?
This research is important because it enhances how well AI systems can reason and understand complex information. By focusing on critical tokens, cDPO helps LLMs make better decisions and provide more accurate answers, which is crucial for applications like education, customer service, and any field that relies on precise communication.
Abstract
Large Language Models (LLMs) have exhibited remarkable performance on reasoning tasks. They utilize autoregressive token generation to construct reasoning trajectories, enabling the development of a coherent chain of thought. In this work, we explore the impact of individual tokens on the final outcomes of reasoning tasks. We identify the existence of ``critical tokens'' that lead to incorrect reasoning trajectories in LLMs. Specifically, we find that LLMs tend to produce positive outcomes when forced to decode other tokens instead of critical tokens. Motivated by this observation, we propose a novel approach - cDPO - designed to automatically recognize and conduct token-level rewards for the critical tokens during the alignment process. Specifically, we develop a contrastive estimation approach to automatically identify critical tokens. It is achieved by comparing the generation likelihood of positive and negative models. To achieve this, we separately fine-tune the positive and negative models on various reasoning trajectories, consequently, they are capable of identifying identify critical tokens within incorrect trajectories that contribute to erroneous outcomes. Moreover, to further align the model with the critical token information during the alignment process, we extend the conventional DPO algorithms to token-level DPO and utilize the differential likelihood from the aforementioned positive and negative model as important weight for token-level DPO learning.Experimental results on GSM8K and MATH500 benchmarks with two-widely used models Llama-3 (8B and 70B) and deepseek-math (7B) demonstrate the effectiveness of the propsoed approach cDPO.