14 poin oleh xguru 2024-08-19 | 8 komentar | Bagikan ke WhatsApp
  • Alasan PyTorch menyebabkan hilangnya produktivitas dan membuang waktu pengembangan adalah karena "bukan framework-nya yang buruk, melainkan karena framework itu tidak dirancang untuk use case yang diterapkan saat ini"

Filosofi PyTorch

  • Filosofi PyTorch adalah dinamis, mudah di-debug, dan Pythonic
  • Sebaliknya, TensorFlow 1.x berupaya menjadi framework yang statis namun berkinerja baik dengan sangat mengandalkan kompiler XLA
  • Para pengembang TensorFlow menyadari bahwa komunitas tidak menyukai API 1.x, lalu memutuskan menggunakan Keras sebagai antarmuka utama dan mengurangi peran kompiler XLA
  • PyTorch mempertahankan akarnya, dan tidak seperti pendekatan TensorFlow yang statis dan tertunda, ia mengadopsi pendekatan "eager execution" yang lebih dinamis, di mana torch.Tensor dievaluasi segera
  • Hal ini membuahkan hasil, dan banyak riset berpindah ke PyTorch
  • Pada 2021, ketika GPT-3 muncul, performa dan skalabilitas menjadi perhatian utama
  • PyTorch merespons kebutuhan ini dengan cukup baik sampai titik tertentu, tetapi karena tidak dirancang dengan filosofi tersebut, utangnya makin menumpuk dan fondasinya mulai goyah
  • Para pengembang PyTorch tidak menginginkan kompromi apa pun dan memilih mengejar dua jalur sekaligus
    • Menggunakan kompiler XLA sebagai backend dasar dengan performa dan stabilitas tinggi
    • Membangun stack torch.compile agar pengguna bebas memanggil kompiler saat diperlukan
  • Tidak adanya strategi jangka panjang adalah masalah serius
  • PyTorch tidak ingin berkomitmen pada filosofi yang berpusat pada kompiler seperti JAX, tetapi juga tidak terlihat ada alternatif yang bagus
  • Lalu, bagaimana solusi produk pesaing untuk masalah ini?

Pengembangan berbasis kompiler di JAX

  • JAX memanfaatkan XLA, stack kompiler kuat dari TensorFlow
  • XLA adalah kompiler yang kuat, tetapi semuanya diabstraksikan dari end user
  • Selama fungsi bersifat pure, fungsi itu bisa di-JIT compile dengan dekorator @jax.jit agar dapat digunakan di XLA
  • XLA menangani semuanya di balik layar: memverifikasi apakah graph yang dihasilkan benar, partisioner GSPMD yang menangani paralelisasi otomatis dengan sharding di JAX, optimisasi graph, fusi operator dan kernel, penjadwalan untuk menyembunyikan latensi, overlap komunikasi asinkron, pembuatan kode untuk backend lain seperti triton, dan lainnya
  • Selama mematuhi batasan JAX, XLA akan menanganinya secara otomatis
  • Misalnya, saat melakukan paralelisasi, tidak diperlukan primitive komunikasi seperti torch.distributed.barrier()
  • Dukungan DDP dimungkinkan dengan kode yang sederhana
  • Pendekatan XLA adalah bahwa komputasi mengikuti sharding. Jadi, jika array input di-shard sepanjang sumbu tertentu, XLA akan menanganinya secara otomatis untuk subkomputasi
  • Ide "pengembangan berbasis kompiler" ini mirip dengan cara kerja kompiler Rust
  • Keterbatasan PyTorch
    • Ada ketidakpuasan terhadap pilihan pengembang PyTorch yang mengintegrasikan dan bergantung pada stack kompiler untuk fitur baru alih-alih mempertahankan filosofi inti fleksibilitas dan kebebasan
    • Menurut roadmap resmi PyTorch 2.x, ada rencana jangka panjang yang jelas untuk mengintegrasikan XLA sepenuhnya ke dalam Torch
    • Ini adalah ide yang buruk. Ibarat mengatakan bahwa memaksa kode C++ masuk ke kompiler Rust akan menjadi pengalaman yang lebih baik daripada memakai Rust itu sendiri
    • Tidak seperti JAX, Torch tidak dirancang berpusat pada XLA
    • Jika PyTorch memutuskan menggunakan stack kompiler berbasis XLA, bukankah framework yang ideal justru yang memang dirancang dan dibangun khusus di sekelilingnya?
    • Bahkan jika PyTorch mengejar pendekatan "multi-backend" yang memungkinkan memilih backend kompiler sesuka hati, bukankah itu akan memperburuk fragmentasi dan merusak API habis-habisan sambil berusaha menghormati batasan semua stack kompiler?
    • Siapa pun yang pernah memakai Torch/XLA di TPU pasti menderita PTSD yang serius

Multi-Backend itu gagal

  • PyTorch gagal total karena mencoba melakukan semuanya sekaligus
  • Keputusan desain "multi-backend" memperburuk masalah ini secara eksponensial
  • Secara teori terdengar seperti bisa memilih stack yang diinginkan, tetapi dalam praktiknya itu adalah kekacauan traceback yang sulit dipahami dan masalah inkompatibilitas
  • Benturan antara batasan antar-backend dan API PyTorch
    • Bukan hanya sulit membuat backend-backend ini berfungsi, tetapi batasan yang mereka harapkan tidak cocok dengan API PyTorch yang fleksibel dan Pythonic
    • Ada trade-off antara menjaga konsistensi API dan mematuhi batasan backend
    • Akibatnya, para pengembang cenderung lebih bergantung pada code generation alih-alih benar-benar mengintegrasikan/berkomitmen pada satu backend
  • Tidak adanya strategi di PyTorch
    • Karena PyTorch menolak trade-off yang bermakna, setiap keputusan terasa seperti kompromi
    • Tidak ada konsistensi, juga tidak ada strategi menyeluruh
    • Pada akhirnya, ini menyebabkan banyak frustrasi bagi pengguna dan terasa seperti tumpukan fitur yang tidak cocok satu sama lain
    • Tidak ada cara yang lebih cepat untuk membunuh ekosistem
  • Mengapa tidak boleh mengikuti pendekatan JAX
    • PyTorch tidak seharusnya mengikuti pendekatan JAX yang "kompiler dan backend terintegrasi"
    • Karena JAX memang secara eksplisit dirancang untuk bekerja bersama XLA
    • Mengganti frontend PyTorch dengan milik JAX bukanlah strategi
    • Hampir mustahil merancang API yang lebih baik dari JAX di atas dasar XLA
    • Tidak menyalahkan para pengembang karena mencoba ide-ide baru yang berbeda
    • Tetapi jika PyTorch ingin bertahan dalam ujian waktu, ia harus lebih fokus memperkuat fondasinya daripada menghadirkan fitur-fitur baru keren yang langsung runtuh di luar kondisi tutorial yang ideal

Fragmentasi PyTorch dan pemrograman fungsional JAX

  • API fungsional JAX
    • Fungsi JAX harus pure, artinya tidak boleh ada efek samping global
    • Seperti fungsi matematika, jika diberi data yang sama, ia harus selalu mengembalikan output yang sama terlepas dari konteks eksekusinya
    • Berkat filosofi desain ini, fungsi-fungsi JAX dapat dikomposisikan dan saling interoperabel dengan baik
    • Kompleksitas pengembangan berkurang, dan fungsi didefinisikan sebagai signature tertentu dengan pekerjaan konkret yang terdefinisi jelas
    • Selama tipe terpenuhi, fungsi dijamin langsung bekerja
    • Ini cocok untuk jenis pekerjaan yang dibutuhkan dalam komputasi ilmiah, khususnya deep learning
  • Contoh API optax
    • Berkat pendekatan fungsional, optax memiliki sesuatu yang disebut "chain"
    • Ini mencakup beberapa fungsi yang diterapkan berurutan pada gradient
    • Komponen dasarnya adalah GradientTransformation
    • Ini menghasilkan API yang kuat namun ekspresif
    • Misalnya, hal seperti clipping gradient, mengambil EMA dari gradient, atau menggabungkan optimizer menjadi sangat sederhana
  • Keunggulan desain fungsional
    • Hasil menarik lain dari desain fungsional adalah vmap
    • Ini berarti map yang "vectorized", dan namanya menjelaskan fungsi itu dengan tepat
    • Semuanya bisa di-map, dan selama itu vmap, XLA akan otomatis menggabungkan dan mengoptimalkannya
    • Saat menulis fungsi, tidak perlu memikirkan dimensi batch
    • Cukup vmap semua kode
    • Ini berarti kebutuhan akan operasi ein-* menjadi lebih sedikit
    • Memahami manipulasi tensor 2D/3D menjadi lebih intuitif dan keterbacaannya juga jauh lebih baik
    • Karena cukup mengisolasi komponen individual lalu menalarinya, menjadi lebih mudah menulis kode kompleks yang benar-benar bekerja
    • Selama mematuhi batasan kemurnian dan memiliki signature yang benar, semua keuntungan lain seperti komposabilitas bisa dinikmati
  • Masalah ekosistem PyTorch
    • Di torch, apa pun stack yang digunakan (FSDP + multi-node + torch.compile dan sebagainya), selalu ada kemungkinan sesuatu rusak
    • Banyak hal harus bekerja bersama dengan benar, dan jika satu komponen saja gagal, Anda harus debugging sampai jam 3 pagi
    • Karena tidak mungkin menguji semua kombinasi dari puluhan fitur yang ditawarkan PyTorch, akan selalu ada bug yang tidak ditemukan selama pengembangan
    • Mustahil menulis kode yang bekerja dengan baik tanpa usaha yang besar
    • Ekosistem torch menjadi sangat bengkak dan penuh bug
    • Karena tidak ada abstraksi bersama, muncullah library dan framework baru yang tidak dirancang untuk berinteraksi dengan "solusi" lain
    • Ini segera berubah menjadi kekacauan dependensi dan requirements.txt
    • 70-80% issue GitHub atau diskusi forum semata-mata terjadi karena error muncul dari library yang berbeda-beda
    • Hampir tidak ada cara untuk memperbaikinya
  • Tidak adanya solusi
    • Ini adalah masalah OOP dan desain
    • Terpikir bahwa objek dasar yang sangat khas PyTorch seperti PyTree mungkin dapat membantu membangun fondasi bersama untuk abstraksi
    • Juga tidak bisa mengadopsi paradigma pemrograman fungsional
    • Jika dilakukan, hasilnya akan menjadi versi JAX yang lebih buruk performanya sambil merusak kompatibilitas mundur seluruh codebase torch yang ada
    • PyTorch tampak benar-benar kacau di bagian ini

Keunggulan JAX dalam reproduksibilitas

  • Penanganan seed
    • Penanganan seed di PyTorch tidak ideal
    • Biasanya perlu menjalankan beberapa baris kode
    • Mudah terlupa atau salah konfigurasi
    • JAX memaksa pembuatan key eksplisit dan mengoperkannya ke semua fungsi yang memerlukan random
    • Pendekatan ini sepenuhnya menghilangkan masalah karena RNG selalu di-seed secara statis
    • JAX punya versi NumPy sendiri (jax.numpy), jadi tidak perlu mengatur seed secara terpisah
    • Keputusan QoL kecil seperti ini bisa membuat pengalaman pengguna seluruh framework jauh lebih baik
  • Portabilitas
    • Salah satu masalah terbesar saat memakai codebase PyTorch adalah kurangnya portabilitas
    • Codebase yang ditulis untuk CUDA/GPU tidak bekerja dengan baik saat dijalankan di hardware non-Nvidia seperti TPU, NPU, AMD GPU, dan lain-lain
    • Sulit mem-porting kode PyTorch yang ditulis untuk 1 node ke multi-node
    • Multi-node sering kali membutuhkan puluhan jam waktu pengembangan dan perubahan kode yang signifikan
    • Pendekatan JAX yang berpusat pada kompiler punya keunggulan di sini
    • XLA menangani perpindahan antar-backend perangkat, dan dengan perubahan kode minimal dapat bekerja baik di GPU/TPU/multi-node/multi-slice
    • Ini memudahkan vendor hardware untuk mendukung perangkat mereka dan mempermudah perpindahan antar-perangkat
    • Tidak semua orang punya akses ke hardware yang sama, jadi codebase yang portabel di berbagai jenis hardware bisa menjadi langkah kecil untuk membuat deep learning lebih mudah diakses oleh pemula dan tingkat menengah
  • Auto-scaling
    • Codebase yang bisa melakukan auto-scaling dengan baik sendiri sangat membantu reproduksi
    • Dalam kondisi ideal, ini harus terjadi otomatis dengan perubahan kode minimal, tanpa terikat oleh batas jaringan
    • JAX melakukannya dengan baik
    • Saat menulis kode JAX, tidak perlu menentukan primitive komunikasi atau menaruh torch.distributed.barrier() di mana-mana
    • XLA akan otomatis menyisipkannya dengan mempertimbangkan hardware yang tersedia
    • Semua perangkat yang bisa dideteksi JAX akan otomatis digunakan, terlepas dari networking, topologi, konfigurasi, dan sebagainya
    • Ia otomatis menyinkronkan dan menyiapkan komputasi serta menerapkan optimization pass untuk memaksimalkan eksekusi kernel asinkron dan meminimalkan latensi
    • Yang perlu dilakukan manusia hanyalah menentukan sharding tensor yang ingin didistribusikan ke perangkat, seperti dimensi batch dari array input
    • Karena pendekatan XLA bahwa "komputasi mengikuti sharding", sisanya akan dipahami secara otomatis
    • Eksperimen yang tervalidasi pada skala bisa dijalankan dengan mudah sebagai hobi untuk bereksperimen dan mungkin diulang
    • Ini bisa mempermudah penemuan kembali ide-ide yang terlupakan, dan mendorong eksperimen semacam itu karena dapat dengan mudah diuji sebagai fungsi pada skala yang lebih besar dengan usaha minimal

Kekurangan JAX

  • Struktur tata kelola
    • Saat ini XLA berada di bawah tata kelola TensorFlow
    • Pernah ada diskusi tentang pembentukan badan organisasi terpisah yang mirip dengan PyTorch, tetapi belum banyak upaya konkret yang dilakukan
    • Tingkat kepercayaan terhadap Google tidak tinggi karena reputasinya menghentikan produk yang tidak populer
    • Secara teknis JAX adalah proyek DeepMind dan memiliki arti penting bagi dorongan AI Google secara keseluruhan, tetapi tampaknya akan memberi manfaat besar bagi seluruh ekosistem dalam jangka panjang
    • Badan tata kelola terpisah akan memberi arahan bagi pengembangan proyek
    • Ini akan memberi struktur yang konkret dan terpisah dari birokrasi Google yang terkenal rumit, sehingga bisa menghindari banyak masalah sekaligus
    • Bukan berarti JAX mutlak memerlukan struktur formal seperti ini, tetapi akan bagus jika ada jaminan bahwa pengembangan JAX akan terus berlangsung lama terlepas dari keputusan manajemen puncak Google
    • Ini jelas akan membantu adopsi di perusahaan dan laboratorium riset besar yang ragu menginvestasikan sumber daya untuk mengintegrasikan alat yang suatu hari nanti bisa saja tidak lagi dipelihara
  • Transisi open source XLA
    • Untuk waktu yang lama, XLA adalah proyek closed source
    • Namun, telah dilakukan upaya untuk menjadikannya open source, dan saat ini OpenXLA menunjukkan performa yang jauh lebih baik daripada build XLA internal
    • Namun dokumentasi tentang bagian dalam XLA masih kurang
    • Sebagian besar sumber daya hanya berupa live talk dan sesekali paper, dan sering kali sudah usang
    • Jika ada roadmap yang bisa diakses publik tentang fitur yang direncanakan, orang akan lebih mudah melacak kemajuan dan terutama berkontribusi pada hal-hal yang menarik
    • Akan baik jika ada mini blog post bergaya Edward Yang yang membedah tiap tahap stack kompiler XLA dan menjelaskan detailnya, sehingga praktisi bisa lebih baik menilai apa yang bisa dan tidak bisa dilakukan XLA
    • Ini memang membutuhkan banyak sumber daya dan mungkin lebih baik disalurkan ke tempat lain, tetapi orang akan lebih percaya pada alat yang mereka pahami, dan diyakini ini memberi efek limpahan positif ke seluruh ekosistem sehingga menguntungkan semua pihak
  • Integrasi ekosistem
    • flax adalah sumber frustrasi di ekosistem JAX
    • API-nya tidak intuitif, sintaksnya ringkas, dan bagi pemula yang pindah dari PyTorch, ini benar-benar seperti neraka
    • Disarankan memakai equinox
    • Ada upaya dari tim pengembang untuk mengatasi kekurangan flax, tetapi pada akhirnya itu buang-buang waktu
    • Jika menginginkan API bergaya equinox, lebih baik langsung pakai equinox
    • Tidak banyak hal yang secara khusus dilakukan flax dengan lebih baik dan itu pun tidak sulit direplikasi dengan equinox
    • Saat ini banyak bagian ekosistem JAX dirancang berpusat pada flax
    • Karena equinox pada dasarnya berinteraksi dengan PyTree, ia interoperabel dengan semua library, meski membutuhkan sedikit eqx.partition dan filter
    • Ingin mengubah status quo. equinox harus mendapat dukungan kelas satu di mana-mana
    • Ini memang pendapat yang kontroversial, tetapi ini adalah contoh klasik sesat pikir sunk cost
    • equinox bekerja lebih baik dengan cara yang seharusnya selalu dimiliki framework JAX
    • Seperti yang dirangkum dalam dokumentasi equinox, jika membandingkan equinox dan flax, equinox lebih baik
    • Bagus jika pengelola ekosistem JAX mengakui popularitas equinox dan menyesuaikan diri, tetapi juga berharap ada dukungan resmi yang lebih besar dari Google dan tim flax
    • Jika ingin mencoba JAX, lebih baik gunakan equinox
  • Sudut-sudut tajam
    • Karena keputusan desain API dan batasan XLA, JAX memiliki beberapa "sudut tajam" yang perlu diwaspadai
    • Hal ini dijelaskan dengan sangat ringkas dalam dokumentasi yang ditulis dengan baik
    • Sebaiknya dibaca setidaknya sekali sebelum menggunakan JAX
    • Seperti biasa, melakukan RTFM akan menghemat banyak waktu dan energi

Kesimpulan

  • Postingan blog ini bertujuan meluruskan mitos yang sering diulang bahwa PyTorch paling cocok untuk workload riset nyata, khususnya di GPU. Itu tidak lagi benar
  • Bahkan sampai berani berpendapat bahwa mem-porting semua kode PyTorch ke JAX akan sangat bermanfaat bagi seluruh bidang ini
    • Paralelisasi otomatis, reproduksibilitas, API fungsional yang bersih, dan sebagainya bukanlah fitur sepele, dan akan sangat membantu banyak codebase riset
  • Jika Anda ingin membuat bidang ini sedikit lebih baik, pertimbangkan untuk menulis ulang codebase Anda ke JAX

8 komentar

 
xguru 2024-08-25
 
hilft 2024-08-21

Saya akan bertahan dengan torch dan onnx.

 
flrngel 2024-08-21

Tulisan yang ditulis mahasiswa S1.. gila

 
cosine20 2024-08-21

PyTorch benar-benar tamat kalau bukan karena Hugging Face wkwk

 
lemonmint 2024-08-19

Hidup JAX! Saya baru mencobanya belakangan ini, dan saya sangat menyukai API NNX.

 
stareta1202 2024-08-19

Masalah terbesar JAX adalah fakta bahwa ini dari Google. Google sangat terkenal suka meninggalkan open source (Tflite, android things, dart, angular, bazel, dll.); bahkan tensorflow pun sejak suatu titik mulai tidak terlalu terbarui dengan baik. Sebaliknya, torch berawal dari Facebook, yang mengelola ekosistem open source besar, dikelola dengan sangat baik, dan sekarang sudah berada di bawah yayasan torch. Kekurangan torch memang jelas ada benarnya, tetapi dalam hal siapa yang akan mengelola open source tersebut secara berkelanjutan, rasanya JAX sudah memulai dengan risiko besar.

 
dalinaum 2024-08-20

Setidaknya Dart tampaknya masih akan baik-baik saja untuk sementara waktu berkat Flutter.

 
ilotoki0804 2024-08-20

Facebook tampaknya tetap setia(?) dan terus berkontribusi pada tech stack yang mereka gunakan sendiri, seperti React dan Django, tetapi Google rasanya sedikit saja sesuatu jadi usang langsung dibuang begitu saja seperti kain lap...