Fine-tuning Llama 405B dengan GPU AMD
(publish.obsidian.md)Felafax BlogTune Llama3 405B on AMD MI300x (perjalanan kami)
Pengenalan
- Seiring model open source makin besar, kebutuhan akan infrastruktur kuat untuk menangani pelatihan AI berskala besar juga meningkat
- Felafax melakukan fine-tuning model LLaMA 3.1 405B di GPU AMD untuk membuktikan efisiensi hardware AMD
- Seluruh pekerjaan dirilis sebagai open source di GitHub
- GPU AMD MI300X menawarkan performa tinggi dibanding hardware AI NVIDIA
- Proyek ini dimungkinkan berkat dukungan TensorWave
Apa itu JAX dan mengapa memilihnya
- JAX adalah library machine learning yang kuat yang menggabungkan API mirip NumPy, diferensiasi otomatis, dan compiler XLA milik Google
- JAX menyediakan API yang sangat baik untuk pemrosesan paralel model sehingga ideal untuk pelatihan model berskala besar
Kelebihan JAX
- Fungsi murni: JAX mendorong penulisan fungsi murni sehingga kode lebih mudah disusun, di-debug, dan dibaca
- Paralelisasi tingkat lanjut: API JIT JAX yang fleksibel mendukung paralelisasi data dan model tingkat lanjut yang penting untuk pelatihan skala besar
- Codebase yang bersih: filosofi desain JAX mendorong penulisan kode yang portabel antar platform hardware
Mengapa JAX unggul di hardware non-NVIDIA
- Pendekatan yang independen dari hardware: JAX memanfaatkan compiler XLA untuk mengompilasi komputasi ke representasi perantara yang independen dari hardware
- Optimisasi yang independen dari platform: compiler XLA melakukan optimisasi tanpa bergantung pada hardware
- Portabilitas yang mudah: dengan JAX, perubahan kode saat beralih dari NVIDIA ke AMD menjadi minimal
Menyiapkan JAX di GPU AMD
- Mengambil image Docker, menjalankan container, lalu memverifikasi instalasi
- Melatih model LLaMA 405B menggunakan 8 GPU AMD MI300x
Pelatihan LLaMA 405B: performa dan skalabilitas
- Melatih model LLaMA 405B di GPU AMD menggunakan JAX
- Melalui fine-tuning LoRA, bobot model dan parameter LoRA disesuaikan dengan presisi bfloat16
- Ukuran model: menggunakan sekitar 800GB VRAM
- Bobot LoRA dan status optimizer: menggunakan sekitar 400GB VRAM
- Total penggunaan VRAM: sekitar 1200GB
- Kecepatan pelatihan: sekitar 35 token per detik
- Efisiensi memori: dipertahankan sekitar 70%
- Skalabilitas: dengan JAX, penskalaan di 8 GPU hampir linear
Konfigurasi pelatihan kami
- Mengonversi LLaMA 3.1 dari PyTorch ke JAX
- Distribusi dilakukan secara efisien melalui pemuatan model dan sharding parameter
Sharding parameter di JAX
- Menggunakan fitur device mesh JAX untuk mendistribusikan model secara efisien ke 8 GPU AMD
- Mendefinisikan aturan sharding parameter untuk membagi dimensi tiap tensor sesuai sumbu mesh
Implementasi pelatihan LoRA
- LoRA mengurangi jumlah parameter yang dapat dilatih dengan memecah pembaruan bobot menjadi matriks berperingkat rendah
- Mengimplementasikan layer LoRADense yang mencakup parameter LoRA
- Mendistribusikan parameter LoRA secara efisien untuk mengoptimalkan penggunaan memori dan efisiensi komputasi
Kesimpulan
- Pengalaman melakukan fine-tuning model LLaMA 3.1 405B dengan GPU AMD dan JAX sangat positif
- Dengan memanfaatkan kemampuan paralelisasi kuat JAX dan pendekatan yang independen dari hardware, model dapat didistribusikan secara efisien
- Ini membuktikan bahwa GPU AMD adalah alternatif yang kuat untuk pelatihan AI berskala besar
- Seluruh kode dapat dilihat dan dijalankan langsung dari repositori GitHub
Ringkasan GN⁺
- Artikel ini menjelaskan cara melatih model AI berskala besar secara efisien menggunakan GPU AMD dan JAX
- Ditekankan bahwa hardware AMD adalah alternatif yang lebih hemat biaya dibanding NVIDIA
- Pendekatan JAX yang independen dari hardware meningkatkan portabilitas kode dan memudahkan pemeliharaan
- Menyediakan informasi berguna dan kode praktik bagi mereka yang tertarik pada pelatihan model skala besar
- Proyek dengan fungsi serupa mencakup CUDA dan PyTorch dari NVIDIA
1 komentar
Komentar Hacker News
Baru-baru ini kami melakukan fine-tuning model llama3.1 405B di 8xAMD MI300x GPU dengan JAX, bukan PyTorch
Berkat API sharding tingkat lanjut dari JAX, performanya bagus, dan teknik sharding yang dipakai sudah kami rangkum di blog. Kodenya juga dibuka: https://github.com/felafax/felafax
Kami adalah startup kecil yang membangun infrastruktur AI untuk fine-tuning dan serving LLM di hardware non-NVIDIA (TPU, AMD, Trainium)
Banyak perusahaan mencoba menjalankan PyTorch di GPU AMD, tetapi PyTorch sangat terkait erat dengan ekosistem NVIDIA, seperti
torch.cudaatauscaled_dot_product_attention, jadi menurut kami perlu banyak “de-NVIDIA-isasi”Menurut kami JAX lebih cocok untuk hardware non-NVIDIA karena kode model dikompilasi menjadi graf HLO yang independen dari hardware, lalu compiler XLA mengoptimalkannya dan menerapkan optimasi spesifik hardware. Kode JAX LLaMA3 yang sama berjalan tanpa modifikasi di Google TPU dan GPU AMD
Strategi perusahaan kami adalah mem-port model ke JAX terlebih dahulu, lalu memanfaatkan framework JAX dan kernel XLA untuk menarik performa maksimal dari backend non-NVIDIA. Jadi kami pertama-tama memindahkan Llama 3.1 dari PyTorch ke JAX, dan model JAX yang sama berjalan baik di TPU dan GPU AMD
Secara pribadi, alasan utama saya memakai PyTorch adalah karena model aslinya dibuat dengan PyTorch. Walau logikanya tampak sama di berbagai versi model, pada skala data yang sangat besar, galat floating-point yang sangat kecil bisa terakumulasi dan menyebabkan model drift
Men-debug ketidakcocokan akurasi seperti ini pada model besar rasanya lebih menyiksa daripada lingkaran neraka ke-10
hipblaslt, Composable Kernel FA, dan semacamnyaSaya tidak terlalu paham JAX, tetapi menurut saya sebagian besar alasan performa training PyTorch di MI300x sangat buruk adalah karena performa library ROCm yang dipakai di dalamnya lambat
Yang saya maksud berjalan di sini bukan kondisi menghabiskan 2 minggu mengurus driver lalu setelah itu server tidak bisa di-update lagi selamanya
Saya juga penasaran dengan masalah teknis yang ditemui
Agar jelas, performa ini cukup buruk. Sepertinya karena kompilasi belum berhasil dibuat berjalan dengan benar
Pada model 405B, hasilnya 35 token/detik, yang setara sekitar 85 teraflops. 8 GPU MI300x berada di kisaran 10,4 petaflops, jadi MFU-nya sekitar 0,8%
Itu 40–50 kali lebih rendah daripada performa training yang layak, yaitu MFU 30–40%, jadi dari sisi AMD mereka mungkin berharap bottleneck-nya ada di software stack
Halaman GitHub mengatakan “bisa melakukan tuning LLaMa3.1 di Google Cloud TPU dengan biaya 30% lebih rendah”, tetapi tidak menyebutkan performanya
Kerja yang bagus. Sekitar setahun lalu saya sempat sedikit mencoba GPU AMD dan dukungan ROCm, dan jelas AMD masih punya jalan panjang untuk mengejar Nvidia
Pendekatan memilih JAX ini menarik, tetapi saya penasaran kesulitan apa saja yang muncul ketika beralih dari PyTorch, yang nyaris menjadi library standar machine learning
Awalnya tujuan kami adalah melakukan fine-tuning LLaMA 3 di TPU, tetapi PyTorch XLA terasa kaku, jadi kami memutuskan menulis ulang modelnya dengan JAX
Seperti disebutkan sebelumnya, kami melihat JAX sebagai platform yang lebih baik untuk GPU non-NVIDIA, dan ingin membangun infrastruktur untuk GPU non-NVIDIA di atas JAX+openXLA
Kerja bagus. Akhir pekan lalu saya juga sedang mengutak-atik sisi inference untuk 405B [0]
Saya tidak yakin
torch.cudaseburuk itu. PyTorch untuk AMD mengonversinya sebagai pengganti. Ini lebih seperti masalah penamaan daripada masalah mendasarDalam praktiknya, menarik container
rocm:pytorchsemudah menarik containerrocm:jaxTidak banyak angka yang dipublikasikan, jadi saya penasaran berapa MFU yang didapat
[0] https://x.com/HotAisle/status/1837580046732874026
MFU harus dihitung. Detail GPU dan VRAM bisa dilihat di repositori: https://dub.sh/amd-405b-res
Akhir pekan depan kami berencana mencoba ulang training run, melakukan JIT compile untuk seluruh langkah training, dan menghitung MFU saat itu
Saat kami mengukurnya di ZML, MI300X 30% lebih cepat daripada H100. Chip yang hebat
Saya penasaran apakah ada penyedia cloud yang menyewakan host 8xAMD MI300
Untuk pekerjaan saya banyak memakai AWS, dan saya ingin mencoba GPU AMD
Di mana data performanya?
Karena keterbatasan kode dan VRAM, kami belum bisa menjalankan versi JIT-compiled dari model 405B. Bagian ini perlu diteliti lebih lanjut
Seluruh training run dilakukan dalam mode eager execution JAX, jadi masih ada banyak ruang untuk peningkatan performa
Bahkan dalam mode eager execution, utilisasi GPU secara umum sekitar 30–40%, yang cukup bagus. Dengan JIT, menurut kami utilisasi GPU bisa dengan mudah naik ke 50–60%
Jika memungkinkan, menarik untuk mengeksplorasi cara mengatasi batasan memori dan menjalankan versi JIT-compiled. Itu bisa menghasilkan peningkatan performa tambahan
Kami membutuhkan langkah training yang di-JIT compile, data loading dan sharding yang lebih dioptimalkan, gradient accumulation, serta activation checkpointing
Kami akan terus membangun dan menerapkan semua peningkatan itu, lalu segera menulis blog lagi
Saya penasaran apakah AMD sudah sedikit saja lebih dekat untuk mengekstrak nilai di sini lewat pesanan GPU massal dan kekurangan pasokan
Kesan saya lebih ke “belum”
Pihak lawan punya keunggulan awal yang sangat besar, dan jelas masih banyak pekerjaan di sisi software. Ini butuh waktu
Kenapa Obsidian, aplikasi catatan itu, melakukan hal ini?