AlphaFold Berilustrasi
(elanapearl.github.io)- AlphaFold3 berupaya memprediksi hanya dari sekuens kompleks yang berisi protein, asam nukleat, dan molekul kecil, melampaui protein tunggal; karena itu representasi input dan tokenisasinya jauh lebih kompleks dibanding AF2
- Input dibagi menjadi representasi single/pair pada level token, representasi pada level atom, MSA, dan template; asam amino serta nukleotida standar diperlakukan sebagai 1 token, sedangkan residu nonstandar dan molekul lain diperlakukan sebagai 1 token per atom
- Trunk pembelajaran representasi secara berulang memperbaiki representasi single s dan representasi pair z melalui modul template, modul MSA, Pairformer, pair-bias attention, operasi triangle, dan recycling
- Prediksi struktur menggunakan model difusi bersyarat atas koordinat atom, bukan Invariant Point Attention milik AF2, serta menghasilkan pembaruan koordinat semua atom melalui augmentasi rotasi/translasi dan denoising
- Pelatihan menggabungkan distogram, diffusion, dan confidence loss, lalu melalui cross-distillation yang memanfaatkan hasil AF2 dan AF-Multimer, model juga mempelajari kembali representasi unfolded pada area berkepercayaan rendah
Cakupan input AlphaFold3 dan keseluruhan pipeline
- Tujuan AlphaFold3 bukan hanya memprediksi sekuens protein individual seperti AF2, atau hanya menangani kompleks protein seperti AF-Multimer, melainkan memprediksi hanya dari sekuens struktur tempat protein berikatan dengan protein lain, asam nukleat, dan molekul kecil secara opsional
- Makna “token” berbeda bergantung pada jenis input
- Protein: 1 asam amino standar adalah 1 token
- DNA/RNA: 1 nukleotida standar adalah 1 token
- Asam amino/nukleotida nonstandar: 1 atom adalah 1 token
- Molekul lainnya: 1 atom adalah 1 token
- Protein yang terdiri dari 35 asam amino standar sebenarnya bisa memiliki lebih dari 600 atom, tetapi direpresentasikan sebagai 35 token; ligand dengan 35 atom direpresentasikan sebagai 35 token
- Model secara garis besar terdiri dari tiga tahap
- Input Preparation: mengubah sekuens input pengguna serta sekuens/struktur terkait hasil pencarian menjadi tensor numerik
- Representation Learning: memperbarui representasi single dan pair dengan berbagai variasi attention
- Structure Prediction: memprediksi struktur melalui difusi bersyarat
- Kompleks protein terutama disimpan dalam dua representasi
- single representation: merepresentasikan semua token dalam kompleks itu sendiri
- pair representation: merepresentasikan relasi seperti jarak dan potensi interaksi di antara semua pasangan token
- Dimensi channel utama adalah
c_z=128,c_m=64,c_atom=128,c_atompair=16,c_token=768,c_s=384
Persiapan input: proses mengubah sekuens menjadi 6 tensor
- Input yang diberikan pengguna diubah menjadi 6 tensor yang masuk ke trunk model
- s: token-level single representation
- z: token-level pair representation
- q: atom-level single representation
- p: atom-level pair representation
- m: MSA representation
- t: template representation
-
Pencarian MSA dan template
- AF3 mencari sekuens serupa untuk sekuens protein dan RNA, menyusunnya sebagai MSA, dan menyertakan struktur terkait sebagai template
- MSA menyelaraskan sekuens protein serupa yang ditemukan di berbagai spesies, sehingga memberi model pola konservasi pada posisi tertentu dan korelasi perubahan antarposisi yang berbeda
- Struktur protein serupa yang sudah diketahui digunakan untuk memperkirakan struktur protein query, seperti dalam homology modeling
- Pencarian tidak melibatkan pelatihan, dan digunakan metode berbasis HMM
- Beberapa basis data protein/RNA dicari dengan
jackhmmer,HHBlits,nhmmer, dan sekuens serupa di Protein Data Bank dicari denganhmmsearch - Ukuran MSA dibatasi menjadi
N_MSA < 2^14karena kompleksitas komputasi - Pada setiap chain protein, struktur berkualitas tinggi dipilih, lalu maksimal 4 disampling sebagai template
- Dibandingkan AF-Multimer, elemen pencarian yang baru ditambahkan adalah sekuens RNA juga dimasukkan sebagai target pencarian
-
Cara merepresentasikan template
- Dari struktur 3D template, jarak Euklides antara setiap pasangan token dihitung
- Token yang memiliki banyak atom menggunakan “center atom” representatif
- Asam amino: atom
Cα - Nukleotida standar: atom
C1'
- Asam amino: atom
- Nilai jarak didiskretisasi sebagai distogram, bukan nilai kontinu
- 38 bin dari 3,15Å hingga 50,75Å
- 1 bin tambahan untuk jarak yang lebih besar dari itu
- Ke dalam distogram ditambahkan informasi chain, apakah token tersebut resolved dalam crystal structure, dan informasi local distance di dalam masing-masing asam amino
- Template matrix di-masking agar hanya melihat jarak di dalam chain yang sama, dan tidak berupaya mendapatkan informasi inter-chain interaction dari pemilihan template
Representasi Tingkat Atom dan Atom Transformer
-
Conformer referensi dan representasi tingkat atom
- Untuk membuat representasi single tingkat atom q, dihitung conformer referensi untuk setiap asam amino, nukleotida, dan ligand
- Conformer adalah susunan atom 3D dari suatu molekul yang dihasilkan dengan mengambil sampel rotasi di sekitar ikatan tunggal
- Asam amino standar menggunakan conformer berenergi rendah yang bisa diperoleh melalui lookup, sedangkan molekul kecil menghasilkan conformer 3D dengan RDKit’s ETKDGv3
- Dengan menggabungkan posisi relatif conformer, muatan atom, nomor atom, identifier, dan sebagainya, dibuat atom-level single representation c
- c digunakan untuk menginisialisasi atom-level pair representation p, dan mask v digunakan agar hanya memuat jarak antaratom yang dihitung dari conformer referensi
- q dimulai sebagai salinan c, lalu diperbarui di Atom Transformer
-
Peran Atom Transformer
- Atom Transformer adalah modul yang melakukan attention tingkat atom, memperbarui q menggunakan p dan representasi awal c
- c tidak diperbarui, dan digunakan seperti residual connection yang mengarah ke representasi awal
- Struktur dasarnya mirip transformer, mencakup LayerNorm, attention, dan MLP transition, tetapi tiap tahap disesuaikan dengan input tambahan c dan p
-
Adaptive LayerNorm
- Adaptive LayerNorm tidak mempelajari
gammadanbetayang tetap, melainkan menghasilkangammadanbetadari input bantu - Di Atom Transformer, target yang di-rescale adalah q, dan parameter rescale diprediksi dari input bantu c
- Adaptive LayerNorm tidak mempelajari
-
Attention with Pair Bias
- Attention tingkat atom dengan pair bias adalah perluasan dari self-attention
- Query, key, dan value semuanya berasal dari single representation q, tetapi setelah dot product query-key, proyeksi linear dari pair representation p ditambahkan sebagai bias
- Informasi mengalir dari pair representation ke q, tetapi pada tahap ini p tidak diperbarui dengan informasi q
- Gate yang dibuat dengan melewatkan proyeksi tambahan melalui sigmoid dikalikan ke hasil attention, sehingga mengatur informasi apa yang dipertahankan di residual stream
- Karena jumlah atom bisa jauh lebih banyak daripada jumlah token, digunakan Sequence-local atom attention alih-alih full attention
- Local group berisi 32 atom dapat melakukan attend ke 128 atom lain
-
Conditioned Gating dan Transition
- Conditioned Gating menerapkan gate yang dihasilkan dari matriks single tingkat atom awal c ke data
- Conditioned Transition setara dengan MLP pada transformer, dan disebut conditioned karena Adaptive LayerNorm serta Conditional Gating bergantung pada c
- AF3 menggunakan SwiGLU di transition block, bukan ReLU
- Transition berbasis ReLU pada AF2 memiliki struktur up-projection 4 kali, ReLU, lalu down-projection
- SwiGLU pada AF3 menerapkan nonlinier swish ke salah satu dari dua up-projection, lalu mengalikannya dan melakukan down-project
Mengagregasi Representasi Atom menjadi Representasi Token
- Karena tahap pembelajaran representasi berikutnya bekerja di token-level, representasi atom-level diagregasi menjadi representasi token-level
- Setelah memproyeksikan atom-level representation ke dimensi yang lebih besar, diambil rata-rata atom-atom yang termasuk dalam token yang sama
- Agregasi rata-rata ini diterapkan ketika beberapa atom terhubung ke satu token, seperti pada asam amino dan nukleotida standar, sedangkan input dengan 1 token per atom dipertahankan apa adanya
- Pada input single token-level, statistik yang diperoleh dari MSA juga digabungkan
- Tipe asam amino
- Distribusi asam amino MSA pada posisi tersebut
- Deletion mean token tersebut
- Untuk token yang tidak memiliki MSA, seperti atom ligand, nilai-nilai ini menjadi 0
- s_inputs yang dibuat dengan cara ini melewati proyeksi menjadi s_init, lalu diperbarui pada tahap representation learning
- Pair representation z_init adalah tensor 3 dimensi yang menyimpan relasi untuk tiap pasangan token, dan setiap z_i,j adalah vektor berdimensi
c_z=128 - Untuk inisialisasi z_i,j, proyeksi s_i dan s_j, relative positional encoding, serta informasi bond antartoken yang ditentukan pengguna ditambahkan
Pembelajaran Representasi: Template, MSA, Pairformer
- Representation learning adalah trunk yang mencakup sebagian besar komputasi model, dengan tujuan memperbaiki representasi single token-level s dan representasi pair z
- Single sequence representation tidak hanya merujuk pada satu sekuens protein, tetapi pada sequence yang menggabungkan semua atom atau token di dalam struktur
-
Template Module
- Setiap template melewati proyeksi linear, lalu ditambahkan dengan proyeksi linear dari pair representation z
- Matriks gabungan melewati Pairformer Stack
- Hasil dari beberapa template dirata-ratakan, lalu melewati linear layer lagi
- ReLU digunakan pada linear layer terakhir, salah satu dari sedikit tempat di AF3 yang menggunakan ReLU sebagai nonlinier
-
MSA Module
- MSA Module sangat mirip dengan Evoformer pada AF2, dan secara bersamaan memperbaiki MSA representation m serta pair representation z
- Tidak semua row MSA digunakan; setelah dilakukan subsampling, proyeksi single representation ditambahkan ke MSA
- Outer Product Mean adalah operasi yang memasukkan informasi MSA ke pair representation
- Untuk setiap indeks token
i,j, dihitung outer product antara m_s,i dan m_s,j untuk semua evolutionary sequence - Nilai ini dirata-ratakan di seluruh sequence, di-flatten, lalu diproyeksikan dan ditambahkan ke z_i,j
- Ini adalah satu-satunya titik dalam model tempat informasi dibagikan antar evolutionary sequence
- Untuk setiap indeks token
- Row-wise gated self-attention using only pair bias menggunakan pair representation untuk memperbarui MSA
- Alih-alih membuat attention score dengan query dan key, pair representation z diproyeksikan menjadi matriks dan digunakan sebagai attention score antartoken
- Karena diterapkan secara independen pada setiap row MSA, pada tahap ini informasi tidak dibagikan antar evolutionary sequence
- Bagian terakhir MSA module memperbarui kembali pair representation melalui triangle update dan triangle attention
Pairformer dan operasi triangle
- Setelah z diperbarui dengan template dan MSA, template dan MSA tidak digunakan lagi, dan hanya s serta z yang dimasukkan ke Pairformer
- Pairformer menghasilkan s_trunk dan z_trunk akhir melalui pengulangan 48 block
-
Intuisi operasi triangle
- triangle update dan triangle attention adalah struktur yang mencoba mencerminkan intuisi ketaksamaan segitiga ke dalam model
- Meskipun z_i,j pada pair tensor bukan jarak fisik itu sendiri, karena memuat relasi antara token
idanj, tiga relasii-j,j-k, dani-kdiperbarui agar saling konsisten - Ketaksamaan segitiga tidak dipaksakan secara langsung di dalam model, melainkan diinduksi dengan melihat semua triplet
(i,j,k)untuk memperbarui z_i,j - z dapat dilihat seperti directed adjacency matrix, sehingga arah outgoing edge dan incoming edge diproses secara terpisah
-
Triangle Updates
- Pada outgoing update, setiap z_i,j diperbarui menggunakan elemen lain pada row yang sama, z_i,k, dan edge ketiga, z_j,k
- Dalam implementasinya, dibuat tiga projection
a,b,gdari z; element-wise multiplication antara rowidan rowjdijumlahkan terhadapk, lalu gategditerapkan - incoming update adalah bentuk dengan row dan column ditukar, yang memperbarui z_i,j melalui elemen lain pada column yang sama, z_k,j, dan z_k,i
-
Triangle Attention
- triangle attention adalah bentuk yang menambahkan prinsip triangle ke axial attention, yang menerapkan attention independen pada row dan column dari matrix 2D
- Dalam case “starting node”, z_j,k ditambahkan sebagai bias pada perbandingan query-key antara z_i,j dan z_i,k
- Dalam case “ending node”, operasi berjalan berdasarkan column, dan attention score antara z_i,j dan z_k,i diberi bias dengan z_k,j
-
Single Attention with Pair Bias
- Setelah triangle step dan transition block, single representation s diperbarui dengan single attention with pair bias yang menggunakan updated pair representation z
- Karena beroperasi pada token-level, digunakan full attention, bukan block-wise sparse attention yang dipakai pada atom-level
Prediksi struktur: denoising koordinat atom dengan difusi
-
Cara dasar model difusi
- AF3 melakukan prediksi struktur akhir dengan atom-level diffusion
- diffusion model dilatih dengan menambahkan random noise secara bertahap ke data nyata, lalu membuat model memprediksi noise apa yang ditambahkan
- Saat inference, proses dimulai dari random noise penuh; pada setiap step, noise yang diprediksi model dihapus untuk menghasilkan denoised datapoint
- Difusi kondisional menerima current noisy generation, representasi timestep saat ini, dan vektor kondisi sebagai input untuk menghasilkan keluaran yang sesuai kondisi
- Di AF3, target denoising adalah matrix x yang berisi koordinat
x,y,zsemua atom
-
Augmentasi rotasi dan translasi sebagai pengganti IPA AF2
- AF3 tidak menggunakan Invariant Point Attention milik AF2, melainkan merotasi dan mentranslasikan seluruh kompleks yang sedang diprediksi secara acak pada setiap timestep
- Augmentasi ini membuat model mempelajari bahwa rotasi dan translasi apa pun tetap valid sebagai struktur yang sama, dan merupakan pendekatan yang lebih sederhana daripada IPA AF2
- Rotasi diterapkan dengan pusat pada rata-rata semua koordinat atom dari generation saat ini, sedangkan translation disampel dari Gaussian
N(0,1)pada tiap dimensi - Noise kecil juga ditambahkan ke koordinat untuk mendorong generation yang lebih beragam
- Saat inference, beberapa generation dapat diberi skor dengan confidence head, lalu generation dengan skor tertinggi dikembalikan
-
Empat tahap Diffusion Module
- Setiap denoising step menggunakan beberapa conditioning representation
- output trunk s_trunk, z_trunk
- representasi awal s_inputs, c_inputs yang dibuat oleh input embedder
- Proses diffusion terdiri dari empat tahap dengan berpindah-pindah antara ruang token dan atom
-
- menyiapkan token-level conditioning tensor
-
- menyiapkan atom-level conditioning tensor, menerapkan Atom Transformer, lalu mengagregasikannya ke token-level
-
- menerapkan token-level attention
-
- memprediksi noise update per atom dengan atom-level attention
-
- Pada token-level conditioning, z_trunk digabungkan dengan relative positional encoding lalu dilewatkan melalui transition block
- Pada single representation, s_inputs dan s_trunk digabungkan, lalu ditambahkan Fourier embedding sesuai diffusion timestep
- Pada tahap atom-level, c dan p awal diperbarui dengan current token-level representation, dan koordinat saat ini x diskalakan dengan data variance untuk membuat dimensionless coordinate r
- Pada tahap atom-level terakhir, linear layer memetakan q ke
R^3untuk menghasilkan coordinate update r_update bagi semua atom - update diskalakan ulang menjadi x_update dengan mempertimbangkan data variance dan noise schedule, lalu diterapkan ke koordinat saat ini x_l
- Setiap denoising step menggunakan beberapa conditioning representation
Fungsi loss dan confidence head
- Loss keseluruhan adalah jumlah berbobot dari tiga komponen
L_loss = L_distogram * α_distogram + L_diffusion * α_diffusion + L_confidence * α_confidence
-
L_distogram
- L_distogram mengevaluasi akurasi distogram yang diprediksi pada level token
- Saat membuat koordinat token dari koordinat atom, digunakan koordinat atom pusat dari tiap token
- Jarak distogram diperlakukan sebagai nilai kategorikal, dan distogram prediksi dibandingkan dengan distogram aktual menggunakan cross entropy
-
L_diffusion
- L_diffusion adalah jumlah berbobot dari beberapa komponen yang menargetkan posisi atom
- L_MSE menghitung mean squared error antarposisi untuk semua atom, bukan hanya atom pusat, dan atom DNA, RNA, serta ligand diberi bobot lebih tinggi
- L_bond adalah komponen MSE tambahan untuk meningkatkan akurasi panjang ikatan pada pasangan atom yang termasuk dalam ikatan protein-ligand
- Pada tahap training awal,
α_bond=0, sehingga komponen ini diperkenalkan belakangan - L_smooth_LDDT adalah loss yang membuat akurasi jarak lokal menjadi halus dan dapat didiferensiasikan
- Menggunakan empat threshold: 4Å, 2Å, 1Å, dan 0.5Å
- Pasangan atom nucleotide diabaikan jika lebih jauh dari 30Å
- Pasangan atom protein atau ligand diabaikan jika lebih jauh dari 15Å
-
L_confidence
- L_confidence tidak secara langsung meningkatkan akurasi struktur, melainkan melatih model agar memperkirakan akurasi prediksinya sendiri
- Terdiri dari loss yang berkorespondensi dengan empat metrik confidence
- pLDDT: akurasi jarak lokal untuk atom-atom yang berdekatan
- PAE: predicted alignment error untuk pasangan token
- PDE: predicted distance error di antara pasangan token
- experimentally resolved prediction: memprediksi apakah tiap atom ter-resolve dalam struktur eksperimental
- Meski struktur prediksi tidak akurat sehingga PAE tinggi, jika model juga memprediksi PAE tinggi, loss PAE tersebut dapat menjadi rendah
- Confidence prediction dihasilkan pada tahap antara dalam diffusion
- Gradient dari confidence loss hanya memperbarui confidence prediction head, dan tidak memengaruhi bagian model lainnya
Teknik pembelajaran tambahan dan efisiensi
-
Recycling
- AF3 menggunakan weight recycling seperti AF2
- Alih-alih membuat model lebih dalam, weight yang sama digunakan kembali beberapa kali untuk menyempurnakan representation secara bertahap
- Diffusion juga secara inheren memiliki recycling, karena pada inference ia menggunakan informasi timestep dan menggunakan kembali weight yang sama pada setiap timestep
-
Cross-distillation
- AF3 menggunakan synthetic training data yang dibuat oleh dirinya sendiri, serta synthetic data yang dibuat oleh AF2 dan AF-Multimer
- Setelah beralih ke generation berbasis diffusion, ada masalah hilangnya bentuk “spaghetti” yang pada AF2 membantu membedakan secara visual area dengan confidence rendah atau disorder
- Dengan memasukkan generation dari AF2 dan AF-Multimer ke dalam training data AF3, AF3 mempelajari cara AF2 mengeluarkan unfolded region pada area yang tidak diyakininya
- Dalam dataset distillation, asam nukleat dan molekul kecil yang tidak dapat ditangani oleh AF2 dan AF-Multimer dihapus
- Setelah model sebelumnya membuat struktur prediksi lalu melakukan alignment dengan struktur asli, molekul yang sebelumnya dihapus ditambahkan kembali
- Jika molekul yang ditambahkan kembali menimbulkan atom clash, seluruh struktur dikecualikan agar model tidak belajar mengizinkan clash
-
Cropping dan tahap training
- Model itu sendiri tidak memiliki batasan eksplisit atas panjang sekuens input, tetapi beberapa operasi meningkat sebesar
N_tokens^3, sehingga kebutuhan memory dan compute membesar - Untuk efisiensi, protein di-random crop
- Karena interaksi antar beberapa chain harus dimodelkan, crop harus mencakup chain-chain tersebut bersama-sama
- Tiga metode cropping digunakan
- contiguous cropping: memilih sekuens asam amino yang berurutan dari tiap chain
- spatial cropping: memilih asam amino berdasarkan jarak ke atom acuan
- spatial interface cropping: memilih berdasarkan jarak ke atom pada binding interface
- Model yang dilatih dengan random crop 384 juga dapat diterapkan pada sekuens yang lebih panjang, tetapi untuk meningkatkan kemampuan menangani sekuens yang lebih panjang, fine-tuning diulang dengan sequence length yang lebih besar
- Model itu sendiri tidak memiliki batasan eksplisit atas panjang sekuens input, tetapi beberapa operasi meningkat sebesar
-
Clashing dan batch size
- Loss AF3 tidak mencakup clash penalty untuk atom yang saling overlap
- Diffusion-based structure module secara teoretis dapat memprediksi dua atom berada di posisi yang sama, tetapi setelah training masalah tersebut kecil
- Clashing penalty digunakan dalam ranking struktur yang dihasilkan
- Diffusion process terlihat kompleks, tetapi biaya komputasinya lebih rendah daripada trunk
- Untuk efisiensi training, batch size diperbesar setelah trunk
- Setiap input structure melewati embedding dan trunk satu kali, lalu 48 structure independen yang telah diberi data augmentation dilatih secara paralel
Desain AF3 dari Perspektif ML
-
Struktur yang mirip dengan Retrieval-Augmented Generation
- Pencarian MSA dan template pada AF3 memiliki karakter yang mirip dengan RAG pada model bahasa
- Di bidang AlphaFold, cara menggunakan template struktur sudah sejak jauh sebelum istilah RAG dikenal sebagai homology modeling
- AF3 mengurangi porsi pemrosesan MSA dibanding AF2, tetapi tetap menyertakan MSA dan template
- Beberapa model prediksi protein seperti ESMFold menghilangkan retrieval dan menggunakan fully parametric inference
-
Pair-Bias Attention
- Pair-Bias Attention, yang merupakan komponen utama AF2, digunakan lebih luas di AF3
- Query, key, dan value berasal dari source yang sama, tetapi pada attention map ditambahkan bias term dari source lain
- Ini adalah cara berbagi informasi yang lebih ringan dibanding full cross-attention
- Karena pair representation secara alami mirip dengan attention map, struktur ini mungkin cocok untuk pemodelan protein
-
Pengurangan self-supervised training
- Model keluarga ESM menunjukkan keunggulan dalam pendekatan menggantikan embedding MSA dengan self-supervised pre-training
- Di AF2 ada task tambahan untuk memprediksi masked token pada MSA, tetapi ini dihapus di AF3
- AF3 mengurangi compute untuk pemrosesan MSA dan tidak menggunakan self-supervised language modeling pre-training untuk MSA
- Kemungkinan alasannya: massive pre-training tidak efisien dari sisi penggunaan compute, modul MSA kecil lebih baik daripada pre-trained embedding, atau kombinasi struktur hybrid atom-token yang mencampur asam amino, DNA/RNA, dan ligand tidak cocok dengan pre-trained embedding
-
Campuran Classification dan Regression
- Seperti AF2, AF3 menggunakan MSE bersama binned classification loss
- Ciri classification loss adalah bahwa meski hanya salah satu distogram bin, tidak ada credit yang diberikan, sama seperti ketika prediksinya meleset jauh
- Dasar pilihan desain ini tidak jelas, tetapi mungkin gradient-nya lebih stabil dibanding beberapa MSE loss
-
Elemen yang mirip dengan recurrent architecture
- AF3 memiliki banyak elemen yang lebih mengingatkan pada recurrent network daripada transformer umum
- Gating mengendalikan aliran informasi dalam residual stream, mirip dengan gate pada LSTM atau GRU
- Recycling dan diffusion menerapkan weight yang sama berulang kali untuk memperbaiki prediksi secara bertahap
- Mirip dengan adaptive compute time, pembaruan berulang berkaitan dengan struktur yang dapat menerapkan lebih banyak pemrosesan pada input yang sulit
- Dalam ablation AF2, pentingnya recycling terlihat, tetapi tidak banyak pembahasan tentang pentingnya gating
Belum ada komentar.