< Explain other AI papers

Scalify: scale propagation for efficient low-precision LLM training

Paul Balança, Sam Hosegood, Carlo Luschi, Andrew Fitzgibbon

2024-07-25

Scalify: scale propagation for efficient low-precision LLM training

Summary

This paper presents Scalify, a new method designed to make training large language models (LLMs) more efficient by using low-precision formats like float8. It aims to simplify the process of achieving high accuracy while reducing the computational resources needed.

What's the problem?

Training large language models typically requires a lot of computing power and memory, especially when using high-precision formats. Although low-precision formats can help reduce these demands, the techniques used to implement them are often complicated and can lead to inaccuracies. This makes it challenging for researchers to adopt low-precision training methods effectively.

What's the solution?

Scalify introduces an end-to-end scale propagation approach that streamlines the process of using low-precision formats in LLM training. It generalizes existing methods for scaling tensors, allowing for efficient matrix multiplication and gradient representation in float8 format. Scalify also supports storing optimizer states in float16 format, making it easier to train models without losing accuracy. The researchers tested Scalify and found that it successfully supports these operations out of the box, simplifying the training process.

Why it matters?

This research is important because it helps make training large language models more accessible and efficient. By reducing the complexity associated with low-precision training, Scalify can enable more researchers and developers to use advanced AI techniques without needing expensive hardware or extensive technical knowledge. This could lead to faster advancements in AI technology and broader applications across various fields.

Abstract

Low-precision formats such as float8 have been introduced in machine learning accelerated hardware to improve computational efficiency for large language models training and inference. Nevertheless, adoption by the ML community has been slowed down by the complex, and sometimes brittle, techniques required to match higher precision training accuracy. In this work, we present Scalify, a end-to-end scale propagation paradigm for computational graphs, generalizing and formalizing existing tensor scaling methods. Experiment results show that Scalify supports out-of-the-box float8 matrix multiplication and gradients representation, as well as float16 optimizer state storage. Our JAX implementation of Scalify is open-sourced at https://github.com/graphcore-research/jax-scalify