Flash-KMeans rappresenta un punto di svolta nel campo degli algoritmi di clustering esatti su GPU grazie alla sua efficienza nell’uso della memoria. Mentre l’algoritmo k-means è stato principalmente utilizzato come preprocessore offline, moderni pipeline di AI lo integrano all'interno di cicli di addestramento e inferenza, richiedendo un'elaborazione con bassa latenza.
Che cos'è Flash-KMeans?
Flash-KMeans è una libreria open-source scritta utilizzando Triton GPU kernels, distribuita con licenza Apache 2.0. I risultati matematici sono identici a quelli dell’algoritmo Lloyd tradizionale, tuttavia la velocità è raggiunta ottimizzando il flusso di dati. Essa non approssima né modifica i calcoli ma si concentra sull’ottimizzare l’accesso alla memoria.
Può essere installata utilizzando semplicemente il comando:
pip install flash-kmeans
Le Due Bottiglie di Cervice
Flash-KMeans si concentra su due aspetti principali del processo di clustering: l'assegnazione di punti ai centroidi e l'aggiornamento dei centroidi, entrambi limitati da problemi di accesso alla memoria.
Assegnazione dei Punti
Nel tradizionale k-means, ogni punto viene confrontato con ogni centroide, creando una matrice di distanze di dimensioni NxK, che occupa molto spazio in HBM. Leggere e scrivere in questa matrice è costoso. Flash-KMeans introduce FlashAssign, un’architettura che legge e calcola la distanza di piccoli blocchi di dati (tiles) in memoria SRAM, evitando la generazione di una matrice completa. Questo abbassa l’accesso alla memoria da O(NK) a O(Nd + Kd), miglioramento significativo.
- FlashAssign utilizza un approccio simile a FlashAttention.
- Effonde i blocchi punti e centroidi direttamente da memoria HBM in SRAM.
- Calcola la distanza e aggiorna l’assegnazione in tempo reale.
Questo approccio ha dimostrato di migliorare l’assegnazione fino a 21.2 volte in una configurazione test con N=1M, K=8192.
Aggiornamento dei Centroidi
Il tradizionale k-means utilizza operazioni di aggiornamento scattering-style, che causano congestione in caso di hotspot di accesso. Flash-KMeans invece implementa Sort-Inverse Update, in cui i cluster vengono raggruppati in segmenti consecutivi tramite argsort. Questo elimina la competizione per un singolo centrale riducendo i contenziosi tra thread. I benefici riscontrati in termini di larghezza di banda di memoria e di velocità possono essere fino a 6.3 volte meglio rispetto al tradizionale.
- Sort-Inverse Update organizza gli assegnamenti in modo che i dati vengano raccolti per cluster.
- Ogni segmento di dati viene ridotto in SRAM.
- Un unico aggiornamento per segnamento riduce il numero di operazioni atomiche.
Test e Benchmark
I test di riferimento sono stati eseguiti su un sistema GPU NVIDIA H200 con dati in FP16 e dimensione d=128. Sono stati paragonati diversi parametri N, K e B a livello di implementazione su una serie di baseline tra cui fastpytorchkmeans, fastkmeans, cuML e FAISS. Gli aumenti di velocità misurati sono significativi.
- Confronto con cuML (biblioteca industriale): miglioramenti di tempo fino al 33×.
- Confronto con FAISS (biblioteca standard per l’indicizzazione vettoriale): miglioramenti fino a più del 200×.
- FlashAssign ottimizza l’assegnazione con incremento fino a 21.2×.
- Kernel di aggiornamento con Sort-Inverse Update migliora con un fattore fino a 6.3×.
Un caso test su 1B punti (K=32768) completa un'iterazione in 41,4 secondi, mentre la baseline impiega 261,8 secondi.
Applicazione su Grandi Dataset con Fuori Memoria (Out-of-Core)
Una funzionalità chiave di Flash-KMeans è la capacità di gestire grandi dataset con una tecnica detta out-of-core. Il kernel utilizza un’ottimizzazione chunked stream overlap per mascherare i trasferimenti di PCIe durante il calcolo.
- Caching-aware compilation riduce i tempi di test da 175× fino al 0,3%.
- Gestisce 400M punti, K=16384, rispetto a fastkmeans con incremento del 10,5×.
- Flash-KMeans gestisce anche 400M punti e K=16384, completando l’iterazione velocemente.
Utilizzo in Applicazioni Critiche
Flash-KMeans sta redefinendo come si possono utilizzare clustering esatto in applicazioni real time, come:
- Indicizzazione per la ricerca vettoriale: FAISS utilizza k-means per l’indicizzazione. Flash-KMeans permette di ridurre il tempo di reinserimento di dati nuovi.
- Routing attention con Transformer: I tokenizer vengono assegnati a cluster in tempo reale grazie alla velocità di Flash-KMeans.
- Compressione cache KV: Compressione layer-specifico grazie al clustering semantico.
- Attenuazione su KV entries: Processo di clustering per entry in codice.
- Transformer per diffusione video: Permuta i token in base all’accesso semantico.
Usare Flash-KMeans
Flash-KMeans è compatibile con framework come FAISS e sklearn. Ecco un esempio di utilizzo per clustering batch:
import torch
from flashkmeans import batchkmeans_Euclid
x = torch.randn(32, 75600, 128, device="cuda", dtype=torch.float16)
clusterids, centers, = batchkmeansEuclid(
x, n_clusters=1000, tol=1e-4, verbose=True
)
Alternativamente, può essere utilizzato con una API in stile scikit-learn:
from flash_kmeans import FlashKMeans
km = FlashKMeans(d=128, k=8192, niter=100)
labels = km.fitpredict(largecpu_tensor) # device=None utilizza tutti i GPU visibili
- L'algoritmo si adatta automaticamente in base alla dimensione del problema.
- I percorsi low-dim (d ≤ 512) vengono gestiti separatamente.
- I percorsi di alta dimensione evitano