nanoSparseAttention provides clean, educational implementations of recent Sparse Attention mechanisms for both prefilling and generation stages of LLM inference. The repository prioritizes clarity and understanding over performance, making it ideal for learning and experimentation.
We implemented a Jupyter notebook that provides:
- Detailed explanation of Sparse Attention concepts
- Step-by-step implementation walkthrough
- Visualization of attention patterns
- Performance comparisons between different methods
The notebook has been prepared for the purpose of NeurIPS 2024 Dynamic Sparsity Workshop - check it out if you want to learn more about dynamic execution, not only in the context of self-attention!
- Pure PyTorch Implementation: All attention mechanisms are implemented in pure PyTorch for maximum clarity and ease of understanding.
- Real-world Testing: Uses Llama-3.2-1B-Instruct model and FiscalNote/billsum dataset for practical experiments.
- Comprehensive Tutorial: Includes a detailed Jupyter notebook explaining core concepts and implementations.
- Extensible Design: Easy to add new models, datasets, and attention patterns through modular architecture.
- Flexible Inference: Supports both prefilling and generation stages with ability to mix both at once.
- Local Window + Attention Sinks (Xiao et al, 2023), (Han et al, 2024)
- Vertical-Slash Attention (Jiang et al, 2024)
- Block-Sparse Attention (Jiang et al, 2024)
- Local Window + Attention Sinks (Xiao et al, 2023), (Han et al, 2024)
- SnapKV (Li et al, 2024)
- TOVA (Oren et al, 2023)
Assuming that we want to use Python venv it's as easy as:
git clone https://github.com/PiotrNawrot/nano-sparse-attention
cd nano-sparse-attention
python3 -m venv .venv
source .venv/bin/activate
pip install --upgrade pip setuptools wheel psutil
pip install -e ./
The repository provides two main scripts for experimenting with sparse attention mechanisms:
from nano_sparse_attn.attention import InferenceHandler, DenseAttention, LocalAndSinksAttention
from nano_sparse_attn.utils import load_model_and_tokenizer, load_examples, update_attention, model_forward
# Load model and prepare inputs
model, tokenizer = load_model_and_tokenizer()
model_inputs = load_examples(tokenizer, num_examples=1)
# Create an inference handler with Local Window + Attention Sinks
handler = InferenceHandler(
prefill_attention=LocalAndSinksAttention(
window_size=256,
attention_sinks=16
),
generation_attention=DenseAttention()
)
# Update model's attention mechanism and run forward pass
update_attention(model, handler)
loss = model_forward(model, model_inputs, handler)
# Get information about the attention mechanism
info = handler.info()
print(f"Loss: {loss}")
print(f"Sparsity: {info['prefill']['sparsity']}")
# Assumes imports from the previous example
from nano_sparse_attn.attention import SnapKVAttention
# Create an inference handler with SnapKV for generation
handler = InferenceHandler(
prefill_attention=DenseAttention(),
generation_attention=SnapKVAttention(
approximation_window=64,
token_capacity=256
)
)
# Update model's attention mechanism and run forward pass
update_attention(model, handler)
loss = model_forward(model, model_inputs, handler)
# Get information about the attention mechanism
info = handler.info()
print(f"Loss: {loss}")
print(f"Sparsity: {info['generation']['sparsity']}")
For ready-to-use scripts check out main_prefill.py and main_generate.py. For a detailed walkthrough of the repository and information about extending it to new models, datasets, and attention patterns, refer to this README.
Contributions are welcome! Our goal is to keep this repository up-to-date with the latest Sparse Attention methods, by consistently adding new methods. Feel free to submit a Pull Request if 1) you want a new method to be added or 2) [even better] you have an implementation of a new Sparse Attention method!
Piotr Nawrot - Website - piotr@nawrot.org
Edoardo Maria Ponti - Website - eponti@ed.ac.uk