2 poin oleh GN⁺ 2024-09-24 | 1 komentar | Bagikan ke WhatsApp

Felafax BlogTune Llama3 405B on AMD MI300x (perjalanan kami)

Pengenalan

  • Seiring model open source makin besar, kebutuhan akan infrastruktur kuat untuk menangani pelatihan AI berskala besar juga meningkat
  • Felafax melakukan fine-tuning model LLaMA 3.1 405B di GPU AMD untuk membuktikan efisiensi hardware AMD
  • Seluruh pekerjaan dirilis sebagai open source di GitHub
  • GPU AMD MI300X menawarkan performa tinggi dibanding hardware AI NVIDIA
  • Proyek ini dimungkinkan berkat dukungan TensorWave

Apa itu JAX dan mengapa memilihnya

  • JAX adalah library machine learning yang kuat yang menggabungkan API mirip NumPy, diferensiasi otomatis, dan compiler XLA milik Google
  • JAX menyediakan API yang sangat baik untuk pemrosesan paralel model sehingga ideal untuk pelatihan model berskala besar

Kelebihan JAX

  • Fungsi murni: JAX mendorong penulisan fungsi murni sehingga kode lebih mudah disusun, di-debug, dan dibaca
  • Paralelisasi tingkat lanjut: API JIT JAX yang fleksibel mendukung paralelisasi data dan model tingkat lanjut yang penting untuk pelatihan skala besar
  • Codebase yang bersih: filosofi desain JAX mendorong penulisan kode yang portabel antar platform hardware

Mengapa JAX unggul di hardware non-NVIDIA

  • Pendekatan yang independen dari hardware: JAX memanfaatkan compiler XLA untuk mengompilasi komputasi ke representasi perantara yang independen dari hardware
  • Optimisasi yang independen dari platform: compiler XLA melakukan optimisasi tanpa bergantung pada hardware
  • Portabilitas yang mudah: dengan JAX, perubahan kode saat beralih dari NVIDIA ke AMD menjadi minimal

Menyiapkan JAX di GPU AMD

  • Mengambil image Docker, menjalankan container, lalu memverifikasi instalasi
  • Melatih model LLaMA 405B menggunakan 8 GPU AMD MI300x

Pelatihan LLaMA 405B: performa dan skalabilitas

  • Melatih model LLaMA 405B di GPU AMD menggunakan JAX
  • Melalui fine-tuning LoRA, bobot model dan parameter LoRA disesuaikan dengan presisi bfloat16
  • Ukuran model: menggunakan sekitar 800GB VRAM
  • Bobot LoRA dan status optimizer: menggunakan sekitar 400GB VRAM
  • Total penggunaan VRAM: sekitar 1200GB
  • Kecepatan pelatihan: sekitar 35 token per detik
  • Efisiensi memori: dipertahankan sekitar 70%
  • Skalabilitas: dengan JAX, penskalaan di 8 GPU hampir linear

Konfigurasi pelatihan kami

  • Mengonversi LLaMA 3.1 dari PyTorch ke JAX
  • Distribusi dilakukan secara efisien melalui pemuatan model dan sharding parameter

Sharding parameter di JAX

  • Menggunakan fitur device mesh JAX untuk mendistribusikan model secara efisien ke 8 GPU AMD
  • Mendefinisikan aturan sharding parameter untuk membagi dimensi tiap tensor sesuai sumbu mesh

Implementasi pelatihan LoRA

  • LoRA mengurangi jumlah parameter yang dapat dilatih dengan memecah pembaruan bobot menjadi matriks berperingkat rendah
  • Mengimplementasikan layer LoRADense yang mencakup parameter LoRA
  • Mendistribusikan parameter LoRA secara efisien untuk mengoptimalkan penggunaan memori dan efisiensi komputasi

Kesimpulan

  • Pengalaman melakukan fine-tuning model LLaMA 3.1 405B dengan GPU AMD dan JAX sangat positif
  • Dengan memanfaatkan kemampuan paralelisasi kuat JAX dan pendekatan yang independen dari hardware, model dapat didistribusikan secara efisien
  • Ini membuktikan bahwa GPU AMD adalah alternatif yang kuat untuk pelatihan AI berskala besar
  • Seluruh kode dapat dilihat dan dijalankan langsung dari repositori GitHub

Ringkasan GN⁺

  • Artikel ini menjelaskan cara melatih model AI berskala besar secara efisien menggunakan GPU AMD dan JAX
  • Ditekankan bahwa hardware AMD adalah alternatif yang lebih hemat biaya dibanding NVIDIA
  • Pendekatan JAX yang independen dari hardware meningkatkan portabilitas kode dan memudahkan pemeliharaan
  • Menyediakan informasi berguna dan kode praktik bagi mereka yang tertarik pada pelatihan model skala besar
  • Proyek dengan fungsi serupa mencakup CUDA dan PyTorch dari NVIDIA

1 komentar

 
GN⁺ 2024-09-24
Komentar Hacker News
  • Baru-baru ini kami melakukan fine-tuning model llama3.1 405B di 8xAMD MI300x GPU dengan JAX, bukan PyTorch
    Berkat API sharding tingkat lanjut dari JAX, performanya bagus, dan teknik sharding yang dipakai sudah kami rangkum di blog. Kodenya juga dibuka: https://github.com/felafax/felafax
    Kami adalah startup kecil yang membangun infrastruktur AI untuk fine-tuning dan serving LLM di hardware non-NVIDIA (TPU, AMD, Trainium)
    Banyak perusahaan mencoba menjalankan PyTorch di GPU AMD, tetapi PyTorch sangat terkait erat dengan ekosistem NVIDIA, seperti torch.cuda atau scaled_dot_product_attention, jadi menurut kami perlu banyak “de-NVIDIA-isasi”
    Menurut kami JAX lebih cocok untuk hardware non-NVIDIA karena kode model dikompilasi menjadi graf HLO yang independen dari hardware, lalu compiler XLA mengoptimalkannya dan menerapkan optimasi spesifik hardware. Kode JAX LLaMA3 yang sama berjalan tanpa modifikasi di Google TPU dan GPU AMD
    Strategi perusahaan kami adalah mem-port model ke JAX terlebih dahulu, lalu memanfaatkan framework JAX dan kernel XLA untuk menarik performa maksimal dari backend non-NVIDIA. Jadi kami pertama-tama memindahkan Llama 3.1 dari PyTorch ke JAX, dan model JAX yang sama berjalan baik di TPU dan GPU AMD

    • Tidak ada masalah berarti menjalankan PyTorch di GPU AMD tanpa mengubah kode CUDA. Blog MosaicML juga layak dijadikan referensi: https://www.databricks.com/blog/training-llms-scale-amd-mi25...
    • Saya penasaran bagaimana kalian memverifikasi akurasi porting JAX untuk Llama 3.1
      Secara pribadi, alasan utama saya memakai PyTorch adalah karena model aslinya dibuat dengan PyTorch. Walau logikanya tampak sama di berbagai versi model, pada skala data yang sangat besar, galat floating-point yang sangat kecil bisa terakumulasi dan menyebabkan model drift
      Men-debug ketidakcocokan akurasi seperti ini pada model besar rasanya lebih menyiksa daripada lingkaran neraka ke-10
    • Saya penasaran apakah JAX punya implementasi sendiri untuk perkalian matriks atau FlashAttention, atau memakai implementasi ROCm seperti PyTorch. Misalnya hipblaslt, Composable Kernel FA, dan semacamnya
      Saya tidak terlalu paham JAX, tetapi menurut saya sebagian besar alasan performa training PyTorch di MI300x sangat buruk adalah karena performa library ROCm yang dipakai di dalamnya lambat
    • Saya penasaran apakah ini juga berjalan di kartu konsumen seperti 7900 XTX
      Yang saya maksud berjalan di sini bukan kondisi menghabiskan 2 minggu mengurus driver lalu setelah itu server tidak bisa di-update lagi selamanya
    • Kalau ini migrasi, saya penasaran apakah ada angka nyata yang dibandingkan dengan versi PyTorch untuk model yang sama. Tabel perbandingan di tulisan itu terlihat lebih ke aspek teknis
      Saya juga penasaran dengan masalah teknis yang ditemui
  • Agar jelas, performa ini cukup buruk. Sepertinya karena kompilasi belum berhasil dibuat berjalan dengan benar
    Pada model 405B, hasilnya 35 token/detik, yang setara sekitar 85 teraflops. 8 GPU MI300x berada di kisaran 10,4 petaflops, jadi MFU-nya sekitar 0,8%
    Itu 40–50 kali lebih rendah daripada performa training yang layak, yaitu MFU 30–40%, jadi dari sisi AMD mereka mungkin berharap bottleneck-nya ada di software stack

    • Saya juga ingin menanyakan hal yang persis sama
      Halaman GitHub mengatakan “bisa melakukan tuning LLaMa3.1 di Google Cloud TPU dengan biaya 30% lebih rendah”, tetapi tidak menyebutkan performanya
  • Kerja yang bagus. Sekitar setahun lalu saya sempat sedikit mencoba GPU AMD dan dukungan ROCm, dan jelas AMD masih punya jalan panjang untuk mengejar Nvidia
    Pendekatan memilih JAX ini menarik, tetapi saya penasaran kesulitan apa saja yang muncul ketika beralih dari PyTorch, yang nyaris menjadi library standar machine learning

    • Beberapa minggu lalu kami memposting Show HN yang menjelaskan perjalanan kami: https://news.ycombinator.com/item?id=41512142
      Awalnya tujuan kami adalah melakukan fine-tuning LLaMA 3 di TPU, tetapi PyTorch XLA terasa kaku, jadi kami memutuskan menulis ulang modelnya dengan JAX
      Seperti disebutkan sebelumnya, kami melihat JAX sebagai platform yang lebih baik untuk GPU non-NVIDIA, dan ingin membangun infrastruktur untuk GPU non-NVIDIA di atas JAX+openXLA
    • Saya belum berhasil menjalankan AMD ROCm di sistem Debian 12 saya, jadi sepertinya Ollama memakai CPU, bukan GPU. Kelihatannya masih jauh jalannya
  • Kerja bagus. Akhir pekan lalu saya juga sedang mengutak-atik sisi inference untuk 405B [0]
    Saya tidak yakin torch.cuda seburuk itu. PyTorch untuk AMD mengonversinya sebagai pengganti. Ini lebih seperti masalah penamaan daripada masalah mendasar
    Dalam praktiknya, menarik container rocm:pytorch semudah menarik container rocm:jax
    Tidak banyak angka yang dipublikasikan, jadi saya penasaran berapa MFU yang didapat
    [0] https://x.com/HotAisle/status/1837580046732874026

    • Bagus
      MFU harus dihitung. Detail GPU dan VRAM bisa dilihat di repositori: https://dub.sh/amd-405b-res
      Akhir pekan depan kami berencana mencoba ulang training run, melakukan JIT compile untuk seluruh langkah training, dan menghitung MFU saat itu
  • Saat kami mengukurnya di ZML, MI300X 30% lebih cepat daripada H100. Chip yang hebat

  • Saya penasaran apakah ada penyedia cloud yang menyewakan host 8xAMD MI300
    Untuk pekerjaan saya banyak memakai AWS, dan saya ingin mencoba GPU AMD

    • Sebagai referensi, perusahaan kami menyewakan 8xMI300x, jadi silakan hubungi kami
    • Oracle menyediakannya. Kemungkinan besar yang lain akan menyusul, tetapi menurut saya penyedia kecil mungkin lebih masuk akal untuk diajak berurusan
  • Di mana data performanya?

    • Kami menambahkan data utilisasi GPU dan VRAM ke repositori GitHub: https://github.com/felafax/felafax?tab=readme-ov-file#amd-40...
      Karena keterbatasan kode dan VRAM, kami belum bisa menjalankan versi JIT-compiled dari model 405B. Bagian ini perlu diteliti lebih lanjut
      Seluruh training run dilakukan dalam mode eager execution JAX, jadi masih ada banyak ruang untuk peningkatan performa
      Bahkan dalam mode eager execution, utilisasi GPU secara umum sekitar 30–40%, yang cukup bagus. Dengan JIT, menurut kami utilisasi GPU bisa dengan mudah naik ke 50–60%
  • Jika memungkinkan, menarik untuk mengeksplorasi cara mengatasi batasan memori dan menjalankan versi JIT-compiled. Itu bisa menghasilkan peningkatan performa tambahan

    • Setuju. Masih banyak performa yang bisa diperas
      Kami membutuhkan langkah training yang di-JIT compile, data loading dan sharding yang lebih dioptimalkan, gradient accumulation, serta activation checkpointing
      Kami akan terus membangun dan menerapkan semua peningkatan itu, lalu segera menulis blog lagi
  • Saya penasaran apakah AMD sudah sedikit saja lebih dekat untuk mengekstrak nilai di sini lewat pesanan GPU massal dan kekurangan pasokan
    Kesan saya lebih ke “belum”

    • Saya paham sindirannya. Namun pada titik ini, kalau tidak ingin menyerahkan seluruh hardware dan software AI ke satu pemasok tunggal, kita harus mulai bergerak menuju alternatif
      Pihak lawan punya keunggulan awal yang sangat besar, dan jelas masih banyak pekerjaan di sisi software. Ini butuh waktu
  • Kenapa Obsidian, aplikasi catatan itu, melakukan hal ini?

    • Bukan begitu. Perusahaan ini memakai Obsidian Publish untuk menerbitkan dokumen mereka