Felafax BlogTune Llama3 405B on AMD MI300x (perjalanan kami)
Pengenalan
- Seiring model open source makin besar, kebutuhan akan infrastruktur kuat untuk menangani pelatihan AI berskala besar juga meningkat
- Felafax melakukan fine-tuning model LLaMA 3.1 405B di GPU AMD untuk membuktikan efisiensi hardware AMD
- Seluruh pekerjaan dirilis sebagai open source di GitHub
- GPU AMD MI300X menawarkan performa tinggi dibanding hardware AI NVIDIA
- Proyek ini dimungkinkan berkat dukungan TensorWave
Apa itu JAX dan mengapa memilihnya
- JAX adalah library machine learning yang kuat yang menggabungkan API mirip NumPy, diferensiasi otomatis, dan compiler XLA milik Google
- JAX menyediakan API yang sangat baik untuk pemrosesan paralel model sehingga ideal untuk pelatihan model berskala besar
Kelebihan JAX
- Fungsi murni: JAX mendorong penulisan fungsi murni sehingga kode lebih mudah disusun, di-debug, dan dibaca
- Paralelisasi tingkat lanjut: API JIT JAX yang fleksibel mendukung paralelisasi data dan model tingkat lanjut yang penting untuk pelatihan skala besar
- Codebase yang bersih: filosofi desain JAX mendorong penulisan kode yang portabel antar platform hardware
Mengapa JAX unggul di hardware non-NVIDIA
- Pendekatan yang independen dari hardware: JAX memanfaatkan compiler XLA untuk mengompilasi komputasi ke representasi perantara yang independen dari hardware
- Optimisasi yang independen dari platform: compiler XLA melakukan optimisasi tanpa bergantung pada hardware
- Portabilitas yang mudah: dengan JAX, perubahan kode saat beralih dari NVIDIA ke AMD menjadi minimal
Menyiapkan JAX di GPU AMD
- Mengambil image Docker, menjalankan container, lalu memverifikasi instalasi
- Melatih model LLaMA 405B menggunakan 8 GPU AMD MI300x
Pelatihan LLaMA 405B: performa dan skalabilitas
- Melatih model LLaMA 405B di GPU AMD menggunakan JAX
- Melalui fine-tuning LoRA, bobot model dan parameter LoRA disesuaikan dengan presisi bfloat16
- Ukuran model: menggunakan sekitar 800GB VRAM
- Bobot LoRA dan status optimizer: menggunakan sekitar 400GB VRAM
- Total penggunaan VRAM: sekitar 1200GB
- Kecepatan pelatihan: sekitar 35 token per detik
- Efisiensi memori: dipertahankan sekitar 70%
- Skalabilitas: dengan JAX, penskalaan di 8 GPU hampir linear
Konfigurasi pelatihan kami
- Mengonversi LLaMA 3.1 dari PyTorch ke JAX
- Distribusi dilakukan secara efisien melalui pemuatan model dan sharding parameter
Sharding parameter di JAX
- Menggunakan fitur device mesh JAX untuk mendistribusikan model secara efisien ke 8 GPU AMD
- Mendefinisikan aturan sharding parameter untuk membagi dimensi tiap tensor sesuai sumbu mesh
Implementasi pelatihan LoRA
- LoRA mengurangi jumlah parameter yang dapat dilatih dengan memecah pembaruan bobot menjadi matriks berperingkat rendah
- Mengimplementasikan layer LoRADense yang mencakup parameter LoRA
- Mendistribusikan parameter LoRA secara efisien untuk mengoptimalkan penggunaan memori dan efisiensi komputasi
Kesimpulan
- Pengalaman melakukan fine-tuning model LLaMA 3.1 405B dengan GPU AMD dan JAX sangat positif
- Dengan memanfaatkan kemampuan paralelisasi kuat JAX dan pendekatan yang independen dari hardware, model dapat didistribusikan secara efisien
- Ini membuktikan bahwa GPU AMD adalah alternatif yang kuat untuk pelatihan AI berskala besar
- Seluruh kode dapat dilihat dan dijalankan langsung dari repositori GitHub
Ringkasan GN⁺
- Artikel ini menjelaskan cara melatih model AI berskala besar secara efisien menggunakan GPU AMD dan JAX
- Ditekankan bahwa hardware AMD adalah alternatif yang lebih hemat biaya dibanding NVIDIA
- Pendekatan JAX yang independen dari hardware meningkatkan portabilitas kode dan memudahkan pemeliharaan
- Menyediakan informasi berguna dan kode praktik bagi mereka yang tertarik pada pelatihan model skala besar
- Proyek dengan fungsi serupa mencakup CUDA dan PyTorch dari NVIDIA
1 komentar
Komentar Hacker News
Membagikan hasil fine-tuning model Llama3.1 405B menggunakan JAX pada 8x GPU AMD MI300x
Usulan untuk mengeksplorasi cara mengatasi keterbatasan memori dan menjalankan versi yang dikompilasi dengan JIT
Berbagi pengalaman terkait GPU AMD dan dukungan ROCm
Berbagi pengalaman bereksperimen pada sisi inferensi model 405B
torch.cudatidak seburuk iturocm:pytorchsemudah menggunakan containerrocm:jaxPertanyaan tentang tidak adanya data performa
Mempertanyakan mengapa Obsidian (aplikasi pencatat) melakukan hal ini
Meminta @dang untuk menyertakan nama pengguna di URL