< Explain other AI papers

Offline Reinforcement Learning for LLM Multi-Step Reasoning

Huaijie Wang, Shibo Hao, Hanze Dong, Shenao Zhang, Yilin Bao, Ziran Yang, Yi Wu

2024-12-23

Offline Reinforcement Learning for LLM Multi-Step Reasoning

Summary

This paper talks about OREO (Offline Reasoning Optimization), a new method that uses offline reinforcement learning to improve the ability of large language models (LLMs) to reason through complex problems step-by-step.

What's the problem?

Many existing methods for training LLMs, like Direct Preference Optimization (DPO), struggle with multi-step reasoning tasks because they require specific paired data that isn't always available. Additionally, these methods treat all parts of the reasoning process equally, which makes it hard to assign credit for correct answers when rewards are sparse.

What's the solution?

OREO tackles these issues by using offline reinforcement learning to enhance multi-step reasoning. It learns a policy model and a value function together, allowing it to better understand which steps in the reasoning process lead to correct answers. This method reduces the need for collecting paired data and improves how the model assigns credit for its decisions. The authors tested OREO on various reasoning tasks and found that it performed better than other existing methods.

Why it matters?

This research is important because it helps improve how AI models understand and solve complex problems that require multiple steps of reasoning. By enhancing the capabilities of LLMs in this area, OREO can lead to better performance in applications like mathematics, logic puzzles, and other tasks where step-by-step reasoning is crucial.

Abstract

Improving the multi-step reasoning ability of large language models (LLMs) with offline reinforcement learning (RL) is essential for quickly adapting them to complex tasks. While Direct Preference Optimization (DPO) has shown promise in aligning LLMs with human preferences, it is less suitable for multi-step reasoning tasks because (1) DPO relies on paired preference data, which is not readily available for multi-step reasoning tasks, and (2) it treats all tokens uniformly, making it ineffective for credit assignment in multi-step reasoning tasks, which often come with sparse reward. In this work, we propose OREO (Offline Reasoning Optimization), an offline RL method for enhancing LLM multi-step reasoning. Building on insights from previous works of maximum entropy reinforcement learning, it jointly learns a policy model and value function by optimizing the soft Bellman Equation. We show in principle that it reduces the need to collect pairwise data and enables better credit assignment. Empirically, OREO surpasses existing offline learning methods on multi-step reasoning benchmarks, including mathematical reasoning tasks (GSM8K, MATH) and embodied agent control (ALFWorld). The approach can be extended to a multi-iteration framework when additional resources are available. Furthermore, the learned value function can be leveraged to guide the tree search for free, which can further boost performance during test time.