2 poin oleh GN⁺ 2024-09-24 | 1 komentar | Bagikan ke WhatsApp

Felafax BlogTune Llama3 405B on AMD MI300x (perjalanan kami)

Pengenalan

  • Seiring model open source makin besar, kebutuhan akan infrastruktur kuat untuk menangani pelatihan AI berskala besar juga meningkat
  • Felafax melakukan fine-tuning model LLaMA 3.1 405B di GPU AMD untuk membuktikan efisiensi hardware AMD
  • Seluruh pekerjaan dirilis sebagai open source di GitHub
  • GPU AMD MI300X menawarkan performa tinggi dibanding hardware AI NVIDIA
  • Proyek ini dimungkinkan berkat dukungan TensorWave

Apa itu JAX dan mengapa memilihnya

  • JAX adalah library machine learning yang kuat yang menggabungkan API mirip NumPy, diferensiasi otomatis, dan compiler XLA milik Google
  • JAX menyediakan API yang sangat baik untuk pemrosesan paralel model sehingga ideal untuk pelatihan model berskala besar

Kelebihan JAX

  • Fungsi murni: JAX mendorong penulisan fungsi murni sehingga kode lebih mudah disusun, di-debug, dan dibaca
  • Paralelisasi tingkat lanjut: API JIT JAX yang fleksibel mendukung paralelisasi data dan model tingkat lanjut yang penting untuk pelatihan skala besar
  • Codebase yang bersih: filosofi desain JAX mendorong penulisan kode yang portabel antar platform hardware

Mengapa JAX unggul di hardware non-NVIDIA

  • Pendekatan yang independen dari hardware: JAX memanfaatkan compiler XLA untuk mengompilasi komputasi ke representasi perantara yang independen dari hardware
  • Optimisasi yang independen dari platform: compiler XLA melakukan optimisasi tanpa bergantung pada hardware
  • Portabilitas yang mudah: dengan JAX, perubahan kode saat beralih dari NVIDIA ke AMD menjadi minimal

Menyiapkan JAX di GPU AMD

  • Mengambil image Docker, menjalankan container, lalu memverifikasi instalasi
  • Melatih model LLaMA 405B menggunakan 8 GPU AMD MI300x

Pelatihan LLaMA 405B: performa dan skalabilitas

  • Melatih model LLaMA 405B di GPU AMD menggunakan JAX
  • Melalui fine-tuning LoRA, bobot model dan parameter LoRA disesuaikan dengan presisi bfloat16
  • Ukuran model: menggunakan sekitar 800GB VRAM
  • Bobot LoRA dan status optimizer: menggunakan sekitar 400GB VRAM
  • Total penggunaan VRAM: sekitar 1200GB
  • Kecepatan pelatihan: sekitar 35 token per detik
  • Efisiensi memori: dipertahankan sekitar 70%
  • Skalabilitas: dengan JAX, penskalaan di 8 GPU hampir linear

Konfigurasi pelatihan kami

  • Mengonversi LLaMA 3.1 dari PyTorch ke JAX
  • Distribusi dilakukan secara efisien melalui pemuatan model dan sharding parameter

Sharding parameter di JAX

  • Menggunakan fitur device mesh JAX untuk mendistribusikan model secara efisien ke 8 GPU AMD
  • Mendefinisikan aturan sharding parameter untuk membagi dimensi tiap tensor sesuai sumbu mesh

Implementasi pelatihan LoRA

  • LoRA mengurangi jumlah parameter yang dapat dilatih dengan memecah pembaruan bobot menjadi matriks berperingkat rendah
  • Mengimplementasikan layer LoRADense yang mencakup parameter LoRA
  • Mendistribusikan parameter LoRA secara efisien untuk mengoptimalkan penggunaan memori dan efisiensi komputasi

Kesimpulan

  • Pengalaman melakukan fine-tuning model LLaMA 3.1 405B dengan GPU AMD dan JAX sangat positif
  • Dengan memanfaatkan kemampuan paralelisasi kuat JAX dan pendekatan yang independen dari hardware, model dapat didistribusikan secara efisien
  • Ini membuktikan bahwa GPU AMD adalah alternatif yang kuat untuk pelatihan AI berskala besar
  • Seluruh kode dapat dilihat dan dijalankan langsung dari repositori GitHub

Ringkasan GN⁺

  • Artikel ini menjelaskan cara melatih model AI berskala besar secara efisien menggunakan GPU AMD dan JAX
  • Ditekankan bahwa hardware AMD adalah alternatif yang lebih hemat biaya dibanding NVIDIA
  • Pendekatan JAX yang independen dari hardware meningkatkan portabilitas kode dan memudahkan pemeliharaan
  • Menyediakan informasi berguna dan kode praktik bagi mereka yang tertarik pada pelatihan model skala besar
  • Proyek dengan fungsi serupa mencakup CUDA dan PyTorch dari NVIDIA

1 komentar

 
GN⁺ 2024-09-24
Komentar Hacker News
  • Membagikan hasil fine-tuning model Llama3.1 405B menggunakan JAX pada 8x GPU AMD MI300x

    • Mencapai performa yang sangat baik berkat API sharding tingkat lanjut dari JAX
    • Menyediakan tautan ke posting blog dan kode open source: tautan GitHub
    • Merupakan startup yang membangun infrastruktur AI untuk fine-tuning dan melayani LLM di TPU, AMD, dan Trainium, bukan pada hardware NVIDIA
    • Menilai banyak perusahaan mencoba menjalankan PyTorch di GPU AMD, tetapi itu adalah jalan yang sulit
    • PyTorch sangat terikat dengan ekosistem NVIDIA sehingga perlu banyak modifikasi agar bisa berjalan di hardware non-NVIDIA
    • Percaya bahwa JAX lebih cocok untuk hardware non-NVIDIA
    • Di JAX, kode model ML dikompilasi menjadi graf HLO yang independen dari hardware, lalu kompiler XLA melakukan optimasi spesifik hardware
    • Kode JAX yang sama dapat dijalankan di Google TPU dan GPU AMD tanpa perubahan
    • Strategi perusahaan adalah mem-porting model ke JAX dan memanfaatkan kernel XLA untuk mengekstrak performa maksimum dari backend non-NVIDIA
    • Mereka pertama kali mem-porting Llama 3.1 dari PyTorch ke JAX, dan kini model JAX yang sama berjalan baik di TPU maupun GPU AMD
    • Ingin mendengar pendapat tentang visi dan repositorinya
  • Usulan untuk mengeksplorasi cara mengatasi keterbatasan memori dan menjalankan versi yang dikompilasi dengan JIT

    • Hal itu kemungkinan dapat membawa peningkatan performa tambahan
  • Berbagi pengalaman terkait GPU AMD dan dukungan ROCm

    • Setahun lalu mencoba GPU AMD dan dukungan ROCm, tetapi merasa AMD masih jauh dari mengejar NVIDIA
    • Memilih JAX adalah pendekatan yang menarik, tetapi penasaran kesulitan apa yang dihadapi saat keluar dari PyTorch
  • Berbagi pengalaman bereksperimen pada sisi inferensi model 405B

    • Menganggap torch.cuda tidak seburuk itu
    • Menilai ini hanya masalah penamaan karena versi AMD dari PyTorch menerjemahkannya
    • Menggunakan container rocm:pytorch semudah menggunakan container rocm:jax
    • Menunjukkan bahwa belum banyak data performa yang dipublikasikan
    • Penasaran dengan angka MFU (tingkat utilisasi model)
  • Pertanyaan tentang tidak adanya data performa

    • Mempertanyakan kemungkinan mengekstrak nilai dari pemesanan GPU AMD dalam jumlah besar
    • Mendapat kesan bahwa jawabannya adalah "tidak"
  • Mempertanyakan mengapa Obsidian (aplikasi pencatat) melakukan hal ini

    • Awalnya mengira ini adalah posting dari Obsidian
    • Mempertanyakan mengapa masih belum dibedakan antara GitHub.com dan GitHub.io
  • Meminta @dang untuk menyertakan nama pengguna di URL

    • Posting ini membahas blog buatan pengguna, bukan Obsidian itu sendiri