MiniMax ha sviluppato MSA (MiniMax Sparse Attention), un nuovo meccanismo sparsificato per l'attenzione in modelli linguistici di grandi dimensioni. MSA si basa direttamente su GQA (Grouped Query Attention), un approccio che raggruppa testi per gestire l’elaborazione in maniera più efficiente. L’obbiettivo di MiniMax è risolvere un collo di bottiglia chiave: la crescita quadratica del tempo di computazione dell’attenzione softmax con il raffinarsi del contesto lungo.

Che cos’é MSA

MSA divide il processo di attenzione in due fasi principali: una branch di Indice e una branch di Attenzione principale. La branch di Indice decide per ogni query i blocchi chiave-valore su cui concentrarsi. Successivamente, la branch principale applica l’attenzione softmax esatta su solo quei blocchi, riducendo considerevolmente la complessità computazionale.

La selezione avviene a livello di blocchi e non token per token. La dimensione predefinita di un blocco è Bk = 128 token. Ogni query e gruppo GQA conserva k = 16 blocchi. Questo blocco limita il budget per query a kBk = 2,048 token chiave-valore.

Mentre l’attenzione densa GQA scala con una complessità lineare O(N) al contesto totale, MSA invece mantiene una complessità fissa O(kB_k), la quale non aumenta con la lunghezza del contesto. Questo gap cresce significativamente al crescere della lunghezza contestuale.

Funzionamento Delle Due Branch

La branch di Indice introduce solamente due matrici di proiezione in aggiunta alla struttura di base del GQA. Viene definita una testa di query di indirizzo per gruppo GQA e una testa chiave condivisa tra gruppi. Esso calcola una valutazione per i token chiave visibili, e li raggruppa a livello di blocco utilizzando l’operazione max-pooling.

Un operatore Top-k seleziona i blocchi di punteggio più alto per query, garantendo che anche il blocco contenente la query venga incluso. Questo approccio evita che la selezione ignorasse i blocchi immediatamente vicini alla query.

La branch principale raccoglie i token visibili casualmente da tali blocchi selezionati. Applica un’attenzione softmax dot-product su questi token. Ogni testa del query conserva la sua proiezione ma condivide lo stesso insieme di blocchi del gruppo.

Come è Addestrato MSA

Poiché la selezione Top-k non è differenziabile, MiniMax ha adottato una perdita di allineamento KL (Kullback-Leibler) per allenare la branch di Indice. Questa misura confronta la distribuzione della branch di Indice con quella della branch principale, adottando come insegnante la distribuzione aggregata della branch principale.

Per rendere stabile il training sparso sono state adottate tre meccaniche: Gradient Detach, Indexer Warmup, e Local Block Reservation. Gradient Detach aggiunge uno stop-gradient all'input della branch di Indice, permettendo alla perdita KL di focalizzarsi sulle proiezioni degli indici, isolandole dal resto della rete. Indexer Warmup inizializza usando l'attenzione standard prima che la branch di Indice assuma la decisione effettiva.

Due percorsi di addestramento sono supportati: MSA-PT inizia da zero dopo un inizializzatore da 40B token, e MSA-CPT converte un checkpoint densamente addestrato su 2.6T token. Continua l'addestramento per ulteriori 400B token.

Co-Design Del Kernel

MiniMax ha abbinato MSA a un kernel specifico per GPU, per ottimizzare la velocità di calcolo. Il primo meccanismo del kernel impiega una selezione Top-k free da funzioni exp(). Questo kernel evita calcoli di massimo, esponenziale o somma, rendendolo più veloce. Per contesti di 128K token con k = 16, il kernel ha mostrato di essere 5,1 volte più veloce del comando torch.topk di PyTorch.

Il secondo meccanismo è l’attenzione sparsificata a blocchi chiave-valore (KV-outer) con raccolta su query. Con iterazioni ottimizzate su blocchi KV rispetto a query, MiniMax ha raggiunto un miglioramento considerevole in efficienza computazionale.

Il kernel opensource, fmha_sm100, è stato progettato per i dispositivi NVIDIA basati su SM100. Sottoscrive licenza MIT e supporta diverse precisioni: FP8, NVFP4, BF16, ed FP4.

Confronto Con Altri Metodi Sparso

MiniMax ha confrontato MSA con quattro metodi esistenti per modelli sparsi. I principali differenziali di MSA si rivelano in gran parte legati all’organizzazione a blocchi e alla distribuzione delle attenzioni all'interno dei gruppi di GQA.

    • NSA seleziona e comprime blocchi con finestre scorrevoli.
    • InfLLM-V2 permette il passaggio da un modello densamente connesso a uno sparso.
    • MoBA utilizza blocchi molto grandi con chiavi media.
    • DSA adotta un indexer basato su ReLU.

La particolarità distintiva di MSA risiede nel fatto che la gestione degli indici di Top-k è condivisa per gruppo GQA, permettendo al tempo stesso di mantenere la coerenza all’intero set di input.

Qualità Del Modello

MiniMax ha condotto test su benchmark notevoli, dimostrando che i modelli sparso restano competitivi rispetto a versioni dense. Con un budget di 3T token, MSA ha raggiunto risultati eccezionali.

Benchmark

    • MMLU: 66.8 (MSA-PT), 67.2 (base), 67.0 (MSA-CPT)
    • GSM8K: 76.2 (MSA-PT), 77.7 (base), 73.7 (MSA-CPT)
    • HumanEval: 57.9 (MSA-CPT), 61.0 (base), 64.0 (MSA-PT)
    • RULER-8K: 77.2 (MSA-CPT), 79.8 (base), 84.2 (MSA-PT)
    • VideoMME: 39.65 (MSA-CPT), 41.11 (base), 45.48 (MSA-PT)

Applicazioni Pratiche

MSA è ideale in contesti