New AMD Paper Overturns Conventional Wisdom: FP4 Training Instability's Cause Is Not Insufficient Randomness
AMD's new research challenges the conventional understanding of FP4 training instability. While reducing precision from FP8 to FP4 promises doubled computational throughput and is supported by new hardware like NVIDIA Blackwell and AMD MI350 series, training large language models natively with FP4 has been notoriously unstable, often attributed to insufficient stochasticity.
The paper "Pretraining large language models with MXFP4 on Native FP4 Hardware" demonstrates successful end-to-end FP4 pre-training of Llama 3.1-8B on AMD MI355X GPUs using the MXFP4 format, achieving a 9-10% overall speedup over FP8. Crucially, it identifies the root cause of instability: not randomness, but the accumulation of *structural micro-scaling errors* along the sensitive weight gradient (Wgrad) path.
Through controlled experiments, researchers found that quantizing the Wgrad operation to FP4 caused significant convergence degradation. Counterintuitively, common stochasticity-based mitigation techniques like stochastic rounding and randomized Hadamard transforms worsened performance. In contrast, applying a *deterministic* Hadamard transform successfully stabilized training by ensuring consistent error patterns, reducing the extra token cost from 26-27% to just 8-9%.
This work has significant implications: 1) It provides a clear diagnostic for low-precision training instability, steering focus towards structural errors. 2) It pushes FP4 from a primarily inference-focused format into the realm of viable training. 3) It leverages the open OCP Microscaling (MX) standard, promoting cross-vendor compatibility. The research marks a critical step towards more economical large model training by further pushing the boundaries of low-precision computation.
marsbit14 h fa