Researchers from Stanford University have developed a new technique called FlashAttention-2 that can significantly speed up training of large Transformer models on longer sequences. In a paper published on ArXiv, the authors demonstrate up to 2x faster training compared to previous state-of-the-art approaches like FlashAttention.
Transformers have become the dominant architecture for natural language processing tasks, but a key limitation is their quadratic scaling in memory and compute requirements with sequence length. This has constrained most Transformer models like GPT-3 to sequences of 2048 tokens or less. FlashAttention-2 aims to push this boundary by optimizing memory access patterns and parallelism when computing the critical self-attention mechanism in Transformers.
The main innovations in FlashAttention-2 are:
- Tweaking the algorithm to reduce non-matrix multiplication operations, which are much slower on GPUs. This increases the proportion of compute time spent on fast matrix multiplies.
- Parallelizing attention calculation across sequence length in addition to batch size and number of heads. This improves GPU utilization for long sequences.
- Optimized work partitioning within each GPU thread block to reduce unnecessary memory access.
In benchmarks on A100 GPUs, FlashAttention-2 achieves up to 2x speedup over the original FlashAttention and up to 10x over standard PyTorch implementations. When used to train GPT-style models, it enables throughput over 225 TFLOPs per GPU for sequence lengths up to 8k tokens.
The researchers highlight that FlashAttention-2 makes it economically feasible to train models on much longer sequences than before. For example, one could train a model on 16k tokens for the same cost previously required for 8k tokens. This could enable Transformer models that understand whole books, articles, images or videos instead of just short snippets.
Beyond cost, there are also accuracy benefits from longer context. Recent models like GPT-3 and Claude have demonstrated significant gains in reasoning and common sense capabilities when trained on sequences of 32-100k tokens. FlashAttention-2 makes these gains more accessible to researchers.
Looking ahead, the team plans to optimize FlashAttention-2 further for new hardware like Nvidia’s H100 GPUs. They also want to combine these low-level optimizations with algorithmic improvements to attention that have been proposed. The end goal is removing the context length bottleneck entirely for Transformers and enabling models that can reason about arbitrarily long documents or data.