Flash-KMeans rappresenta un punto di svolta nel campo degli algoritmi di clustering esatti su GPU grazie alla sua efficienza nell’uso della memoria. Mentre l’algoritmo k-means è stato principalmente utilizzato come preprocessore offline, moderni pipeline di AI lo integrano all'interno di cicli di addestramento e inferenza, richiedendo un'elaborazione con bassa latenza.

Che cos'è Flash-KMeans?

Flash-KMeans è una libreria open-source scritta utilizzando Triton GPU kernels, distribuita con licenza Apache 2.0. I risultati matematici sono identici a quelli dell’algoritmo Lloyd tradizionale, tuttavia la velocità è raggiunta ottimizzando il flusso di dati. Essa non approssima né modifica i calcoli ma si concentra sull’ottimizzare l’accesso alla memoria.

Può essere installata utilizzando semplicemente il comando:

pip install flash-kmeans

Le Due Bottiglie di Cervice

Flash-KMeans si concentra su due aspetti principali del processo di clustering: l'assegnazione di punti ai centroidi e l'aggiornamento dei centroidi, entrambi limitati da problemi di accesso alla memoria.

Assegnazione dei Punti

Nel tradizionale k-means, ogni punto viene confrontato con ogni centroide, creando una matrice di distanze di dimensioni NxK, che occupa molto spazio in HBM. Leggere e scrivere in questa matrice è costoso. Flash-KMeans introduce FlashAssign, un’architettura che legge e calcola la distanza di piccoli blocchi di dati (tiles) in memoria SRAM, evitando la generazione di una matrice completa. Questo abbassa l’accesso alla memoria da O(NK) a O(Nd + Kd), miglioramento significativo.

    • FlashAssign utilizza un approccio simile a FlashAttention.
    • Effonde i blocchi punti e centroidi direttamente da memoria HBM in SRAM.
    • Calcola la distanza e aggiorna l’assegnazione in tempo reale.

Questo approccio ha dimostrato di migliorare l’assegnazione fino a 21.2 volte in una configurazione test con N=1M, K=8192.

Aggiornamento dei Centroidi

Il tradizionale k-means utilizza operazioni di aggiornamento scattering-style, che causano congestione in caso di hotspot di accesso. Flash-KMeans invece implementa Sort-Inverse Update, in cui i cluster vengono raggruppati in segmenti consecutivi tramite argsort. Questo elimina la competizione per un singolo centrale riducendo i contenziosi tra thread. I benefici riscontrati in termini di larghezza di banda di memoria e di velocità possono essere fino a 6.3 volte meglio rispetto al tradizionale.

    • Sort-Inverse Update organizza gli assegnamenti in modo che i dati vengano raccolti per cluster.
    • Ogni segmento di dati viene ridotto in SRAM.
    • Un unico aggiornamento per segnamento riduce il numero di operazioni atomiche.

Test e Benchmark

I test di riferimento sono stati eseguiti su un sistema GPU NVIDIA H200 con dati in FP16 e dimensione d=128. Sono stati paragonati diversi parametri N, K e B a livello di implementazione su una serie di baseline tra cui fastpytorchkmeans, fastkmeans, cuML e FAISS. Gli aumenti di velocità misurati sono significativi.

    • Confronto con cuML (biblioteca industriale): miglioramenti di tempo fino al 33×.
    • Confronto con FAISS (biblioteca standard per l’indicizzazione vettoriale): miglioramenti fino a più del 200×.
    • FlashAssign ottimizza l’assegnazione con incremento fino a 21.2×.
    • Kernel di aggiornamento con Sort-Inverse Update migliora con un fattore fino a 6.3×.

Un caso test su 1B punti (K=32768) completa un'iterazione in 41,4 secondi, mentre la baseline impiega 261,8 secondi.

Applicazione su Grandi Dataset con Fuori Memoria (Out-of-Core)

Una funzionalità chiave di Flash-KMeans è la capacità di gestire grandi dataset con una tecnica detta out-of-core. Il kernel utilizza un’ottimizzazione chunked stream overlap per mascherare i trasferimenti di PCIe durante il calcolo.

    • Caching-aware compilation riduce i tempi di test da 175× fino al 0,3%.
    • Gestisce 400M punti, K=16384, rispetto a fastkmeans con incremento del 10,5×.
    • Flash-KMeans gestisce anche 400M punti e K=16384, completando l’iterazione velocemente.

Utilizzo in Applicazioni Critiche

Flash-KMeans sta redefinendo come si possono utilizzare clustering esatto in applicazioni real time, come:

    • Indicizzazione per la ricerca vettoriale: FAISS utilizza k-means per l’indicizzazione. Flash-KMeans permette di ridurre il tempo di reinserimento di dati nuovi.
    • Routing attention con Transformer: I tokenizer vengono assegnati a cluster in tempo reale grazie alla velocità di Flash-KMeans.
    • Compressione cache KV: Compressione layer-specifico grazie al clustering semantico.
    • Attenuazione su KV entries: Processo di clustering per entry in codice.
    • Transformer per diffusione video: Permuta i token in base all’accesso semantico.

Usare Flash-KMeans

Flash-KMeans è compatibile con framework come FAISS e sklearn. Ecco un esempio di utilizzo per clustering batch:

import torch

from flashkmeans import batchkmeans_Euclid

x = torch.randn(32, 75600, 128, device="cuda", dtype=torch.float16)

clusterids, centers, = batchkmeansEuclid(

x, n_clusters=1000, tol=1e-4, verbose=True

)

Alternativamente, può essere utilizzato con una API in stile scikit-learn:

from flash_kmeans import FlashKMeans

km = FlashKMeans(d=128, k=8192, niter=100)

labels = km.fitpredict(largecpu_tensor) # device=None utilizza tutti i GPU visibili

    • L'algoritmo si adatta automaticamente in base alla dimensione del problema.
    • I percorsi low-dim (d ≤ 512) vengono gestiti separatamente.
  • I percorsi di alta dimensione evitano