FlashAttention-3: Attention yang lebih cepat dan akurat dengan asinkroni dan presisi rendah
(together.ai)-
Pentingnya Attention
- Attention adalah lapisan inti dalam arsitektur Transformer dan menjadi bottleneck pada model bahasa besar serta aplikasi dengan konteks panjang.
- FlashAttention dan FlashAttention-2 memelopori pendekatan untuk mempercepat Attention di GPU dengan meminimalkan operasi baca/tulis memori.
- Hal ini membuat panjang konteks LLM meningkat secara signifikan.
-
Teknologi utama FlashAttention-3
- Memanfaatkan asinkroni: Memanfaatkan sifat asinkron Tensor Cores dan TMA untuk menumpangtindihkan seluruh komputasi dan perpindahan data.
- Operasi per blok: Menjalankan perkalian matriks dan operasi softmax secara bergantian per blok.
- Pemrosesan presisi rendah: Meningkatkan performa dengan dukungan presisi rendah FP8.
-
Peningkatan performa FlashAttention-3
- Efisiensi pemanfaatan GPU: Memanfaatkan hingga 75% performa maksimum GPU H100, sehingga 1,5-2 kali lebih cepat dibanding versi sebelumnya.
- Performa presisi rendah: Menggunakan FP8 untuk meningkatkan kecepatan pemrosesan dan mengurangi penggunaan memori.
- Pemrosesan konteks panjang: Mempercepat mekanisme Attention sehingga teks yang lebih panjang dapat diproses dengan efisien.
-
Ringkasan FlashAttention
- FlashAttention menyusun ulang komputasi Attention dan memanfaatkan tiling serta rekalkulasi untuk secara signifikan meningkatkan kecepatan dan mengurangi penggunaan memori.
- Melalui tiling, blok input dimuat, Attention dijalankan pada blok tersebut, lalu output diperbarui.
- Jumlah operasi baca/tulis memori dikurangi dengan tidak menuliskan matriks Attention antara ke memori.
-
Fitur hardware baru pada GPU Hopper
- WGMMA: Memberikan throughput tinggi dengan memanfaatkan Tensor Cores baru.
- TMA: Unit hardware yang mempercepat transfer data antara global memory dan shared memory.
- Presisi rendah FP8: Menggandakan throughput Tensor Core dengan menggunakan FP8.
-
Asinkroni: menumpangtindihkan GEMM dan Softmax
- Kebutuhan overlap: Memaksimalkan performa dengan menjalankan GEMM dan softmax secara paralel.
- Penjadwalan ping-pong: Dua kelompok warp secara bergantian menjalankan GEMM dan softmax untuk meningkatkan performa.
- Overlap dalam kelompok warp: Meningkatkan throughput dengan menjalankan GEMM dan softmax secara paralel dalam kelompok warp yang sama.
-
Presisi rendah: pemrosesan incoherent untuk mengurangi error kuantisasi
- Pemrosesan incoherent: Mengurangi error kuantisasi dengan menggunakan transformasi Hadamard.
- Hasil eksperimen: Pemrosesan incoherent mengurangi error kuantisasi hingga 2,6 kali.
-
Benchmark Attention
- FP16: Sekitar 1,6-1,8 kali lebih cepat dibanding FlashAttention-2.
- FP8: Mencapai hingga 1,2 PFLOPS.
Ringkasan GN⁺
- FlashAttention-3 secara signifikan meningkatkan performa mekanisme Attention dengan memanfaatkan fitur hardware baru pada GPU.
- Teknologi ini dapat memproses konteks panjang secara efisien sehingga memaksimalkan performa model bahasa besar.
- Karena kemungkinan besar akan diintegrasikan ke framework utama seperti PyTorch, dampaknya terhadap riset dan aplikasi AI ke depan diperkirakan besar.
- Proyek lain yang menyediakan fungsi serupa antara lain Triton dan cuDNN.
1 komentar
Komentar Hacker News
Tampaknya Tri Dao mulai mengerjakan FA3 sejak April 2022
Ada yang penasaran seberapa bergantung algoritme Flash Attention pada hardware
Ada yang penasaran apakah compiler bisa menemukan sendiri optimisasi seperti FlashAttention
Siapa pun yang ingin melakukan port ke ROCm/AMD MI300x diminta untuk menghubungi
TMA (Tensor Memory Accelerator) adalah unit hardware yang mempercepat transfer data antara memori global dan memori bersama
FlashAttention-3 dioptimalkan untuk GPU Hopper (misalnya H100)
Disebutkan bahwa fungsi aktivasi seperti sigmoid sangat lambat pada LLM modern
Ada yang penasaran mengapa Flash Attention 5 kali lebih lambat saat ada variable masking dibanding saat tidak ada
Ada yang penasaran apakah FlashAttention bisa menggantikan operasi attention pada LLM
Membutuhkan hardware yang mahal