< Explain other AI papers

Cut Your Losses in Large-Vocabulary Language Models

Erik Wijmans, Brody Huval, Alexander Hertzberg, Vladlen Koltun, Philipp Krähenbühl

2024-11-15

Cut Your Losses in Large-Vocabulary Language Models

Summary

This paper discusses a new method called Cut Cross-Entropy (CCE) that helps large language models (LLMs) reduce the amount of memory they use during training while still learning effectively.

What's the problem?

As language models get bigger, they also need to handle larger vocabularies, which can make training them very memory-intensive. A significant amount of memory is used in a specific part of the training process called cross-entropy loss computation. This can lead to inefficient use of memory, making it hard to train these models effectively.

What's the solution?

The authors propose CCE, which changes how cross-entropy loss is calculated. Instead of calculating the logits (the raw scores for each word in the vocabulary) for all tokens and storing them in memory, CCE only calculates the logit for the correct token and evaluates the necessary values on the fly. This approach drastically reduces memory usage from 24 GB to just 1 MB for the loss computation. They also implemented a custom method to skip unnecessary calculations, further improving efficiency without slowing down training.

Why it matters?

This research is important because it allows researchers and developers to train larger language models more efficiently. By reducing memory requirements, CCE makes it feasible to work with more complex models, which can lead to better performance in various applications such as natural language processing and artificial intelligence.

Abstract

As language models grow ever larger, so do their vocabularies. This has shifted the memory footprint of LLMs during training disproportionately to one single layer: the cross-entropy in the loss computation. Cross-entropy builds up a logit matrix with entries for each pair of input tokens and vocabulary items and, for small models, consumes an order of magnitude more memory than the rest of the LLM combined. We propose Cut Cross-Entropy (CCE), a method that computes the cross-entropy loss without materializing the logits for all tokens into global memory. Rather, CCE only computes the logit for the correct token and evaluates the log-sum-exp over all logits on the fly. We implement a custom kernel that performs the matrix multiplications and the log-sum-exp reduction over the vocabulary in flash memory, making global memory consumption for the cross-entropy computation negligible. This has a dramatic effect. Taking the Gemma 2 (2B) model as an example, CCE reduces the memory footprint of the loss computation from 24 GB to 1 MB, and the total training-time memory consumption of the classifier head from 28 GB to 1 GB. To improve the throughput of CCE, we leverage the inherent sparsity of softmax and propose to skip elements of the gradient computation that have a negligible (i.e., below numerical precision) contribution to the gradient. Experiments demonstrate that the dramatic reduction in memory consumption is accomplished without sacrificing training speed or convergence.