Cara Men-scale Model Anda: Perspektif Sistem terhadap LLM di TPU
(jax-ml.github.io)- Mengoptimalkan performa deep learning dalam skala besar sering tampak seperti “alkimia”, tetapi pada praktiknya efisiensi model dapat ditingkatkan dengan prinsip-prinsip sederhana yang bisa dipahami
- Dari satu akselerator hingga puluhan ribu akselerator, prinsip yang relatif sederhana berlaku di mana-mana, dan dengan memahaminya kita bisa melakukan hal-hal berguna seperti berikut:
- Memperkirakan secara kasar seberapa dekat tiap bagian model dengan nilai optimal teoretis
- Menyusun dasar untuk memilih berbagai teknik paralelisasi pada beragam skala
- Memperkirakan biaya dan waktu yang diperlukan untuk melatih dan menjalankan model Transformer besar
- Merancang algoritme yang memanfaatkan karakteristik perangkat keras tertentu
- Merancang perangkat keras dengan memahami secara jelas batas performa algoritme saat ini
- Pengetahuan latar yang dibutuhkan
- Perlu memahami konsep dasar LLM dan arsitektur Transformer
- Pemahaman tentang cara operasi berskala besar bukan keharusan
- Akan lebih baik jika memiliki pengetahuan dasar pelatihan LLM dan pengalaman menggunakan JAX
- Disarankan merujuk ke posting blog tentang arsitektur Transformer dan slide tentang scaling LLM di JAX
- Tujuan
- Mengembangkan kemampuan untuk memperkirakan cara terbaik memparalelkan model pada perangkat keras yang tersedia
- Mengembangkan kemampuan untuk menghitung secara kasar waktu dan biaya pelatihan maupun inferensi
Mengapa ini penting
- Bahkan 3~4 tahun lalu, sebagian besar peneliti ML tidak perlu memahami optimasi skala besar seperti ini secara mendalam
- Kini, bahkan model yang “kecil” pun berjalan mendekati batas perangkat keras, sehingga memahami cara kerja efisien pada skala besar menjadi hal esensial
- Sejarah ML dapat dilihat sebagai alur perkembangan silang antara inovasi sistem dan peningkatan perangkat lunak
- Karena model Transformer belakangan ini menggunakan perangkat keras hingga batasnya, tanpa memahami efisiensi model, arsitektur atau riset baru berisiko gagal saat diterapkan di dunia nyata
- Bahkan jika benchmark menunjukkan peningkatan performa 20%, bila efisiensi perangkat keras turun 20%, pada akhirnya kegunaannya menjadi rendah
- Tujuan inti scaling model adalah membuat throughput meningkat secara linear ketika jumlah chip (akselerator) ditambah
- Ini disebut "strong scaling"
- Menambah chip mengurangi waktu komputasi tetapi menimbulkan biaya komunikasi antar-chip
- Jika komunikasi memakan waktu lebih lama daripada komputasi, sistem menjadi berada dalam kondisi "communication bound" sehingga strong scaling tidak mungkin dicapai
- Jika kita cukup memahami perangkat keras untuk memprediksi di mana bottleneck seperti ini akan muncul, kita bisa merancang atau menyusun ulang model untuk mencegahnya
- Tujuan buku ini adalah menjelaskan cara kerja perangkat keras TPU (dan GPU) serta bagaimana arsitektur Transformer berkembang agar berjalan baik di perangkat keras saat ini
- Diharapkan ini bermanfaat baik bagi peneliti yang merancang arsitektur baru maupun engineer yang berusaha menjalankan LLM generasi sekarang dengan cepat
Gambaran umum
- Tulisan ini disusun sebagai berikut
- Bagian 1 menjelaskan faktor-faktor yang menentukan batas performa model (komunikasi, komputasi, memori) melalui analisis roofline
- Bagian 2, Bagian 3 membahas struktur internal TPU dan GPU serta cara koneksi antar-chip
- Melalui ini, tulisan menjawab pertanyaan seperti berikut
- Seberapa cepat perkalian matriks berukuran tertentu secara teoretis dapat dijalankan
- Pada titik mana komputasi menjadi dibatasi oleh bandwidth memori atau bandwidth komunikasi
- Dengan struktur seperti apa klaster TPU dihubungkan, dan kira-kira berapa lama waktu yang dibutuhkan untuk memindahkan data dari satu chip ke chip lain
- Bagaimana mengalikan matriks terdistribusi secara efisien
- Melalui ini, tulisan menjawab pertanyaan seperti berikut
- Bagian 4 membahas secara rinci rumus-rumus arsitektur Transformer (ukuran matriks, jumlah parameter, FLOPs)
- Bagian 5 dan Bagian 7 adalah bagian inti, yang memperkenalkan berbagai cara memparalelkan model ke banyak chip
- Data parallel, Tensor parallel, Pipeline parallel, Expert parallel
- Juga membahas teknik penghematan memori seperti ZeRO, Rematerialisation, Host offload, Gradient accumulation
- Bagian 6, Bagian 8 menggunakan contoh pelatihan dan inferensi model LLaMA-3 di TPU untuk menunjukkan biaya, waktu, dan konfigurasi nyata
- Terakhir, Bagian 9, Bagian 10 membahas cara praktis melakukan profiling, debugging, dan menerapkan pemrosesan paralel pada model di JAX
Rincian lebih lanjut: ringkasan bagian-bagian utama buku
-
Bagian 1: Preliminaries
-
Bagian 1: Pengantar singkat analisis Roofline
- Tiga faktor yang membatasi algoritme: komputasi, komunikasi, memori
- Dari sini dipelajari cara memperkirakan batas atas kecepatan komputasi
-
- Bagaimana TPU melakukan komputasi
- Apa itu struktur systolic array
- Pemahaman dasar tentang bagaimana TPU menyediakan bandwidth memori dan komunikasi
-
Bagian 3: Matriks terdistribusi dan perkalian terdistribusi
- Teknik menyimpan parameter model dengan membaginya ke banyak chip (sharding)
- Cara menangani komunikasi dan bottleneck yang muncul saat operasi matriks terdistribusi
-
-
Bagian 2: Transformers
-
Bagian 4: Ringkasan rumus Transformer yang diperlukan
- Bentuk konkret perkalian matriks dalam Transformer
- Cara menghitung jumlah parameter, FLOPs, ukuran KV cache, dan sebagainya
- Memahami seberapa besar komputasi yang dibutuhkan operasi Attention dibanding blok Feed-Forward
-
Bagian 5: Strategi paralelisasi untuk pelatihan Transformer
- Pengenalan teknik Data parallel, Tensor parallel, Pipeline parallel, Expert parallel
- Opsi penghematan memori seperti ZeRO(FSDP), Rematerialisation, Gradient accumulation, Host offload
- Pembentukan konsep untuk menyusun paralelisasi sesuai ukuran model dan jumlah chip tertentu
-
Bagian 6: Penerapan pelatihan LLaMA 3 di TPU
- Estimasi waktu dan biaya jika diasumsikan melatih model LLaMA 3 di lingkungan TPU nyata
- Memberikan contoh konkret terkait batch size, metode paralelisasi, penggunaan memori, dan lain-lain
-
Bagian 7: Semua tentang inferensi Transformer
- Dalam inferensi, latensi muncul sebagai faktor penting yang baru
- Penggunaan memori dan masalah komunikasi akibat KV cache dan sebagainya
- Diskusi tentang bagaimana mengalokasikan dan menghubungkan banyak chip untuk model serving
-
Bagian 8: Penerapan serving LLaMA 3 di TPU
- Analisis perkiraan biaya, latensi, dan trade-off throughput saat diasumsikan melakukan serving LLaMA 3 di TPU v5e
-
-
Bagian 3: Tutorial Praktis
-
Bagian 9: Cara profiling kode TPU
- Memahami stack JAX+XLA
- Mengidentifikasi isu penurunan performa nyata dan solusinya
- Cara menggunakan profiler JAX/TensorBoard
-
Bagian 10: Memprogram TPU dengan JAX
- Cara memanfaatkan API paralelisasi JAX (primitives)
- Mempelajari konsep komputasi paralel melalui contoh dan latihan
-
Bagian 11: Kesimpulan dan materi tambahan
- Bacaan lanjutan tentang TPU dan LLM
- Menutup keseluruhan isi secara singkat sambil menyinggung prospek ke depan
-
1 komentar
Komentar Hacker News