The Curse of Conditions: Analyzing and Improving Optimal Transport for Conditional Flow-Based Generation
Ho Kei Cheng, Alexander Schwing
2025-03-14
Summary
This paper tackles a problem with 'optimal transport' in AI image generation, specifically when trying to generate images based on certain conditions (like generating a cat picture given the condition 'fluffy').
What's the problem?
Optimal transport helps AI generate better images, but it struggles when conditions are added. It's like trying to draw a cat but the AI ignores the 'fluffy' instruction because it's focusing on drawing a general cat shape. This happens because the AI's training gets messed up by ignoring the conditions, leading to lower quality images.
What's the solution?
The researchers created a fix called C^2OT (conditional optimal transport) that adds a 'conditional weighting term' to the process. This helps the AI pay attention to the conditions during training, ensuring it generates images that accurately reflect the desired conditions.
Why it matters?
This work matters because it improves the quality of AI-generated images when specific conditions are required, making the AI more versatile and capable of creating a wider range of images based on user instructions.
Abstract
Minibatch optimal transport coupling straightens paths in unconditional flow matching. This leads to computationally less demanding inference as fewer integration steps and less complex numerical solvers can be employed when numerically solving an ordinary differential equation at test time. However, in the conditional setting, minibatch optimal transport falls short. This is because the default optimal transport mapping disregards conditions, resulting in a conditionally skewed prior distribution during training. In contrast, at test time, we have no access to the skewed prior, and instead sample from the full, unbiased prior distribution. This gap between training and testing leads to a subpar performance. To bridge this gap, we propose conditional optimal transport C^2OT that adds a conditional weighting term in the cost matrix when computing the optimal transport assignment. Experiments demonstrate that this simple fix works with both discrete and continuous conditions in 8gaussians-to-moons, CIFAR-10, ImageNet-32x32, and ImageNet-256x256. Our method performs better overall compared to the existing baselines across different function evaluation budgets. Code is available at https://hkchengrex.github.io/C2OT