Improving Transformer World Models for Data-Efficient RL
Antoine Dedieu, Joseph Ortiz, Xinghua Lou, Carter Wendelken, Wolfgang Lehrach, J Swaroop Guntupalli, Miguel Lazaro-Gredilla, Kevin Patrick Murphy
2025-02-04
Summary
This paper talks about improving how AI systems learn to make decisions in complex environments, using a method called model-based reinforcement learning (MBRL). It introduces new techniques that help AI perform better in a challenging survival game called Craftax-classic, where the goal is to make smart decisions over time.
What's the problem?
AI systems that learn by trial and error often require a lot of data and time to become good at solving problems. In complex environments, like open-world games, they struggle to plan ahead, explore efficiently, and adapt to new challenges. Current methods are not efficient enough and don’t perform as well as humans in these tasks.
What's the solution?
The researchers developed a new MBRL algorithm that uses advanced techniques to improve how AI learns and makes decisions. They started with a strong baseline model that combines two types of neural networks (CNNs and RNNs). Then, they added three key improvements: 'Dyna with warmup,' which trains the AI using both real and simulated data; 'nearest neighbor tokenizer,' which helps the AI process images more effectively; and 'block teacher forcing,' which improves the AI's ability to plan for the future. These changes allowed their algorithm to achieve better performance than previous methods, even surpassing human performance in the Craftax-classic game.
Why it matters?
This research is important because it shows how AI can become more efficient and smarter at solving complex problems. By requiring less data and achieving better results, this method could lead to advancements in areas like robotics, gaming, and real-world decision-making tasks. It also demonstrates that AI can now outperform humans in certain challenging environments, which is a big step forward for the field.
Abstract
We present an approach to model-based RL that achieves a new state of the art performance on the challenging Craftax-classic benchmark, an open-world 2D survival game that requires agents to exhibit a wide range of general abilities -- such as strong generalization, deep exploration, and long-term reasoning. With a series of careful design choices aimed at improving sample efficiency, our MBRL algorithm achieves a reward of 67.4% after only 1M environment steps, significantly outperforming DreamerV3, which achieves 53.2%, and, for the first time, exceeds human performance of 65.0%. Our method starts by constructing a SOTA model-free baseline, using a novel policy architecture that combines CNNs and RNNs. We then add three improvements to the standard MBRL setup: (a) "Dyna with warmup", which trains the policy on real and imaginary data, (b) "nearest neighbor tokenizer" on image patches, which improves the scheme to create the transformer world model (TWM) inputs, and (c) "block teacher forcing", which allows the TWM to reason jointly about the future tokens of the next timestep.