Workflow
低秩表示
icon
Search documents
为什么BF16的FlashAttention会把训练「炸掉」?清华首次给出机制解释,用极简改动稳住训练
机器之心· 2026-03-03 23:19
Core Insights - The article discusses the phenomenon of instability in low-precision training, particularly in BF16, and identifies that FlashAttention does not randomly produce bugs but triggers specific numerical biases under certain conditions, leading to loss explosions [1][4]. Group 1: Background and Importance - Low-precision training has become a necessity in the industry, with BF16/FP16 being widely used to enhance training efficiency, but this approach can lead to instability as precision approaches its limits [2][3]. - FlashAttention is a critical component for training long-context models, yet it has been associated with reproducible but unexplained failure cases, which have been reported over the years without a clear mechanism linking numerical errors to loss explosions [4]. Group 2: Research Methodology - The authors conducted a rigorous reproduction of failures using GPT-2, eliminating randomness by recording and replaying the same data batch sequences [6]. - They narrowed down the issue to specific layers and attention heads using spectral norm and other metrics, identifying that the instability stemmed from a particular intermediate quantity in FlashAttention's backpropagation [7]. Group 3: Mechanism of Failure - The article explains that similar low-rank structures can amplify numerical errors, turning them into persistent biases rather than mere noise, which leads to abnormal growth in weight updates and ultimately causes loss explosions [8][9]. - A critical observation was made regarding systematic biases in BF16, particularly when multiple identical maximum values appear in a score row, which can trigger dangerous conditions in subsequent calculations [13][18]. Group 4: Proposed Solutions - The authors suggest a straightforward fix: adjusting the safe softmax implementation to ensure that the maximum values in a row are strictly less than 1, which prevents the triggering of subsequent biases in BF16 accumulation [22][25]. - Experimental results demonstrated that using the modified FlashAttention allowed stable training without sudden loss explosions across various hardware setups [26]. Group 5: Broader Implications - The findings emphasize that low-precision errors should not be treated as random noise, as they can form systematic biases under specific distributions and discrete events [31]. - The article also highlights that model structures can amplify these biases, particularly through similar low-rank update directions in attention mechanisms, which facilitate the accumulation of errors in the same direction [31].