diff --git a/examples/llm_eval/lm_eval_hf.py b/examples/llm_eval/lm_eval_hf.py index 31103ff86..24dcb28f6 100755 --- a/examples/llm_eval/lm_eval_hf.py +++ b/examples/llm_eval/lm_eval_hf.py @@ -43,9 +43,11 @@ from lm_eval.api.model import T from lm_eval.models.huggingface import HFLM from quantization_utils import quantize_model +from sparse_attention_utils import sparsify_model import modelopt.torch.opt as mto from modelopt.torch.quantization.utils import is_quantized +from modelopt.torch.sparsity.attention_sparsity.conversion import is_attn_sparsified def create_from_arg_obj(cls: type[T], arg_dict: dict, additional_config: dict | None = None) -> T: @@ -60,9 +62,20 @@ def create_from_arg_obj(cls: type[T], arg_dict: dict, additional_config: dict | calib_size = arg_dict.pop("calib_size", 512) compress = arg_dict.pop("compress", False) + # Sparse attention arguments + sparse_cfg = arg_dict.pop("sparse_cfg", None) + additional_config = {} if additional_config is None else additional_config additional_config = {k: v for k, v in additional_config.items() if v is not None} + # Force eager attention if sparse attention is requested + if sparse_cfg: + additional_config["attn_implementation"] = "eager" + warnings.warn( + "Sparse attention requires attn_implementation='eager'. " + "Forcing eager attention implementation." + ) + # Enable automatic save/load of modelopt state huggingface checkpointing mto.enable_huggingface_checkpointing() @@ -91,6 +104,15 @@ def create_from_arg_obj(cls: type[T], arg_dict: dict, additional_config: dict | auto_quantize_checkpoint=auto_quantize_checkpoint, ) + if sparse_cfg: + if is_attn_sparsified(model_obj.model): + warnings.warn("Skipping sparse attention: model already has sparse attention applied.") + else: + sparsify_model( + model=model_obj, + sparse_cfg=sparse_cfg, + ) + return model_obj @@ -152,6 +174,11 @@ def setup_parser_with_modelopt_args(): action="store_true", help="Compress the model after quantization", ) + parser.add_argument( + "--sparse_cfg", + type=str, + help="Sparse attention configuration (e.g., SKIP_SOFTMAX_DEFAULT, SKIP_SOFTMAX_CALIB)", + ) return parser @@ -177,6 +204,7 @@ def setup_parser_with_modelopt_args(): "calib_batch_size": args.calib_batch_size, "calib_size": args.calib_size, "compress": args.compress, + "sparse_cfg": args.sparse_cfg, } ) diff --git a/examples/llm_eval/mmlu.py b/examples/llm_eval/mmlu.py index ca244052b..b6bccd3a7 100755 --- a/examples/llm_eval/mmlu.py +++ b/examples/llm_eval/mmlu.py @@ -48,6 +48,7 @@ from fire import Fire from modeling import EvalModel, select_model from quantization_utils import MAX_SEQ_LEN, get_tokenizer, quantize_model +from sparse_attention_utils import sparsify_model from tqdm import tqdm try: @@ -56,6 +57,7 @@ LLM = None # type: ignore[misc] import modelopt.torch.opt as mto from modelopt.torch.quantization.utils import is_quantized +from modelopt.torch.sparsity.attention_sparsity.conversion import is_attn_sparsified os.environ["TOKENIZERS_PARALLELISM"] = "false" @@ -230,6 +232,7 @@ def main( auto_quantize_method: str = "gradient", auto_quantize_score_size: int = 128, auto_quantize_checkpoint: str | None = None, + sparse_cfg: str | None = None, **kwargs, ): random.seed(RAND_SEED) @@ -266,6 +269,14 @@ def main( max_batch_size=1, ) else: + # Force eager attention if sparse attention is requested + if sparse_cfg: + kwargs["attn_implementation"] = "eager" + warnings.warn( + "Sparse attention requires attn_implementation='eager'. " + "Forcing eager attention implementation." + ) + model = select_model( max_input_length=MAX_SEQ_LEN, max_output_length=2, dtype=dtype, **kwargs ) @@ -289,6 +300,34 @@ def main( auto_quantize_checkpoint=auto_quantize_checkpoint, ) + # Apply sparse attention if requested + if sparse_cfg: + model.load() + + if is_attn_sparsified(model.model): + warnings.warn( + "Skipping sparse attention: model already has sparse attention applied." + ) + else: + sparsify_model( + model=model, + sparse_cfg=sparse_cfg, + ) + + # Apply sparse attention if requested + if sparse_cfg: + model.load() + + if is_attn_sparsified(model.model): + warnings.warn( + "Skipping sparse attention: model already has sparse attention applied." + ) + else: + sparsify_model( + model=model, + sparse_cfg=sparse_cfg, + ) + for subject in tqdm(subjects): dev_df = pd.read_csv(os.path.join(data_dir, "dev", subject + "_dev.csv"), header=None)[ :ntrain diff --git a/examples/llm_eval/modeling.py b/examples/llm_eval/modeling.py index 747b95d5b..d06d05560 100644 --- a/examples/llm_eval/modeling.py +++ b/examples/llm_eval/modeling.py @@ -179,6 +179,7 @@ class SeqToSeqModel(EvalModel): lora_path: str = "" device: str = "cuda" load_8bit: bool = False + attn_implementation: str | None = None def load(self): if self.model is None: @@ -188,6 +189,8 @@ def load(self): if self.load_8bit: args.update(device_map="auto", load_in_8bit=True) args.update(torch_dtype=getattr(torch, self.dtype) if self.dtype != "auto" else "auto") + if self.attn_implementation: + args["attn_implementation"] = self.attn_implementation self.model = AutoModelForSeq2SeqLM.from_pretrained(self.model_path, **args) print_gpu_utilization() if self.lora_path: @@ -241,6 +244,8 @@ def load(self): if self.load_8bit: args.update(device_map="auto", load_in_8bit=True) args.update(torch_dtype=getattr(torch, self.dtype) if self.dtype != "auto" else "auto") + if self.attn_implementation: + args["attn_implementation"] = self.attn_implementation self.model = AutoModelForCausalLM.from_pretrained( self.model_path, trust_remote_code=True, **args ) diff --git a/examples/llm_eval/sparse_attention_utils.py b/examples/llm_eval/sparse_attention_utils.py new file mode 100644 index 000000000..dc7a1b14e --- /dev/null +++ b/examples/llm_eval/sparse_attention_utils.py @@ -0,0 +1,111 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utilities for sparse attention integration with llm_eval.""" + +import modelopt.torch.sparsity.attention_sparsity as mtsa + +# Custom sparse attention configurations +CUSTOM_SPARSE_CONFIG = { + "SPARSE_CONSERVATIVE": { + "sparse_cfg": { + "*attn*": { + "method": "flash_skip_softmax", + "threshold": {"prefill": 5e-4, "decode": 1e-5}, + "br": 128, + "bc": 128, + "backend": "pytorch", + "enable": True, + }, + "default": {"enable": False}, + }, + }, + "SPARSE_AGGRESSIVE": { + "sparse_cfg": { + "*attn*": { + "method": "flash_skip_softmax", + "threshold": {"prefill": 5e-3, "decode": 5e-4}, + "br": 128, + "bc": 128, + "backend": "pytorch", + "enable": True, + }, + "default": {"enable": False}, + }, + }, +} + + +def _extract_model(model_obj): + """Extract actual model from wrapper (HFLM or EvalModel).""" + if hasattr(model_obj, "gpt2"): + return model_obj.gpt2 + elif hasattr(model_obj, "model"): + return model_obj.model + else: + return model_obj + + +def sparsify_model( + model, + sparse_cfg: str, + backend=None, +): + """Apply sparse attention to model with optional RULER calibration. + + Args: + model: Model wrapper (HFLM or EvalModel) or raw model + sparse_cfg: Sparse attention config name or dict + backend: Backend to use (optional, overrides config backend) + + Returns: + The model with sparse attention applied + + Note: + Calibration is automatically triggered if the config contains a 'calibration' field. + The calibration will auto-generate RULER dataset from the model's tokenizer. + """ + # Extract actual model + net = _extract_model(model) + + # Resolve config + if isinstance(sparse_cfg, str): + # Try custom configs first + mtsa_cfg = CUSTOM_SPARSE_CONFIG.get(sparse_cfg) + if mtsa_cfg is None: + # Try predefined configs + mtsa_cfg = getattr(mtsa, sparse_cfg, None) + if mtsa_cfg is None: + raise ValueError(f"Unknown sparse_cfg: {sparse_cfg}") + else: + mtsa_cfg = sparse_cfg + + # Override backend if specified + if backend: + if isinstance(mtsa_cfg, dict) and "sparse_cfg" in mtsa_cfg: + modified_sparse_cfg = {} + for pattern, cfg in mtsa_cfg["sparse_cfg"].items(): + modified_cfg = cfg.copy() if isinstance(cfg, dict) else cfg + if isinstance(modified_cfg, dict): + modified_cfg["backend"] = backend + modified_sparse_cfg[pattern] = modified_cfg + mtsa_cfg = {"sparse_cfg": modified_sparse_cfg} + + # Apply sparsification + print(f"\nApplying sparse attention with config: {sparse_cfg}") + mtsa.sparsify(net, mtsa_cfg) + print("Sparse attention applied successfully!") + + return model diff --git a/examples/llm_sparsity/attention_sparsity/README.md b/examples/llm_sparsity/attention_sparsity/README.md new file mode 100644 index 000000000..204fe8b83 --- /dev/null +++ b/examples/llm_sparsity/attention_sparsity/README.md @@ -0,0 +1,172 @@ +# Attention Sparsity for HuggingFace Models + +In this tutorial, we demonstrate how to use NVIDIA TensorRT Model Optimizer to apply attention sparsity to HuggingFace models. Attention sparsity reduces computational cost by skipping near-zero attention scores during the softmax computation. + +## Getting Started + +### Quick Example + +```python +import modelopt.torch.sparsity.attention_sparsity as mtsa +from modelopt.torch.sparsity.attention_sparsity.config import SKIP_SOFTMAX_DEFAULT + +# Load your model +model = AutoModelForCausalLM.from_pretrained( + "meta-llama/Llama-2-7b-hf", + attn_implementation="eager", # Required for sparse attention + torch_dtype=torch.bfloat16, +) + +# Apply sparse attention +model = mtsa.sparsify(model, config=SKIP_SOFTMAX_DEFAULT) +``` + +> [!Note] +> `attn_implementation="eager"` is required for sparse attention to work properly. Flash Attention 2 or SDPA would bypass the softmax patching needed for stats collection. + +## Configuration Options + +Two pre-defined configurations are available: + +### 1. Fixed Threshold (SKIP_SOFTMAX_DEFAULT) + +Uses a fixed threshold value. Simple but may not be optimal for all sequence lengths. + +```python +from modelopt.torch.sparsity.attention_sparsity.config import SKIP_SOFTMAX_DEFAULT + +model = mtsa.sparsify(model, config=SKIP_SOFTMAX_DEFAULT) +``` + +### 2. Calibrated Threshold (SKIP_SOFTMAX_CALIB) + +Uses RULER-based calibration to determine an optimal dynamic threshold that adapts to sequence length. Recommended for production use. + +```python +from modelopt.torch.sparsity.attention_sparsity.config import SKIP_SOFTMAX_CALIB + +model = mtsa.sparsify(model, config=SKIP_SOFTMAX_CALIB) +``` + +## Prerequisites + +### Install Requirements + +```bash +pip install -r requirements.txt +``` + +### Download RULER Calibration Data (Required for Calibration) + +If using `SKIP_SOFTMAX_CALIB`, you need to download the RULER calibration dataset first: + +```bash +bash modelopt/torch/sparsity/attention_sparsity/calibration/download_ruler_data.sh +``` + +This downloads the Paul Graham essays dataset used for generating calibration samples. + +## Run Sparse Attention on HuggingFace Models + +### Basic Usage (Without Calibration) + +Apply sparse attention with a fixed threshold: + +```bash +python hf_sa.py \ + --pyt_ckpt_path meta-llama/Llama-2-7b-hf \ + --sparse_attn skip_softmax \ + --verify_output +``` + +### With RULER Calibration + +Apply sparse attention with calibrated thresholds for optimal sparsity: + +```bash +python hf_sa.py \ + --pyt_ckpt_path meta-llama/Llama-2-7b-hf \ + --sparse_attn skip_softmax_calib \ + --verify_output +``` + +The calibration process: + +1. Generates RULER calibration samples +2. Collects attention statistics during forward passes +3. Determines optimal threshold scale factor for target sparsity ratio + +### Command Line Arguments + +| Argument | Default | Description | +|----------|---------|-------------| +| `--pyt_ckpt_path` | Required | HuggingFace model path or name | +| `--sparse_attn` | `skip_softmax` | Configuration: `skip_softmax` or `skip_softmax_calib` | +| `--backend` | `pytorch` | Backend: `pytorch` or `triton` | +| `--seq_len` | `2048` | Maximum sequence length for input prompts | +| `--verify_output` | `False` | Compare baseline vs sparse attention outputs | +| `--export_dir` | `None` | Directory to export the sparsified model | + +## Verify Outputs + +The `--verify_output` flag compares outputs between baseline (sparse attention disabled) and sparse attention enabled modes: + +```bash +python hf_sa.py \ + --pyt_ckpt_path meta-llama/Llama-2-7b-hf \ + --sparse_attn skip_softmax_calib \ + --verify_output +``` + +This will: + +1. Load a test sample from the NarrativeQA dataset +2. Generate text with sparse attention disabled (baseline) +3. Generate text with sparse attention enabled +4. Compare and display both outputs + +## Export Model + +Export the sparsified model to a HuggingFace checkpoint: + +```bash +python hf_sa.py \ + --pyt_ckpt_path meta-llama/Llama-2-7b-hf \ + --sparse_attn skip_softmax_calib \ + --export_dir ./exported_sparse_model +``` + +The exported model can be loaded and used with standard HuggingFace APIs. + +## Custom Configuration + +You can create custom sparse attention configurations: + +```python +custom_config = { + "sparse_cfg": { + "calibration": { # Optional: omit for fixed threshold + "target_sparse_ratio": 0.5, # Target 50% sparsity + "samples": 128, # Number of calibration samples + "max_seqlen": 8192, # Maximum sequence length + }, + "*attn*": { # Pattern to match attention modules + "method": "flash_skip_softmax", + "threshold": 1e-4, # Fixed threshold (ignored if calibration is used) + "br": 128, # Flash Attention block rows + "bc": 128, # Flash Attention block columns + "backend": "pytorch", + "collect_stats": True, + "enable": True, + }, + "default": {"enable": False}, + }, +} + +model = mtsa.sparsify(model, config=custom_config) +``` + +## References + +- [TensorRT Model Optimizer Documentation](https://nvidia.github.io/TensorRT-Model-Optimizer/) +- [RULER: What's the Real Context Size of Your Long-Context Language Models?](https://github.com/NVIDIA/RULER) diff --git a/examples/llm_sparsity/attention_sparsity/hf_sa.py b/examples/llm_sparsity/attention_sparsity/hf_sa.py new file mode 100644 index 000000000..03bcb75d4 --- /dev/null +++ b/examples/llm_sparsity/attention_sparsity/hf_sa.py @@ -0,0 +1,312 @@ +#!/usr/bin/env python3 +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Example script for applying sparse attention to HuggingFace models.""" + +import argparse +import random +from pathlib import Path + +import numpy as np +import torch +from datasets import load_dataset +from transformers import AutoModelForCausalLM, AutoTokenizer + +import modelopt.torch.opt as mto +import modelopt.torch.sparsity.attention_sparsity as mtsa +from modelopt.torch.export import export_hf_checkpoint +from modelopt.torch.sparsity.attention_sparsity.config import ( + SKIP_SOFTMAX_CALIB, + SKIP_SOFTMAX_DEFAULT, +) +from modelopt.torch.sparsity.attention_sparsity.sparse_attention import SparseAttentionModule +from modelopt.torch.utils.memory_monitor import launch_memory_monitor + +RAND_SEED = 1234 + +# Enable HuggingFace checkpointing support +mto.enable_huggingface_checkpointing() + +# Sparse attention configuration choices +SPARSE_ATTN_CFG_CHOICES = { + "skip_softmax": SKIP_SOFTMAX_DEFAULT, + "skip_softmax_calib": SKIP_SOFTMAX_CALIB, +} + + +def get_narrativeqa_samples(num_samples=3): + """Load samples from NarrativeQA dataset for testing. + + Args: + num_samples: Number of samples to generate + """ + # Load NarrativeQA dataset + dataset = load_dataset("narrativeqa", split="test", streaming=True) + + samples = [] + for i, item in enumerate(dataset): + if i >= num_samples: + break + + # Combine document context and question + context = item.get("document", {}).get("text", "") + question = item.get("question", {}).get("text", "") + + if context and question: + # Use the full context as-is + prompt = f"Context: {context}\n\nQuestion: {question}\n\nAnswer:" + samples.append(prompt) + + if not samples: + raise ValueError("Could not load NarrativeQA samples") + + print(f"Loaded {len(samples)} NarrativeQA samples") + return samples + + +def truncate_text(text: str, tokenizer, max_length: int): + """Truncate text from the middle to preserve beginning and end. + + Args: + text: Input text to truncate + tokenizer: Tokenizer to use for encoding + max_length: Maximum number of tokens + + Returns: + Truncated text that fits within max_length tokens + """ + # First tokenize to see if truncation is needed + tokens = tokenizer.encode(text, add_special_tokens=True) + + if len(tokens) <= max_length: + return text + + # Need to truncate - preserve beginning and end + # Reserve some tokens for special tokens + available_tokens = max_length - 2 # Account for special tokens + + # Split tokens roughly in half for beginning and end + begin_tokens = available_tokens // 2 + end_tokens = available_tokens - begin_tokens + + # Decode beginning and end parts + begin_text = tokenizer.decode(tokens[:begin_tokens], skip_special_tokens=True) + end_text = tokenizer.decode(tokens[-end_tokens:], skip_special_tokens=True) + + # Combine with ellipsis marker + return begin_text + " [...] " + end_text + + +def verify_outputs(model, tokenizer, args): + """Compare outputs between baseline and sparse attention models.""" + # Update seq_len to match calibration max_seqlen if calibration was used + base_config = SPARSE_ATTN_CFG_CHOICES.get(args.sparse_attn, {}) + if "calibration" in base_config and "max_seqlen" in base_config["calibration"]: + calib_max_seqlen = base_config["calibration"]["max_seqlen"] + if args.seq_len != calib_max_seqlen: + print( + f"\nNote: Updating test seq_len from {args.seq_len} to {calib_max_seqlen} " + f"to match calibration config" + ) + args.seq_len = calib_max_seqlen + + # Load and prepare a single test prompt + print(f"\nLoading test sample (will be tokenized up to {args.seq_len} tokens)") + prompts = get_narrativeqa_samples(num_samples=1) + prompt = prompts[0] + + # Prepare inputs + truncated_prompt = truncate_text(prompt, tokenizer, args.seq_len) + display_prompt = ( + truncated_prompt[:150] + "..." if len(truncated_prompt) > 150 else truncated_prompt + ) + + inputs = tokenizer( + truncated_prompt, + return_tensors="pt", + max_length=args.seq_len, + truncation=True, + padding=False, + ) + if torch.cuda.is_available(): + inputs = {k: v.cuda() for k, v in inputs.items()} + + print("\n" + "=" * 60) + print("BASELINE vs SPARSE ATTENTION COMPARISON") + print("=" * 60) + print(f"\nTest prompt: {display_prompt}") + print(f"Input tokens: {inputs['input_ids'].shape[1]}") + + # Helper function to generate text + def generate_text(model, inputs, args, tokenizer): + with torch.no_grad(): + outputs = model.generate( + **inputs, + max_new_tokens=args.max_new_tokens, + do_sample=args.do_sample, + temperature=args.temperature if args.do_sample else 1.0, + pad_token_id=tokenizer.pad_token_id, + ) + input_length = inputs["input_ids"].shape[1] + generated_ids = outputs[0][input_length:] + return tokenizer.decode(generated_ids, skip_special_tokens=True) + + # Find all sparse attention modules + sparse_modules = [m for m in model.modules() if isinstance(m, SparseAttentionModule)] + + # Generate baseline by temporarily disabling sparse attention + print("\n" + "-" * 60) + print("Generating baseline (sparse attention disabled)...") + for module in sparse_modules: + module.disable() + baseline_text = generate_text(model, inputs, args, tokenizer) + + # Generate with sparse attention enabled + print("\nGenerating with sparse attention (calibrated thresholds)...") + for module in sparse_modules: + module.enable() + sparse_text = generate_text(model, inputs, args, tokenizer) + + # Display comparison + print("\n" + "-" * 60) + print("RESULTS:") + baseline_display = baseline_text[:300] + "..." if len(baseline_text) > 300 else baseline_text + sparse_display = sparse_text[:300] + "..." if len(sparse_text) > 300 else sparse_text + + print(f"\nBaseline: {baseline_display}") + print(f"With Sparse: {sparse_display}") + + if baseline_text == sparse_text: + print("\nOutputs are identical") + else: + print("\nOutputs differ") + + +def main(args): + """Main function to run the selected mode.""" + if not torch.cuda.is_available(): + raise OSError("GPU is required for inference.") + + random.seed(RAND_SEED) + np.random.seed(RAND_SEED) + launch_memory_monitor() + + print(f"Loading model: {args.pyt_ckpt_path}") + + # Load model and tokenizer + # Note: attn_implementation="eager" is required for calibration to work properly + # (flash_attention_2 or sdpa would bypass the softmax patching needed for stats collection) + model = AutoModelForCausalLM.from_pretrained( + args.pyt_ckpt_path, + attn_implementation="eager", + torch_dtype=torch.bfloat16, + ) + tokenizer = AutoTokenizer.from_pretrained(args.pyt_ckpt_path) + + # Set pad token if not set + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + + # Move model to GPU if available + if torch.cuda.is_available(): + model = model.cuda() + print("Model moved to CUDA") + + # Apply sparse attention with optional calibration + print(f"\nApplying sparse attention: {args.sparse_attn}") + sparse_config = SPARSE_ATTN_CFG_CHOICES[args.sparse_attn] + model = mtsa.sparsify(model, config=sparse_config) + + print("Sparse attention applied successfully!") + + # Verify outputs if requested (compares baseline vs calibrated sparse model) + if args.verify_output: + verify_outputs(model, tokenizer, args) + + # Export if requested + if args.export_dir: + print(f"\nExporting model to: {args.export_dir}") + export_dir = Path(args.export_dir) + export_dir.mkdir(parents=True, exist_ok=True) + + with torch.inference_mode(): + export_hf_checkpoint(model, export_dir=export_dir) + + tokenizer.save_pretrained(export_dir) + print(f"Model exported successfully to: {export_dir}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description=__doc__) + + # Model arguments + parser.add_argument( + "--pyt_ckpt_path", + type=str, + required=True, + help="Specify where the PyTorch checkpoint path is", + ) + parser.add_argument( + "--sparse_attn", + type=str, + default="skip_softmax", + choices=list(SPARSE_ATTN_CFG_CHOICES.keys()), + help="Sparse attention configuration to apply.", + ) + parser.add_argument( + "--backend", + type=str, + default="pytorch", + choices=["pytorch", "triton"], + help="Backend to use for sparse attention computation (default: pytorch)", + ) + + # Sequence length arguments + parser.add_argument( + "--seq_len", + type=int, + default=2048, + help="Maximum sequence length for input prompts (will be truncated if longer)", + ) + parser.add_argument( + "--num_samples", + type=int, + default=3, + help="Number of samples to use from NarrativeQA dataset", + ) + + # Generation arguments + parser.add_argument( + "--max_new_tokens", type=int, default=50, help="Maximum new tokens to generate" + ) + parser.add_argument("--do_sample", action="store_true", help="Use sampling for generation") + parser.add_argument("--temperature", type=float, default=0.7, help="Temperature for sampling") + + # Operation arguments + parser.add_argument( + "--verify_output", + action="store_true", + help="Verify that sparse attention outputs match baseline", + ) + parser.add_argument( + "--export_dir", + type=str, + default=None, + help="Directory to export the model with sparse attention applied", + ) + + args = parser.parse_args() + main(args) diff --git a/examples/llm_sparsity/attention_sparsity/requirements.txt b/examples/llm_sparsity/attention_sparsity/requirements.txt new file mode 100644 index 000000000..a3e0dfa17 --- /dev/null +++ b/examples/llm_sparsity/attention_sparsity/requirements.txt @@ -0,0 +1,2 @@ +nltk +wonderwords diff --git a/examples/llm_sparsity/.gitignore b/examples/llm_sparsity/weight_sparsity/.gitignore similarity index 100% rename from examples/llm_sparsity/.gitignore rename to examples/llm_sparsity/weight_sparsity/.gitignore diff --git a/examples/llm_sparsity/README.md b/examples/llm_sparsity/weight_sparsity/README.md similarity index 100% rename from examples/llm_sparsity/README.md rename to examples/llm_sparsity/weight_sparsity/README.md diff --git a/examples/llm_sparsity/data_prep.py b/examples/llm_sparsity/weight_sparsity/data_prep.py similarity index 100% rename from examples/llm_sparsity/data_prep.py rename to examples/llm_sparsity/weight_sparsity/data_prep.py diff --git a/examples/llm_sparsity/eval.py b/examples/llm_sparsity/weight_sparsity/eval.py similarity index 100% rename from examples/llm_sparsity/eval.py rename to examples/llm_sparsity/weight_sparsity/eval.py diff --git a/examples/llm_sparsity/export_trtllm_ckpt.py b/examples/llm_sparsity/weight_sparsity/export_trtllm_ckpt.py similarity index 100% rename from examples/llm_sparsity/export_trtllm_ckpt.py rename to examples/llm_sparsity/weight_sparsity/export_trtllm_ckpt.py diff --git a/examples/llm_sparsity/finetune.py b/examples/llm_sparsity/weight_sparsity/finetune.py similarity index 95% rename from examples/llm_sparsity/finetune.py rename to examples/llm_sparsity/weight_sparsity/finetune.py index 3cfc1073f..869068dbd 100644 --- a/examples/llm_sparsity/finetune.py +++ b/examples/llm_sparsity/weight_sparsity/finetune.py @@ -1,3 +1,18 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + # Adapted from https://github.com/tatsu-lab/stanford_alpaca/blob/3783d18/train.py # Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li diff --git a/examples/llm_sparsity/hf_pts.py b/examples/llm_sparsity/weight_sparsity/hf_pts.py similarity index 100% rename from examples/llm_sparsity/hf_pts.py rename to examples/llm_sparsity/weight_sparsity/hf_pts.py diff --git a/examples/llm_sparsity/launch_finetune.sh b/examples/llm_sparsity/weight_sparsity/launch_finetune.sh similarity index 100% rename from examples/llm_sparsity/launch_finetune.sh rename to examples/llm_sparsity/weight_sparsity/launch_finetune.sh diff --git a/examples/llm_sparsity/requirements.txt b/examples/llm_sparsity/weight_sparsity/requirements.txt similarity index 100% rename from examples/llm_sparsity/requirements.txt rename to examples/llm_sparsity/weight_sparsity/requirements.txt diff --git a/examples/llm_sparsity/utils.py b/examples/llm_sparsity/weight_sparsity/utils.py similarity index 100% rename from examples/llm_sparsity/utils.py rename to examples/llm_sparsity/weight_sparsity/utils.py diff --git a/modelopt/torch/sparsity/attention_sparsity/__init__.py b/modelopt/torch/sparsity/attention_sparsity/__init__.py new file mode 100644 index 000000000..150f93a3a --- /dev/null +++ b/modelopt/torch/sparsity/attention_sparsity/__init__.py @@ -0,0 +1,24 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Extensible sparse attention optimization for transformer models.""" + +# Initialize mode +from . import mode + +# Add methods to namespace +from .config import * +from .conversion import * +from .model_sparsify import * diff --git a/modelopt/torch/sparsity/attention_sparsity/calibration/__init__.py b/modelopt/torch/sparsity/attention_sparsity/calibration/__init__.py new file mode 100644 index 000000000..3b616e8e3 --- /dev/null +++ b/modelopt/torch/sparsity/attention_sparsity/calibration/__init__.py @@ -0,0 +1,26 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Calibration framework for sparse attention methods.""" + +from .calibrate import calibrate_sparse_attention +from .calibrator import DynamicThresholdCalibrator +from .dataset import RulerDatasetBuilder + +__all__ = [ + "DynamicThresholdCalibrator", + "RulerDatasetBuilder", + "calibrate_sparse_attention", +] diff --git a/modelopt/torch/sparsity/attention_sparsity/calibration/calibrate.py b/modelopt/torch/sparsity/attention_sparsity/calibration/calibrate.py new file mode 100644 index 000000000..f6e66ae00 --- /dev/null +++ b/modelopt/torch/sparsity/attention_sparsity/calibration/calibrate.py @@ -0,0 +1,186 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Calibration functions for sparse attention.""" + +import warnings +from collections.abc import Callable +from typing import Any + +import torch +import torch.nn as nn +from transformers import AutoTokenizer + +from ..config import CalibrationConfig +from ..conversion import print_sparse_attention_summary +from ..sparse_attention import SparseAttentionModule +from .calibrator import DynamicThresholdCalibrator +from .dataset import RulerDatasetBuilder + + +def _extract_tokenizer_from_model(model: nn.Module) -> str: + """Extract tokenizer name/path from model config. + + Args: + model: Model to extract tokenizer from + + Returns: + Tokenizer name or path + + Raises: + ValueError: If tokenizer path cannot be determined from model + """ + # Extract tokenizer path from model config + tokenizer_path = getattr(getattr(model, "config", None), "_name_or_path", None) + + if not tokenizer_path: + raise ValueError("Could not load tokenizer from model.") + + return tokenizer_path + + +def _extract_calibration_config(config: dict[str, Any]) -> CalibrationConfig | None: + """Extract and validate calibration config from sparse_cfg. + + Args: + config: Sparse attention configuration dict + + Returns: + Validated CalibrationConfig instance, or None if calibration is not configured + + Raises: + ValueError: If calibration config has invalid type or contains invalid values + """ + sparse_cfg = config.get("sparse_cfg", {}) + + # Calibration is optional + if "calibration" not in sparse_cfg: + return None + + calib_dict = sparse_cfg["calibration"] + + # Validate calibration is a dict + if not isinstance(calib_dict, dict): + raise ValueError(f"Calibration config must be a dict, got {type(calib_dict).__name__}. ") + + # Create and validate CalibrationConfig + return CalibrationConfig(**calib_dict) + + +def create_calibration_forward_loop( + calibration_data: list[dict[str, Any]], + tokenizer_name_or_path: str, + batch_size: int = 1, +) -> Callable: + """Create forward loop for calibration. + + Args: + calibration_data: List of samples with 'input' and 'length' fields + tokenizer_name_or_path: HuggingFace tokenizer path + batch_size: Batch size (currently unused, always 1) + + Returns: + Forward loop function that takes model as argument + """ + tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path) + if not tokenizer.pad_token: + tokenizer.pad_token = tokenizer.eos_token + + def forward_loop(model: nn.Module) -> None: + device = next(model.parameters()).device + + for sample in calibration_data: + inputs = tokenizer( + sample["input"], return_tensors="pt", truncation=True, max_length=sample["length"] + ) + inputs = {k: v.to(device) for k, v in inputs.items()} + + with torch.no_grad(): + model(**inputs) + + return forward_loop + + +def calibrate_sparse_attention( + model: nn.Module, + config: dict[str, Any], + forward_loop: Callable | None = None, +) -> dict[str, Any]: + """Calibrate sparse attention parameters for optimal sparsity. + + Args: + model: Model with sparse attention modules + config: Sparse attention configuration dict + forward_loop: Callable that forwards calibration data through model. + If None, auto-generates RULER dataset. + + Returns: + Dictionary with calibration results + """ + # Extract and validate calibration config + calib_config = _extract_calibration_config(config) + + # Skip calibration if not configured + if calib_config is None: + return {} + + # Generate forward_loop if not provided + if not forward_loop: + tokenizer = _extract_tokenizer_from_model(model) + builder = RulerDatasetBuilder( + samples=calib_config.samples, + max_seqlen=calib_config.max_seqlen, + tokenizer_name_or_path=tokenizer, + num_length_bins=calib_config.num_length_bins, + max_length_filter=int(calib_config.max_seqlen * 1.5), + ) + calibration_data = builder.build_calibration_dataset() + print(f"Generated {len(calibration_data)} calibration samples") + forward_loop = create_calibration_forward_loop(calibration_data, tokenizer) + + # Get sparse attention modules + sparse_modules = [ + (name, m) for name, m in model.named_modules() if isinstance(m, SparseAttentionModule) + ] + + if not sparse_modules: + print("No sparse attention modules found for calibration") + return {} + + print(f"Calibrating {len(sparse_modules)} sparse attention modules together...") + + # Run calibration + calibrator = DynamicThresholdCalibrator( + target_sparse_ratio=calib_config.target_sparse_ratio, + threshold_trials=calib_config.threshold_trials, + ) + calibration_result = calibrator.calibrate(model, forward_loop) + + # Print calibration statistics (regardless of success/failure for debugging) + print("\nCalibration complete!") + print_sparse_attention_summary(model) + + if "scale_factor" not in calibration_result: + warnings.warn("Calibration did not produce valid results") + return {} + + # Apply calibrated scale factor to all modules + scale_factor = calibration_result["scale_factor"] + print(f"\nApplying calibrated scale factor={scale_factor:.6f} to {len(sparse_modules)} modules") + + for module_name, module in sparse_modules: + module._sparse_method_instance.threshold_scale_factor = scale_factor + + return {"calibration_results": {name: calibration_result for name, _ in sparse_modules}} diff --git a/modelopt/torch/sparsity/attention_sparsity/calibration/calibrator.py b/modelopt/torch/sparsity/attention_sparsity/calibration/calibrator.py new file mode 100644 index 000000000..2914651f7 --- /dev/null +++ b/modelopt/torch/sparsity/attention_sparsity/calibration/calibrator.py @@ -0,0 +1,307 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Calibration framework for sparse attention methods.""" + +import warnings +from collections.abc import Callable +from typing import Any + +import numpy as np +import torch +import torch.nn as nn +from tqdm import tqdm + +from ..sparse_attention import SparseAttentionModule +from ..stats_manager import SparseAttentionStatsManager + + +class DynamicThresholdCalibrator: + """Dynamic threshold calibrator using length-based linear relationship. + + Implements calibration algorithm: + 1. Find hyperparameter 'a' where threshold λ = a / context_length + 2. Use dataset with different lengths and test multiple thresholds + 3. For each sample, find optimal threshold closest to target sparsity + 4. Use linear regression to fit: threshold = a * (1/length) + """ + + def __init__( + self, + target_sparse_ratio: float = 0.5, + threshold_trials: list[float] | None = None, + ): + """Initialize dynamic threshold calibrator. + + Args: + target_sparse_ratio: Target sparsity ratio (0.0 to 1.0) + threshold_trials: List of thresholds to try during calibration + + Note: + Calibration only supports prefill phase (seq_len > 1). + Decode phase uses the same calibrated threshold. + """ + self.target_sparse_ratio = target_sparse_ratio + + # Default threshold trials if not provided + self.threshold_trials = threshold_trials or [ + 1e-6, + 5e-6, + 1e-5, + 5e-5, + 1e-4, + 5e-4, + 1e-3, + 5e-3, + 1e-2, + 5e-2, + 1e-1, + 5e-1, + ] + + # Statistics tracking + self.sparsity_results = [] + + def calibrate(self, model: nn.Module, forward_loop: Callable) -> dict[str, Any]: + """Find optimal 'a' parameter for length-based threshold. + + Algorithm: + 1. Test all threshold trials by running forward_loop multiple times + 2. For each sample, find optimal threshold closest to target sparsity + 3. Use regression to find 'a' in: threshold = a / length + + Args: + model: The model with sparse attention modules + forward_loop: Callable that takes model and forwards calibration data + """ + # Extract attention modules + attention_modules = [m for m in model.modules() if isinstance(m, SparseAttentionModule)] + + if not attention_modules: + raise ValueError("No sparse attention modules found for calibration") + + print("Starting dynamic threshold calibration") + print(f"Target sparsity: {self.target_sparse_ratio}") + print(f"Threshold trials: {len(self.threshold_trials)}") + + # Stage 1: Collect sparsity for all sample-threshold pairs + print("\nStage 1: Collecting sparsity data...") + self.sparsity_results = [] + + # For each threshold, run forward_loop and collect per-sample statistics + for threshold_idx, threshold in enumerate( + tqdm(self.threshold_trials, desc="Testing thresholds") + ): + # Set threshold and enable calibration mode + self._set_threshold(attention_modules, threshold) + self._enable_calibration_mode(attention_modules) + + # Run forward loop and collect stats + with torch.no_grad(): + forward_loop(model) + per_sample_stats = self._extract_calibration_stats(attention_modules) + self._disable_calibration_mode(attention_modules) + + # Store results + for sample_idx, sample_stat in enumerate(per_sample_stats): + if threshold_idx == 0: + # Initialize sample entry on first threshold + sample_length = sample_stat.get("sample_length", 0) + if sample_length > 0: + self.sparsity_results.append( + { + "sample_index": sample_idx, + "length": sample_length, + "threshold_sparsities": {}, + } + ) + + # Add sparsity for this threshold + if sample_idx < len(self.sparsity_results): + sparsity = sample_stat.get("sparsity", 0.0) + self.sparsity_results[sample_idx]["threshold_sparsities"][threshold] = sparsity + + if not self.sparsity_results: + warnings.warn("No valid sparsity measurements collected during calibration") + return {} + + print(f"Collected statistics for {len(self.sparsity_results)} samples") + + # Stage 2: Find optimal threshold for each sample and compute 'a' + print( + f"\nStage 2: Finding 'a' parameter for target sparsity {self.target_sparse_ratio:.2f}" + ) + + # Find optimal threshold for each sample + optimal_pairs = [] + for sample_result in self.sparsity_results: + # Find threshold closest to target sparsity + best_threshold, achieved_sparsity = min( + sample_result["threshold_sparsities"].items(), + key=lambda item: abs(item[1] - self.target_sparse_ratio), + ) + + optimal_pairs.append( + { + "length": sample_result["length"], + "optimal_threshold": best_threshold, + "achieved_sparsity": achieved_sparsity, + "target_sparsity": self.target_sparse_ratio, + } + ) + + if not optimal_pairs: + warnings.warn( + f"No optimal threshold pairs found for target sparsity {self.target_sparse_ratio}. " + f"Collected {len(self.sparsity_results)} samples but none achieved target sparsity." + ) + return {} + + # Linear regression: threshold = a * (1/length) + lengths = np.array([p["length"] for p in optimal_pairs]) + thresholds = np.array([p["optimal_threshold"] for p in optimal_pairs]) + + # X = 1/length, Y = threshold + x = 1.0 / lengths + y = thresholds + + # Least squares: scale_factor = sum(x*y) / sum(x^2) + scale_factor = np.sum(x * y) / np.sum(x**2) + + # Calculate statistics + scale_factors_per_sample = y * lengths + scale_factor_std = np.std(scale_factors_per_sample) + + # Calculate R-squared for quality metric + y_pred = scale_factor * x + ss_res = np.sum((y - y_pred) ** 2) + ss_tot = np.sum((y - np.mean(y)) ** 2) + r_squared = 1 - (ss_res / ss_tot) if ss_tot > 0 else 0 + + # Calculate average achieved sparsity + avg_achieved_sparsity = np.mean([p["achieved_sparsity"] for p in optimal_pairs]) + + print("\nCalibration Results:") + print(f" Threshold scale factor: {scale_factor:.6f} (std: {scale_factor_std:.6f})") + print(f" R-squared: {r_squared:.4f}") + print( + f" Average achieved sparsity: {avg_achieved_sparsity:.2%} (target: {self.target_sparse_ratio:.2%})" + ) + print(f"\nExample thresholds with λ = {scale_factor:.6f} / length:") + for length in [1024, 2048, 4096, 8192, 16384]: + print(f" Length {length:5d}: threshold = {scale_factor / length:.2e}") + + # Apply the calibrated scale factor to modules + self._apply_length_based_calibration(attention_modules, scale_factor) + + return { + "scale_factor": scale_factor, + "scale_factor_std": scale_factor_std, + "r_squared": r_squared, + "num_samples": len(optimal_pairs), + "target_sparsity": self.target_sparse_ratio, + "avg_achieved_sparsity": avg_achieved_sparsity, + "optimal_pairs": optimal_pairs, + "calibration_type": "length_based_dynamic", + } + + def _apply_length_based_calibration(self, modules: list[nn.Module], scale_factor: float): + """Apply calibrated threshold scale factor to modules. + + Args: + modules: List of attention modules + scale_factor: Calibrated scale factor for λ = scale_factor / length + """ + for module in modules: + module._sparse_method_instance.threshold_scale_factor = scale_factor + + def _enable_calibration_mode(self, modules: list[nn.Module]): + """Enable calibration mode on sparse attention modules.""" + for idx, module in enumerate(modules): + # Create stats manager if needed + if not module._stats_manager: + module._stats_manager = SparseAttentionStatsManager( + module_name=f"sparse_attn_{idx}", enabled=True + ) + else: + # Re-enable if disabled + module._stats_manager.enabled = True + + # Enable calibration mode with fresh stats + module._stats_manager.set_calibration_mode(enabled=True, reset_history=True) + module._sparse_method_instance.set_calibration_mode(True) + + def _disable_calibration_mode(self, modules: list[nn.Module]): + """Disable calibration mode (but keep stats enabled if collect_stats=True).""" + for module in modules: + if module._stats_manager: + module._stats_manager.set_calibration_mode(enabled=False) + + module._sparse_method_instance.set_calibration_mode(False) + + def _extract_calibration_stats(self, modules: list[nn.Module]) -> list[dict]: + """Extract per-sample calibration statistics from modules. + + Args: + modules: List of attention modules + + Returns: + List of per-sample statistics across all modules + """ + # Collect from all stats managers + all_per_sample_stats = [] + + for module in modules: + # Skip modules without stats manager + if not hasattr(module, "_stats_manager") or module._stats_manager is None: + continue + + manager_stats = module._stats_manager.get_calibration_stats() + if manager_stats: + all_per_sample_stats.append(manager_stats) + + if not all_per_sample_stats: + return [] + + # Aggregate across modules by sample index + num_samples = len(all_per_sample_stats[0]) + aggregated_stats = [] + + for sample_idx in range(num_samples): + sparsities = [] + sample_length = 0 + + for module_stats in all_per_sample_stats: + if sample_idx < len(module_stats): + sample_stat = module_stats[sample_idx] + sparsities.append(sample_stat.get("sparsity", 0.0)) + if not sample_length and "sample_length" in sample_stat: + sample_length = sample_stat["sample_length"] + + avg_sparsity = np.mean(sparsities) if sparsities else 0.0 + + aggregated_stats.append( + { + "sparsity": avg_sparsity, + "sample_length": sample_length, + } + ) + + return aggregated_stats + + def _set_threshold(self, modules: list[nn.Module], threshold: float): + """Set threshold on sparse attention modules.""" + for module in modules: + module._sparse_method_instance.threshold = threshold diff --git a/modelopt/torch/sparsity/attention_sparsity/calibration/dataset.py b/modelopt/torch/sparsity/attention_sparsity/calibration/dataset.py new file mode 100644 index 000000000..7603b4e1d --- /dev/null +++ b/modelopt/torch/sparsity/attention_sparsity/calibration/dataset.py @@ -0,0 +1,546 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""RULER dataset builder for sparse attention calibration.""" + +import random +import string +from dataclasses import dataclass +from typing import Any + +from tqdm import tqdm +from transformers import AutoTokenizer + +from . import ruler_utils + + +def _generate_target_lengths( + max_seqlen: int, num_length_bins: int = 4, min_seqlen: int = 1024 +) -> list[int]: + """Generate target lengths as descending powers of 2. + + Args: + max_seqlen: Maximum sequence length + num_length_bins: Maximum number of length bins to generate + min_seqlen: Minimum sequence length threshold + + Returns: + List of target lengths in descending order + + Examples: + >>> _generate_target_lengths(32768, 4) + [32768, 16384, 8192, 4096] + >>> _generate_target_lengths(2048, 4) + [2048, 1024] + """ + target_lengths = [] + current = max_seqlen + + for _ in range(num_length_bins): + if current < min_seqlen: + break + target_lengths.append(current) + current = current // 2 + + return target_lengths + + +@dataclass +class RulerTask: + """Configuration for a RULER task.""" + + name: str + task_type: str # niah, variable_tracking, freq_words_extraction, qa + tokens_to_generate: int + template: str + answer_prefix: str + args: dict[str, Any] + + +# Task configurations based on RULER benchmark +RULER_TASKS = { + "niah_multikey_2": RulerTask( + name="niah_multikey_2", + task_type="niah", + tokens_to_generate=128, + template=( + "Some special magic {type_needle_v} are hidden within the following text. " + "Make sure to memorize it. I will quiz you about the {type_needle_v} afterwards.\n" + "{context}\n" + "What are all the special magic {type_needle_v} for {query} mentioned in the provided text?" + ), + answer_prefix=( + " The special magic {type_needle_v} for {query} mentioned in the provided text are" + ), + args={ + "type_haystack": "needle", + "type_needle_k": "words", + "type_needle_v": "numbers", + "num_needle_k": 1, + "num_needle_v": 1, + "num_needle_q": 1, + }, + ), + "niah_multikey_3": RulerTask( + name="niah_multikey_3", + task_type="niah", + tokens_to_generate=128, + template=( + "Some special magic {type_needle_v} are hidden within the following text. " + "Make sure to memorize it. I will quiz you about the {type_needle_v} afterwards.\n" + "{context}\n" + "What are all the special magic {type_needle_v} for {query} mentioned in the provided text?" + ), + answer_prefix=( + " The special magic {type_needle_v} for {query} mentioned in the provided text are" + ), + args={ + "type_haystack": "needle", + "type_needle_k": "uuids", + "type_needle_v": "uuids", + "num_needle_k": 1, + "num_needle_v": 1, + "num_needle_q": 1, + }, + ), + "vt": RulerTask( + name="vt", + task_type="variable_tracking", + tokens_to_generate=30, + template=( + "Memorize and track the chain(s) of variable assignment hidden in the following text.\n\n" + "{context}\n" + "Question: Find all variables that are assigned the value {query} in the text above." + ), + answer_prefix=( + " Answer: According to the chain(s) of variable assignment in the text above, " + "{num_v} variables are assgined the value {query}, they are: " + ), + args={"num_chains": 1, "num_hops": 4}, + ), + "fwe": RulerTask( + name="fwe", + task_type="freq_words_extraction", + tokens_to_generate=50, + template=( + "Read the following coded text and track the frequency of each coded word. " + "Find the three most frequently appeared coded words. {context}\n" + "Question: Do not provide any explanation. Please ignore the dots '....'. " + "What are the three most frequently appeared words in the above coded text?" + ), + answer_prefix=( + " Answer: According to the coded text above, " + "the three most frequently appeared words are:" + ), + args={"alpha": 2.0}, + ), + "qa_1": RulerTask( + name="qa_1", + task_type="qa", + tokens_to_generate=32, + template=( + "Answer the question based on the given documents. " + "Only give me the answer and do not output any other words.\n\n" + "The following are given documents.\n\n{context}\n\n" + "Answer the question based on the given documents. " + "Only give me the answer and do not output any other words.\n\n" + "Question: {query}" + ), + answer_prefix=" Answer:", + args={"dataset": "squad"}, + ), + "qa_2": RulerTask( + name="qa_2", + task_type="qa", + tokens_to_generate=32, + template=( + "Answer the question based on the given documents. " + "Only give me the answer and do not output any other words.\n\n" + "The following are given documents.\n\n{context}\n\n" + "Answer the question based on the given documents. " + "Only give me the answer and do not output any other words.\n\n" + "Question: {query}" + ), + answer_prefix=" Answer:", + args={"dataset": "hotpotqa"}, + ), +} + + +class RulerDatasetBuilder: + """Builder for RULER calibration datasets.""" + + def __init__( + self, + samples: int, + max_seqlen: int, + tokenizer_name_or_path: str | object, + num_length_bins: int = 4, + max_length_filter: int = 65536, + seed: int = 42, + ): + """Initialize RULER dataset builder. + + Args: + samples: Total number of samples to generate (distributed evenly across length bins) + max_seqlen: Maximum sequence length (length bins auto-generated as powers of 2) + tokenizer_name_or_path: HuggingFace tokenizer path or tokenizer object + seed: Random seed for reproducibility + num_length_bins: Number of length bins to generate (default: 4) + max_length_filter: Maximum sequence length to keep (default: 65536) + + Note: + Length bins are auto-generated as descending powers of 2: + [max_seqlen, max_seqlen/2, max_seqlen/4, ...] + Generation stops when num_length_bins is reached or length < 1024. + Subtasks are set to all the difficult tasks defined in RULER_TASKS. + """ + # Validate inputs + if samples <= 0: + raise ValueError(f"samples must be positive, got {samples}") + if max_seqlen < 1024: + raise ValueError(f"max_seqlen must be >= 1024, got {max_seqlen}") + + # Store parameters + self.total_samples = samples + self.max_seqlen = max_seqlen + self.num_length_bins = num_length_bins + self.subtasks = list(RULER_TASKS.keys()) + self.tokenizer_name_or_path = tokenizer_name_or_path + self.seed = seed + self.max_length_filter = max_length_filter + + # Generate target lengths and validate + self.target_lengths = _generate_target_lengths(max_seqlen, num_length_bins, min_seqlen=1024) + if not self.target_lengths: + raise ValueError(f"No valid target lengths generated from max_seqlen={max_seqlen}") + + # Distribute samples evenly across lengths + self.samples_per_length = [samples // len(self.target_lengths)] * len(self.target_lengths) + + # Initialize tokenizer + if isinstance(tokenizer_name_or_path, str): + self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path) + else: + self.tokenizer = tokenizer_name_or_path + random.seed(seed) + + def build_calibration_dataset(self) -> list[dict[str, Any]]: + """Build the complete calibration dataset. + + Returns: + List of calibration samples with 'input' and 'length' fields + """ + all_samples = [] + + # Generate calibration samples + for num_samples, target_length in tqdm( + zip(self.samples_per_length, self.target_lengths), + desc="Generating RULER calibration samples", + total=len(self.target_lengths), + ): + samples_per_task = max(num_samples // len(self.subtasks), 1) + + # Generate equal samples for each task + for task_name in self.subtasks: + for sample_idx in range(samples_per_task): + sample = self._generate_sample(task_name, target_length, sample_idx) + if sample and sample["length"] <= self.max_length_filter: + all_samples.append(sample) + + random.shuffle(all_samples) + return all_samples + + def _generate_sample( + self, task_name: str, target_length: int, sample_idx: int + ) -> dict[str, Any]: + """Generate a single RULER sample. + + Args: + task_name: Name of the RULER task + target_length: Target sequence length in tokens + sample_idx: Index of the sample (for uniqueness) + + Returns: + Dict with 'input', 'length', and metadata fields + """ + task = RULER_TASKS[task_name] + + if task.task_type == "niah": + return self._generate_niah_sample(task, target_length, sample_idx) + elif task.task_type == "variable_tracking": + return self._generate_vt_sample(task, target_length, sample_idx) + elif task.task_type == "freq_words_extraction": + return self._generate_fwe_sample(task, target_length, sample_idx) + elif task.task_type == "qa": + return self._generate_qa_sample(task, target_length, sample_idx) + else: + raise ValueError(f"Unknown task type: {task.task_type}") + + def _generate_niah_sample( + self, task: RulerTask, target_length: int, sample_idx: int + ) -> dict[str, Any]: + """Generate a needle-in-haystack sample.""" + args = task.args + + # Find optimal haystack size for target length + optimal_haystack = ruler_utils.find_optimal_haystack_size( + tokenizer=self.tokenizer, + max_seq_length=target_length, + template=task.template, + answer_prefix=task.answer_prefix, + tokens_to_generate=task.tokens_to_generate, + type_haystack=args.get("type_haystack", "essay"), + type_needle_k=args.get("type_needle_k", "words"), + type_needle_v=args.get("type_needle_v", "numbers"), + num_needle_k=args.get("num_needle_k", 1), + num_needle_v=args.get("num_needle_v", 1), + num_needle_q=args.get("num_needle_q", 1), + ) + + # Generate sample using official RULER implementation + sample = ruler_utils.generate_niah_sample( + num_haystack=optimal_haystack, + tokenizer=self.tokenizer, + template=task.template, + answer_prefix=task.answer_prefix, + tokens_to_generate=task.tokens_to_generate, + type_haystack=args.get("type_haystack", "essay"), + type_needle_k=args.get("type_needle_k", "words"), + type_needle_v=args.get("type_needle_v", "numbers"), + num_needle_k=args.get("num_needle_k", 1), + num_needle_v=args.get("num_needle_v", 1), + num_needle_q=args.get("num_needle_q", 1), + random_seed=self.seed + sample_idx, + ) + + # Add task metadata + sample["task"] = task.name + sample["target_length"] = target_length + sample["sample_idx"] = sample_idx + + return sample + + def _generate_vt_sample( + self, task: RulerTask, target_length: int, sample_idx: int + ) -> dict[str, Any]: + """Generate a variable tracking sample.""" + args = task.args + num_chains = args["num_chains"] + num_hops = args["num_hops"] + + # Generate variable chains + variables = [] + chains = [] + for _ in range(num_chains): + chain = [self._generate_random_variable() for _ in range(num_hops + 1)] + variables.extend(chain) + chains.append(chain) + + # Generate assignments + assignments = [ + f"VAR {chain[i]} = {chain[i + 1]}" for chain in chains for i in range(len(chain) - 1) + ] + + # Create context with padding + context = self._pad_context_with_text( + "\n".join(assignments), target_length, "variable tracking context" + ) + + # Select a query value + query_value = random.choice([chain[-1] for chain in chains]) + + # Format template + template = task.template.format(context=context, query=query_value) + + # Count variables with the query value + num_v = sum(1 for chain in chains if chain[-1] == query_value) + + # Add answer prefix + full_input = template + task.answer_prefix.format(num_v=num_v, query=query_value) + + # Tokenize to get actual length + tokens = self.tokenizer.encode(full_input, add_special_tokens=False) + + return { + "input": full_input, + "length": len(tokens), + "task": task.name, + "target_length": target_length, + "sample_idx": sample_idx, + } + + def _generate_fwe_sample( + self, task: RulerTask, target_length: int, sample_idx: int + ) -> dict[str, Any]: + """Generate a frequency word extraction sample.""" + # Generate coded words with frequencies + num_unique_words = 50 + coded_words = [self._generate_coded_word() for _ in range(num_unique_words)] + + # Assign frequencies (make top 3 clearly more frequent) + frequencies = {} + for i, word in enumerate(coded_words): + if i < 3: + frequencies[word] = random.randint(20, 30) # High frequency + else: + frequencies[word] = random.randint(1, 10) # Low frequency + + # Generate the coded text + word_list = [] + for word, freq in frequencies.items(): + word_list.extend([word] * freq) + random.shuffle(word_list) + + # Add dots for separation + coded_text = " .... ".join(word_list) + + # Pad to target length + context = self._pad_context_with_text(coded_text, target_length, "coded text padding") + + # Format template + template = task.template.format(context=context) + full_input = template + task.answer_prefix + + # Tokenize to get actual length + tokens = self.tokenizer.encode(full_input, add_special_tokens=False) + + return { + "input": full_input, + "length": len(tokens), + "task": task.name, + "target_length": target_length, + "sample_idx": sample_idx, + } + + def _generate_qa_sample( + self, task: RulerTask, target_length: int, sample_idx: int + ) -> dict[str, Any]: + """Generate a QA sample.""" + # Generate synthetic documents + num_docs = 5 + documents = [] + + # Create a simple QA pair + answer = self._generate_random_phrase() + question = f"What is the special code mentioned in document {random.randint(1, num_docs)}?" + + for i in range(num_docs): + doc_text = self._generate_document_text(200) # Base document + if i == 2: # Insert answer in one document + doc_text += f" The special code is {answer}. " + documents.append(f"Document {i + 1}:\n{doc_text}\n") + + # Combine documents + context_base = "\n".join(documents) + + # Pad to target length + context = self._pad_context_with_text( + context_base, target_length, "additional document text" + ) + + # Format template + template = task.template.format(context=context, query=question) + full_input = template + task.answer_prefix + + # Tokenize to get actual length + tokens = self.tokenizer.encode(full_input, add_special_tokens=False) + + return { + "input": full_input, + "length": len(tokens), + "task": task.name, + "target_length": target_length, + "sample_idx": sample_idx, + } + + def _pad_context_with_text( + self, base_context: str, target_length: int, padding_type: str + ) -> str: + """Pad context to approach target length.""" + tokens = self.tokenizer.encode(base_context, add_special_tokens=False) + + while len(tokens) < target_length * 0.7: # Leave room for template + if padding_type == "variable tracking context": + padding = ( + f" VAR {self._generate_random_variable()} = {self._generate_random_variable()}." + ) + elif padding_type == "coded text padding": + padding = f" .... {self._generate_coded_word()} .... " + else: + padding = " " + self._generate_essay_text(50) + + base_context += padding + tokens = self.tokenizer.encode(base_context, add_special_tokens=False) + + if len(tokens) > target_length * 0.9: + # Truncate if too long + base_context = self.tokenizer.decode(tokens[: int(target_length * 0.8)]) + + return base_context + + def _generate_random_word(self) -> str: + """Generate a random word.""" + return "".join(random.choices(string.ascii_lowercase, k=random.randint(5, 10))) + + def _generate_random_variable(self) -> str: + """Generate a random variable name.""" + return "".join(random.choices(string.ascii_uppercase, k=1)) + "".join( + random.choices(string.digits, k=3) + ) + + def _generate_coded_word(self) -> str: + """Generate a coded word.""" + return "".join(random.choices(string.ascii_uppercase + string.digits, k=6)) + + def _generate_random_phrase(self) -> str: + """Generate a random phrase.""" + words = [self._generate_random_word() for _ in range(random.randint(2, 4))] + return " ".join(words) + + def _generate_essay_text(self, num_words: int) -> str: + """Generate essay-like text.""" + topics = [ + "technology", + "science", + "nature", + "history", + "culture", + "education", + "health", + "economics", + "politics", + "philosophy", + "art", + "literature", + ] + + sentences = [] + words_generated = 0 + + while words_generated < num_words: + topic = random.choice(topics) + word1 = self._generate_random_word() + word2 = self._generate_random_word() + word3 = self._generate_random_word() + sentence = f"The {topic} of {word1} is {word2} and {word3}. " + sentences.append(sentence) + words_generated += len(sentence.split()) + + return " ".join(sentences) + + def _generate_document_text(self, num_words: int) -> str: + """Generate document-like text.""" + return self._generate_essay_text(num_words) diff --git a/modelopt/torch/sparsity/attention_sparsity/calibration/download_ruler_data.sh b/modelopt/torch/sparsity/attention_sparsity/calibration/download_ruler_data.sh new file mode 100755 index 000000000..8f492b477 --- /dev/null +++ b/modelopt/torch/sparsity/attention_sparsity/calibration/download_ruler_data.sh @@ -0,0 +1,72 @@ +#!/bin/bash +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Download RULER data files for attention sparsity calibration. +# Downloads Paul Graham Essays URL list and essay content from official sources. + +set -e + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +DATA_DIR="${SCRIPT_DIR}/data" +ESSAYS_DIR="${DATA_DIR}/essays" +RULER_URLS_FILE="${DATA_DIR}/PaulGrahamEssays_URLs.txt" +RULER_URLS_URL="https://raw.githubusercontent.com/NVIDIA/RULER/main/scripts/data/synthetic/json/PaulGrahamEssays_URLs.txt" + +echo "Downloading RULER data files for attention sparsity calibration..." + +# Create directories +mkdir -p "${DATA_DIR}" +mkdir -p "${ESSAYS_DIR}" + +# Step 1: Download URL list +if [ -f "${RULER_URLS_FILE}" ]; then + echo "URL list already exists: ${RULER_URLS_FILE}" +else + echo "Downloading URL list..." + curl -fsSL "${RULER_URLS_URL}" -o "${RULER_URLS_FILE}" + echo "Downloaded: ${RULER_URLS_FILE}" +fi + +# Step 2: Download essay files (only GitHub .txt files) +echo "Downloading essay files..." +DOWNLOAD_COUNT=0 +SKIP_COUNT=0 + +while IFS= read -r url; do + # Only process GitHub .txt URLs + if [[ "${url}" == https://github.com*.txt ]]; then + # Extract filename from URL + filename=$(basename "${url}") + filepath="${ESSAYS_DIR}/${filename}" + + if [ -f "${filepath}" ]; then + ((SKIP_COUNT++)) + else + # Convert GitHub URL to raw URL + raw_url="${url/github.com/raw.githubusercontent.com}" + raw_url="${raw_url/\/raw\//\/}" + + if curl -fsSL "${raw_url}" -o "${filepath}" 2>/dev/null; then + ((DOWNLOAD_COUNT++)) + else + echo "Warning: Failed to download ${filename}" + fi + fi + fi +done < "${RULER_URLS_FILE}" + +echo "Downloaded ${DOWNLOAD_COUNT} new essays, ${SKIP_COUNT} already existed." +echo "Done! RULER data files are ready in ${DATA_DIR}" diff --git a/modelopt/torch/sparsity/attention_sparsity/calibration/ruler_utils.py b/modelopt/torch/sparsity/attention_sparsity/calibration/ruler_utils.py new file mode 100644 index 000000000..70d4da81b --- /dev/null +++ b/modelopt/torch/sparsity/attention_sparsity/calibration/ruler_utils.py @@ -0,0 +1,487 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Copied and Adapted from https://github.com/NVIDIA/RULER +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License + +"""Official RULER dataset generation utilities adapted for Model Optimizer. + +This module contains core logic from the RULER benchmark (https://github.com/NVIDIA/RULER) +adapted to work as a library for calibration purposes. The generation logic closely follows +the official RULER implementation to ensure dataset consistency. + +Key adaptations from official RULER: +- Converted from CLI scripts to library functions +- Works with HuggingFace tokenizers directly +- Removed file I/O, returns data structures +- Simplified for calibration use case (primarily NIAH tasks) +""" + +import logging +import random +import re +import uuid +from pathlib import Path +from typing import Any + +logger = logging.getLogger(__name__) + + +# Needle/Haystack template from official RULER +NEEDLE_TEMPLATE = "One of the special magic {type_needle_v} for {key} is: {value}." + +# Depth positions for needle insertion (from official RULER) +DEPTHS = [ + 0, + 2, + 5, + 7, + 10, + 12, + 15, + 18, + 20, + 23, + 25, + 28, + 30, + 33, + 35, + 38, + 40, + 43, + 45, + 48, + 50, + 53, + 55, + 58, + 60, + 62, + 65, + 67, + 70, + 72, + 75, + 77, + 80, + 82, + 85, + 87, + 90, + 92, + 95, + 97, + 100, +] + +# Data directory for RULER calibration files (downloaded via download_ruler_data.sh) +DATA_DIR = Path(__file__).parent / "data" +RULER_URLS_FILE = DATA_DIR / "PaulGrahamEssays_URLs.txt" +ESSAYS_DIR = DATA_DIR / "essays" + + +def _get_data_dir() -> Path: + """Get data directory for RULER data. + + Returns: + Path to data directory under calibration/ (created if doesn't exist) + """ + data_dir = Path(__file__).parent / "data" + data_dir.mkdir(parents=True, exist_ok=True) + return data_dir + + +def _load_paul_graham_essays_from_files() -> str: + """Load Paul Graham essays from local files. + + Reads essay .txt files from the data/essays directory. + Files must be downloaded first using download_ruler_data.sh. + + Returns: + Combined essay text + + Raises: + RuntimeError: If essays directory doesn't exist or is empty + """ + if not ESSAYS_DIR.exists(): + raise RuntimeError( + f"Essays directory not found at {ESSAYS_DIR}.\n" + "Please run the download script first:\n" + " bash modelopt/torch/sparsity/attention_sparsity/calibration/download_ruler_data.sh" + ) + + essay_files = list(ESSAYS_DIR.glob("*.txt")) + if not essay_files: + raise RuntimeError( + f"No essay files found in {ESSAYS_DIR}.\n" + "Please run the download script first:\n" + " bash modelopt/torch/sparsity/attention_sparsity/calibration/download_ruler_data.sh" + ) + + logger.info(f"Loading {len(essay_files)} Paul Graham essays from local files...") + + all_essays = [] + for filepath in essay_files: + text = filepath.read_text() + all_essays.append(text) + + combined_text = " ".join(all_essays) + logger.info(f"Loaded {len(all_essays)} essays successfully") + + return combined_text + + +def _load_paul_graham_essays() -> str: + """Load Paul Graham essays from local files. + + Essay files must be downloaded first using download_ruler_data.sh. + + Returns: + Essay text as string + """ + essay_text = _load_paul_graham_essays_from_files() + return re.sub(r"\s+", " ", essay_text) + + +def _load_word_lists(): + """Load word lists for random word generation. + + Returns: + List of words (adj-noun combinations) + """ + import wonderwords + + # Load wonderwords lists (same as official RULER) + nouns = wonderwords.random_word._get_words_from_text_file("nounlist.txt") + adjs = wonderwords.random_word._get_words_from_text_file("adjectivelist.txt") + words = [f"{adj}-{noun}" for adj in adjs for noun in nouns] + words = sorted(set(words)) + return words + + +# Global word list (loaded once) +_WORD_LIST = None + + +def generate_random_number(num_digits=7) -> str: + """Generate random number (from official RULER).""" + lower_bound = 10 ** (num_digits - 1) + upper_bound = 10**num_digits - 1 + return str(random.randint(lower_bound, upper_bound)) + + +def generate_random_word() -> str: + """Generate random word (from official RULER).""" + global _WORD_LIST + if _WORD_LIST is None: + _WORD_LIST = _load_word_lists() + return random.choice(_WORD_LIST) + + +def generate_random_uuid() -> str: + """Generate random UUID (from official RULER).""" + return str(uuid.UUID(int=random.getrandbits(128), version=4)) + + +def generate_random(type_needle: str) -> str: + """Generate random needle value based on type (from official RULER). + + Args: + type_needle: Type of needle ('numbers', 'words', 'uuids') + + Returns: + Random value as string + """ + if type_needle == "numbers": + return generate_random_number() + elif type_needle == "words": + return generate_random_word() + elif type_needle == "uuids": + return generate_random_uuid() + else: + raise ValueError(f"Unknown needle type: {type_needle}") + + +def generate_niah_sample( + num_haystack: int, + tokenizer, + template: str, + answer_prefix: str, + tokens_to_generate: int = 128, + type_haystack: str = "essay", + type_needle_k: str = "words", + type_needle_v: str = "numbers", + num_needle_k: int = 1, + num_needle_v: int = 1, + num_needle_q: int = 1, + random_seed: int = 42, +) -> dict[str, Any]: + """Generate a single NIAH (Needle in a Haystack) sample. + + This function implements the core generation logic from official RULER's niah.py, + adapted to work as a library function. + + Args: + num_haystack: Number of haystack items/words + tokenizer: HuggingFace tokenizer (AutoTokenizer instance) + template: NIAH question template + answer_prefix: Answer prefix template + tokens_to_generate: Expected number of generation tokens + type_haystack: Type of haystack ('essay', 'noise', 'needle') + type_needle_k: Type of needle keys ('numbers', 'words', 'uuids') + type_needle_v: Type of needle values ('numbers', 'words', 'uuids') + num_needle_k: Number of needle keys + num_needle_v: Number of needle values per key + num_needle_q: Number of needles to query + random_seed: Random seed for this sample + + Returns: + Dictionary with 'input', 'outputs', 'length' keys + """ + import nltk + from nltk.tokenize import sent_tokenize + + try: + nltk.data.find("tokenizers/punkt") + except LookupError: + nltk.download("punkt", quiet=True) + nltk.download("punkt_tab", quiet=True) + + if random_seed is not None: + random.seed(random_seed) + + # Ensure num_needle_k >= num_needle_q + num_needle_k = max(num_needle_k, num_needle_q) + + # Generate needles (keys and values) + keys, values, needles = [], [], [] + for _ in range(num_needle_k): + keys.append(generate_random(type_needle_k)) + value = [] + for _ in range(num_needle_v): + value.append(generate_random(type_needle_v)) + needles.append( + NEEDLE_TEMPLATE.format( + type_needle_v=type_needle_v, + key=keys[-1], + value=value[-1], + ) + ) + values.append(value) + + random.shuffle(needles) + + # Generate context based on haystack type + if type_haystack == "essay": + # Load essay corpus + essay_text = _load_paul_graham_essays() + haystack = essay_text.split(" ") + + # Create text from haystack + if num_haystack <= len(haystack): + text = " ".join(haystack[:num_haystack]) + else: + # Repeat haystack as needed + repeats = (num_haystack + len(haystack) - 1) // len(haystack) + text = " ".join((haystack * repeats)[:num_haystack]) + + # Insert needles at various depths + document_sents = sent_tokenize(text.strip()) + insertion_positions = [ + 0, + *sorted( + int(len(document_sents) * (depth / 100)) + for depth in random.sample(DEPTHS, len(needles)) + ), + len(document_sents), + ] + + document_sents_list = [] + for i in range(1, len(insertion_positions)): + last_pos = insertion_positions[i - 1] + next_pos = insertion_positions[i] + document_sents_list.append(" ".join(document_sents[last_pos:next_pos])) + if i - 1 < len(needles): + document_sents_list.append(needles[i - 1]) + + context = " ".join(document_sents_list) + + if type_haystack == "noise": + haystack_sent = "The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again." + sentences = [haystack_sent] * num_haystack + indexes = sorted(random.sample(range(num_haystack), len(needles)), reverse=True) + for index, element in zip(indexes, needles): + sentences.insert(index, element) + context = "\n".join(sentences) + + elif type_haystack == "needle": + sentences = [ + NEEDLE_TEMPLATE.format( + type_needle_v=type_needle_v, + key=generate_random(type_needle_k), + value=generate_random(type_needle_v), + ) + for _ in range(num_haystack) + ] + + indexes = sorted(random.sample(range(num_haystack), len(needles)), reverse=True) + for index, element in zip(indexes, needles): + sentences.insert(index, element) + context = "\n".join(sentences) + + # Generate query and answer + indices = random.sample(range(num_needle_k), num_needle_q) + queries = [keys[i] for i in indices] + answers = [a for i in indices for a in values[i]] + query = ", ".join(queries[:-1]) + ", and " + queries[-1] if len(queries) > 1 else queries[0] + + # Format template (adjust for singular vs plural) + type_needle_v_display = type_needle_v + formatted_template = template + if num_needle_q * num_needle_v == 1: + formatted_template = formatted_template.replace("Some", "A") + formatted_template = formatted_template.replace("are all", "is") + formatted_template = formatted_template.replace("are", "is") + formatted_template = formatted_template.replace("answers", "answer") + type_needle_v_display = type_needle_v[:-1] # remove "s" + + input_text = formatted_template.format( + type_needle_v=type_needle_v_display, + context=context, + query=query, + ) + + # Add answer prefix + formatted_answer_prefix = answer_prefix.format( + type_needle_v=type_needle_v_display, + query=query, + ) + input_text = input_text + formatted_answer_prefix + + # Calculate actual length + if hasattr(tokenizer, "encode"): + # HuggingFace tokenizer + tokens = tokenizer.encode(input_text, add_special_tokens=False) + length = len(tokens) + tokens_to_generate + else: + # Fallback + length = len(input_text.split()) + tokens_to_generate + + return { + "input": input_text, + "outputs": answers, + "length": length, + } + + +def find_optimal_haystack_size( + tokenizer, + max_seq_length: int, + template: str, + answer_prefix: str, + tokens_to_generate: int = 128, + type_haystack: str = "essay", + **kwargs, +) -> int: + """Find optimal haystack size using binary search (from official RULER). + + Args: + tokenizer: HuggingFace tokenizer + max_seq_length: Maximum sequence length + tokens_to_generate: Expected generation tokens + type_haystack: Type of haystack + template: NIAH question template + answer_prefix: Answer prefix template + **kwargs: Additional arguments for generate_niah_sample + + Returns: + Optimal number of haystack items + """ + # Determine incremental step based on haystack type + if type_haystack == "essay": + incremental = 500 + elif type_haystack in ["noise", "needle"]: + incremental = 25 + else: + incremental = 100 + + if max_seq_length < 4096 and type_haystack != "essay": + incremental = 5 + + # Estimate tokens per haystack item + sample = generate_niah_sample( + incremental, + tokenizer, + template, + answer_prefix, + tokens_to_generate, + type_haystack=type_haystack, + **kwargs, + ) + + if hasattr(tokenizer, "encode"): + sample_tokens = len(tokenizer.encode(sample["input"], add_special_tokens=False)) + else: + sample_tokens = len(sample["input"].split()) + + tokens_per_haystack = sample_tokens / incremental + estimated_max = int((max_seq_length / tokens_per_haystack) * 3) + + # Binary search for optimal size + lower_bound = incremental + upper_bound = max(estimated_max, incremental * 2) + optimal_num_haystack = None + + logger.info(f"Estimated {tokens_per_haystack:.1f} tokens per haystack") + logger.info(f"Binary search bounds: {lower_bound} to {upper_bound}") + + while lower_bound <= upper_bound: + mid = (lower_bound + upper_bound) // 2 + sample = generate_niah_sample( + mid, + tokenizer, + template, + answer_prefix, + tokens_to_generate, + type_haystack=type_haystack, + **kwargs, + ) + total_tokens = sample["length"] + + logger.debug(f"Testing haystack size: {mid}, tokens: {total_tokens}/{max_seq_length}") + + if total_tokens <= max_seq_length: + optimal_num_haystack = mid + lower_bound = mid + 1 + else: + upper_bound = mid - 1 + + final_size = optimal_num_haystack if optimal_num_haystack is not None else incremental + logger.info(f"Optimal haystack size: {final_size}") + + return final_size diff --git a/modelopt/torch/sparsity/attention_sparsity/config.py b/modelopt/torch/sparsity/attention_sparsity/config.py new file mode 100644 index 000000000..7fc985f0a --- /dev/null +++ b/modelopt/torch/sparsity/attention_sparsity/config.py @@ -0,0 +1,343 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Configuration classes for sparse attention optimization.""" + +from collections.abc import Callable +from typing import Any + +from pydantic import Field, field_validator + +from modelopt.torch.opt.config import ModeloptBaseConfig, ModeloptField + +# Type definitions for sparse configuration +SparseAttributeConfig = dict[str, Any] # Configuration for a specific pattern + +SparseAttentionCfgType = dict[ + str | Callable, # Pattern or callable for matching modules + SparseAttributeConfig, # Configuration dict with threshold, enable, etc. +] + + +class SparseAttentionAttributeConfig(ModeloptBaseConfig): + """Sparse attention attribute configuration for pattern-based module config.""" + + method: str = ModeloptField( + default="flash_skip_softmax", + title="Sparse attention method.", + description="The sparse attention method to use (e.g., 'flash_skip_softmax').", + ) + + enable: bool = ModeloptField( + default=True, + title="Enable sparse attention.", + description="If True, enables sparse attention. If False, bypasses sparsity.", + ) + + threshold: float | dict[str, float] = ModeloptField( + default=1e-3, + title="Sparsity threshold.", + description=( + "Threshold for determining which attention values to skip. " + "Can be a float or dict with phase-specific values." + ), + ) + + br: int = ModeloptField( + default=128, + title="Block row size.", + description="Block row size for block-wise sparsity in Flash Attention.", + ) + + bc: int = ModeloptField( + default=128, + title="Block column size.", + description="Block column size for block-wise sparsity in Flash Attention.", + ) + + backend: str = ModeloptField( + default="pytorch", + title="Backend implementation.", + description=( + "Backend to use for sparse attention computation. " + "Only 'pytorch' is supported, which uses softmax patching with F.softmax. " + "Requires model to be loaded with attn_implementation='eager'." + ), + ) + + collect_stats: bool = ModeloptField( + default=False, + title="Collect statistics.", + description="Whether to collect sparsity statistics during forward pass for monitoring.", + ) + + is_causal: bool = ModeloptField( + default=True, + title="Causal attention flag.", + description=( + "Whether the model uses causal (autoregressive) attention. " + "If True, sparsity statistics are calculated over the lower triangle only. " + "Defaults to True for decoder-only models like GPT, LLaMA, etc." + ), + ) + + @field_validator("method") + @classmethod + def validate_method(cls, v): + """Validate method is a string.""" + if not isinstance(v, str): + raise ValueError("method must be a string") + return v + + @field_validator("backend") + @classmethod + def validate_backend(cls, v): + """Validate backend is pytorch.""" + if v != "pytorch": + raise ValueError( + f"Invalid backend: {v}. Only 'pytorch' backend is supported. " + f"Model must be loaded with attn_implementation='eager'." + ) + return v + + @field_validator("br", "bc") + @classmethod + def validate_block_size(cls, v): + """Validate block sizes are positive integers.""" + if v <= 0: + raise ValueError(f"Block size must be positive, got {v}") + return v + + @field_validator("threshold") + @classmethod + def validate_threshold(cls, v): + """Validate threshold is in valid range (0, 1) or dict with valid phases.""" + if isinstance(v, dict): + # Validate phase keys + valid_phases = {"prefill", "decode", "default"} + invalid_keys = set(v.keys()) - valid_phases + if invalid_keys: + raise ValueError( + f"Invalid threshold phases: {invalid_keys}. Valid phases: {valid_phases}" + ) + # Validate all values are in range (0, 1) + for phase, threshold in v.items(): + if not isinstance(threshold, (int, float)) or threshold <= 0 or threshold >= 1: + raise ValueError( + f"Threshold for phase '{phase}' must be in range (0, 1), got {threshold}" + ) + elif isinstance(v, (int, float)): + if v <= 0 or v >= 1: + raise ValueError(f"Threshold must be in range (0, 1), got {v}") + else: + raise ValueError(f"Threshold must be a number in range (0, 1) or dict, got {type(v)}") + return v + + +class CalibrationConfig(ModeloptBaseConfig): + """Configuration for automatic threshold calibration using RULER dataset. + + Calibration learns a dynamic threshold λ = scale_factor / sequence_length that + achieves target sparsity. Only supports prefill phase (seq_len > 1). + """ + + target_sparse_ratio: float = ModeloptField( + default=0.5, + title="Target sparsity ratio", + description="Target ratio of sparse attention blocks (0.0 to 1.0).", + ) + + samples: int = ModeloptField( + default=24, + title="Calibration samples", + description="Total number of RULER samples for calibration (distributed across length bins).", + ) + + max_seqlen: int = ModeloptField( + default=32768, + title="Maximum sequence length", + description="Maximum sequence length for calibration (length bins auto-generated as powers of 2).", + ) + + num_length_bins: int = ModeloptField( + default=4, + title="Number of length bins", + description="Number of length bins to generate (hidden parameter, default: 4).", + ) + + threshold_trials: list[float] | None = ModeloptField( + default=None, + title="Threshold trials", + description=( + "List of threshold values to test during calibration. " + "If None, uses default: [1e-6, 5e-6, 1e-5, 5e-5, 1e-4, 5e-4, 1e-3, 5e-3, 1e-2, 5e-2, 1e-1, 5e-1]" + ), + ) + + @field_validator("threshold_trials") + @classmethod + def validate_threshold_trials(cls, v): + """Validate threshold_trials are in valid range.""" + if v is not None: + if not isinstance(v, list): + raise ValueError(f"threshold_trials must be a list, got {type(v)}") + if len(v) == 0: + raise ValueError("threshold_trials must not be empty") + for threshold in v: + if not isinstance(threshold, (int, float)): + raise ValueError(f"All threshold_trials must be numbers, got {type(threshold)}") + if threshold <= 0 or threshold >= 1: + raise ValueError( + f"All threshold_trials must be in range (0, 1), got {threshold}" + ) + return v + + @field_validator("target_sparse_ratio") + @classmethod + def validate_target_sparse_ratio(cls, v): + """Validate target sparsity ratio is between 0 and 1.""" + if not 0.0 <= v <= 1.0: + raise ValueError(f"target_sparse_ratio must be between 0.0 and 1.0, got {v}") + return v + + @field_validator("samples") + @classmethod + def validate_samples(cls, v): + """Validate samples is positive.""" + if v <= 0: + raise ValueError(f"samples must be positive, got {v}") + return v + + @field_validator("max_seqlen") + @classmethod + def validate_max_seqlen(cls, v): + """Validate max_seqlen is at least 1024.""" + if v < 1024: + raise ValueError(f"max_seqlen must be >= 1024, got {v}") + return v + + @field_validator("num_length_bins") + @classmethod + def validate_num_length_bins(cls, v): + """Validate num_length_bins is positive.""" + if v <= 0: + raise ValueError(f"num_length_bins must be positive, got {v}") + return v + + +class SparseAttentionConfig(ModeloptBaseConfig): + """Base configuration for sparse attention optimization. + + This base configuration provides the common structure for all sparse + attention methods and supports pattern-based layer configuration. + """ + + # Pattern-based sparse configuration (similar to quant_cfg in quantization) + sparse_cfg: SparseAttentionCfgType = ModeloptField( + default={ + "*attention*": {"method": "flash_skip_softmax", "enable": True}, + "default": {"enable": False}, + }, + title="Sparse attention configuration", + description="Pattern-based configuration for sparse attention. Keys are patterns to match module names " + "(or 'calibration' for global calibration settings), values are configuration dicts with parameters like " + "'threshold', 'enable', etc.", + validate_default=True, + ) + + # Export configuration + export_format: str | None = Field( + None, description="Export format for sparse attention (e.g., 'onnx', 'tensorrt')" + ) + + +class FlashSkipSoftmaxConfig(SparseAttentionConfig): + """Configuration for Flash Attention-aware softmax skip sparse attention.""" + + # Override sparse_cfg with flash_skip_softmax specific defaults + sparse_cfg: SparseAttentionCfgType = ModeloptField( + default={ + "*attention*": { + "method": "flash_skip_softmax", + "threshold": {"prefill": 1e-3, "decode": 1e-5}, + "br": 128, # Flash Attention block rows + "bc": 128, # Flash Attention block columns + "backend": "pytorch", # Only pytorch backend supported + "collect_stats": True, # Enable statistics collection + "enable": True, + }, + "default": {"enable": False}, + }, + title="Flash softmax skip sparse configuration", + description="Pattern-based configuration with flash_skip_softmax specific defaults. " + "Includes FA block sizes (br, bc) and correction factor settings.", + validate_default=True, + ) + + +# Pre-defined Sparse Attention Configuration +# Default configuration with block-wise sparsity optimized for Flash Attention +SKIP_SOFTMAX_DEFAULT = { + "sparse_cfg": { + "*attn*": { + "method": "flash_skip_softmax", + "threshold": { + "prefill": 1e-3, # More aggressive during prefill + "decode": 1e-4, # Conservative during decode + }, + "br": 128, # Flash Attention block rows + "bc": 128, # Flash Attention block columns + "backend": "pytorch", # Only pytorch backend supported + "collect_stats": True, + "enable": True, + }, + "default": {"enable": False}, + }, +} + + +# Configuration with RULER calibration +# Note: threshold field is omitted - calibration determines dynamic threshold λ = a / length +# The calibrated threshold adapts to sequence length for optimal sparsity +SKIP_SOFTMAX_CALIB = { + "sparse_cfg": { + "calibration": { + "target_sparse_ratio": 0.5, + "samples": 128, + "max_seqlen": 8192, + }, + "*attn*": { + "method": "flash_skip_softmax", + "br": 128, + "bc": 128, + "backend": "pytorch", # Only pytorch backend supported + "collect_stats": True, + "enable": True, + }, + "default": {"enable": False}, + }, +} + + +__all__ = [ + "SKIP_SOFTMAX_CALIB", + "SKIP_SOFTMAX_DEFAULT", + "CalibrationConfig", + "FlashSkipSoftmaxConfig", + "SparseAttentionAttributeConfig", + "SparseAttentionCfgType", + "SparseAttentionConfig", + "SparseAttributeConfig", +] diff --git a/modelopt/torch/sparsity/attention_sparsity/conversion.py b/modelopt/torch/sparsity/attention_sparsity/conversion.py new file mode 100644 index 000000000..8b849e247 --- /dev/null +++ b/modelopt/torch/sparsity/attention_sparsity/conversion.py @@ -0,0 +1,371 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Conversion and restoration utilities for sparse attention.""" + +import fnmatch +from collections.abc import Callable +from typing import Any + +import torch.nn as nn + +from modelopt.torch.opt.conversion import ModelLikeModule, ModeloptStateManager +from modelopt.torch.opt.mode import ConvertReturnType, MetadataDict +from modelopt.torch.utils import get_unwrapped_name + +from .config import SparseAttentionConfig +from .plugins.huggingface import register_sparse_attention_on_the_fly +from .sparse_attention import SparseAttentionModule, SparseAttentionRegistry + + +def is_attn_sparsified(model: nn.Module) -> bool: + """Check if a model has sparse attention applied. + + Similar to quantization's is_quantized for API consistency. + + Args: + model: Model to check + + Returns: + True if model contains any SparseAttentionModule instances + """ + return any(isinstance(module, SparseAttentionModule) for module in model.modules()) + + +def convert_to_sparse_attention_model( + model: ModelLikeModule, config: SparseAttentionConfig +) -> ConvertReturnType: + """Convert model to use sparse attention. + + Args: + model: Model to convert + config: Sparse attention configuration + + Returns: + Tuple of (converted_model, metadata) + """ + # Initialize the true module if necessary + model = model.init_modellike() if isinstance(model, ModelLikeModule) else model + + # Register sparse attention modules dynamically + register_sparse_attention_on_the_fly(model) + + # Replace attention modules with sparse versions + replace_sparse_attention_modules(model, version=ModeloptStateManager(model).state_version) + + # Apply configuration to sparse attention modules + sparse_cfg = config.sparse_cfg if hasattr(config, "sparse_cfg") else {} + set_sparse_attention_by_cfg(model, sparse_cfg) + + # Create metadata + metadata = {} + update_sparse_attention_metadata(model, config, metadata) + + return model, metadata + + +def replace_sparse_attention_modules(model: nn.Module, version=None): + """Replace regular attention modules with sparse attention modules. + + Recursively replace all attention modules in the model with their sparse attention counterparts. + + Args: + model: Model to process + version: State version for tracking (optional) + """ + # Recursively replace modules + _replace_sparse_attention_modules(model, version=version) + + # Count and report replaced modules + replaced_count = sum(isinstance(m, SparseAttentionModule) for _, m in model.named_modules()) + if replaced_count > 0: + print(f"Inserted {replaced_count} sparse attention modules") + + +def _replace_sparse_attention_modules(model: nn.Module, version=None): + """Helper function for replace_sparse_attention_modules.""" + for name, child in model.named_children(): + if type(child) in SparseAttentionRegistry: + # REPLACE on the parent (model), not on child + sparse_module = SparseAttentionRegistry.convert(child) + setattr(model, name, sparse_module) + + # Now recurse into whichever module is now at `model.name` + _replace_sparse_attention_modules(getattr(model, name), version=version) + + +def set_sparse_attention_by_cfg(model: nn.Module, sparse_cfg: dict): + """Apply sparse attention configuration to model. + + Similar to quantization's set_quantizer_by_cfg. + + Args: + model: Model with sparse attention modules + sparse_cfg: Sparse configuration dictionary mapping patterns to attributes + """ + sparse_cfg = sparse_cfg.copy() + + # Apply default first if exists + if "default" in sparse_cfg: + set_sparse_attention_attribute(model, "*", sparse_cfg["default"]) + sparse_cfg.pop("default") + + # Apply pattern-specific configs + for pattern, cfg in sparse_cfg.items(): + set_sparse_attention_attribute(model, pattern, cfg) + + +def set_sparse_attention_attribute( + model: nn.Module, + wildcard_or_filter: str | Callable, + attribute_cfg: dict[str, Any], +): + """Set sparse attention attributes for modules matching pattern. + + Similar to quantization's set_quantizer_attribute. + + Args: + model: Model to configure + wildcard_or_filter: Pattern to match module names + attribute_cfg: Attributes to apply (must include 'method') + """ + # Filter out model-level configs that shouldn't be passed to modules + module_cfg = {k: v for k, v in attribute_cfg.items() if k != "calibration"} + + for name, module in model.named_modules(): + if not isinstance(module, SparseAttentionModule): + continue + + # Check pattern match + matched = False + if isinstance(wildcard_or_filter, str): + matched = fnmatch.fnmatch(name, wildcard_or_filter) + elif callable(wildcard_or_filter): + matched = wildcard_or_filter(name) + else: + raise NotImplementedError(f"Unsupported type {type(wildcard_or_filter)}") + + if matched: + # Apply config using the same method as TensorQuantizer + module.set_from_attribute_config(module_cfg) + + +def restore_sparse_attention_model( + model: ModelLikeModule, config: SparseAttentionConfig, metadata: MetadataDict +) -> nn.Module: + """Restore sparse attention model from saved state. + + Args: + model: Model to restore + config: Sparse attention configuration + metadata: Saved metadata + + Returns: + Restored model + """ + # Convert to sparse attention model + model, _ = convert_to_sparse_attention_model(model, config) + + # Restore sparse attention state from metadata + if "sparse_attention_state" in metadata: + restore_sparse_attention_state(model, metadata["sparse_attention_state"]) + + return model + + +def restore_sparse_attention_state(model: nn.Module, state_dict: dict[str, Any]): + """Restore sparse attention state from state dict. + + Args: + model: Model with sparse attention modules + state_dict: Saved state dictionary + """ + for name, module in model.named_modules(): + if isinstance(module, SparseAttentionModule): + module_name = get_unwrapped_name(name, model) + if module_name in state_dict: + module_state = state_dict[module_name] + + # Restore method and config + if "method" in module_state: + module._method = module_state["method"] + if "method_config" in module_state: + # Restore config attributes + for key, val in module_state["method_config"].items(): + setattr(module, f"_{key}", val) + + # Re-setup with restored config + module._setup() + + +def update_sparse_attention_metadata( + model: nn.Module, config: SparseAttentionConfig, metadata: MetadataDict +) -> None: + """Update metadata with sparse attention state. + + Args: + model: Model with sparse attention + config: Configuration used + metadata: Metadata dict to update + """ + sparse_state = {} + + for name, module in model.named_modules(): + if isinstance(module, SparseAttentionModule): + module_name = get_unwrapped_name(name, model) + + # Save the method configuration that was used + # _method_config already contains the validated config dict + module_state = { + "method": module._sparse_method_instance.name, + "method_config": module._method_config.copy(), + } + + sparse_state[module_name] = module_state + + metadata["sparse_attention_state"] = sparse_state + metadata["sparse_attention_config"] = ( + config.model_dump() if hasattr(config, "model_dump") else vars(config) + ) + + +def disable_sparse_attention(model: nn.Module, wildcard_or_filter_func: str | Callable): + """Disable sparse attention for matching modules. + + Similar to mtq.disable_quantizer for API consistency. + + Args: + model: Model with sparse attention applied + wildcard_or_filter_func: Wildcard string or filter function to match module names. + For example: "*lm_head*", "*layer_0*", etc. + + Example: + >>> import modelopt.torch.sparsity.attention_sparsity as sparse_attn + >>> model = sparse_attn.sparsify(model, config) + >>> # Disable sparse attention for lm_head + >>> sparse_attn.disable_sparse_attention(model, "*lm_head*") + """ + for name, module in model.named_modules(): + if not isinstance(module, SparseAttentionModule): + continue + + matched = False + if isinstance(wildcard_or_filter_func, str): + matched = fnmatch.fnmatch(name, wildcard_or_filter_func) + elif callable(wildcard_or_filter_func): + matched = wildcard_or_filter_func(name) + + if matched: + module.disable() + + +def enable_sparse_attention(model: nn.Module, wildcard_or_filter_func: str | Callable): + """Enable sparse attention for matching modules. + + Similar to mtq.enable_quantizer for API consistency. + + Args: + model: Model with sparse attention applied + wildcard_or_filter_func: Wildcard string or filter function to match module names. + For example: "*attention*", "*attn*", etc. + + Example: + >>> import modelopt.torch.sparsity.attention_sparsity as sparse_attn + >>> model = sparse_attn.sparsify(model, config) + >>> # Re-enable sparse attention for all attention modules + >>> sparse_attn.enable_sparse_attention(model, "*attention*") + """ + for name, module in model.named_modules(): + if not isinstance(module, SparseAttentionModule): + continue + + matched = False + if isinstance(wildcard_or_filter_func, str): + matched = fnmatch.fnmatch(name, wildcard_or_filter_func) + elif callable(wildcard_or_filter_func): + matched = wildcard_or_filter_func(name) + + if matched: + module.enable() + + +def print_sparse_attention_summary(model: nn.Module): + """Print summary of sparse attention modules in the model. + + Similar to mtq.print_quant_summary for API consistency. + + Args: + model: Model with sparse attention applied + + Prints: + - Total sparse attention modules + - Enabled vs disabled count + - Method distribution + - Configuration summary by module + + Example: + >>> import modelopt.torch.sparsity.attention_sparsity as sparse_attn + >>> model = sparse_attn.sparsify(model, config) + >>> sparse_attn.print_sparse_attention_summary(model) + """ + sparse_modules = [] + for name, module in model.named_modules(): + if isinstance(module, SparseAttentionModule): + sparse_modules.append((name, module)) + + if not sparse_modules: + print("No sparse attention modules found in model") + return + + enabled_count = sum(1 for _, m in sparse_modules if m.is_enabled) + disabled_count = len(sparse_modules) - enabled_count + + # Count methods + method_counts = {} + for _, module in sparse_modules: + method = getattr(module, "_method", "unknown") + method_counts[method] = method_counts.get(method, 0) + 1 + + print(f"Total sparse attention modules: {len(sparse_modules)}") + print(f"Enabled: {enabled_count}") + print(f"Disabled: {disabled_count}") + + if method_counts: + print("\nMethods:") + for method, count in sorted(method_counts.items()): + print(f"{method}: {count}") + + for name, module in sparse_modules: + method = getattr(module, "_method", "unknown") + threshold_info = module.get_threshold_info() + + # Format threshold information based on type + threshold_type = threshold_info.get("type", "unknown") + + if threshold_type == "dynamic": + scale_factor = threshold_info.get("scale_factor") + threshold_str = f"Dynamic (λ={scale_factor:.6f})" + elif threshold_type == "static": + value = threshold_info.get("value") + threshold_str = ( + f"Static ({value:.2e})" if isinstance(value, (int, float)) else f"Static ({value})" + ) + elif threshold_type == "static_phased": + thresholds = threshold_info.get("thresholds", {}) + threshold_str = f"Phased {thresholds}" + else: + threshold_str = "N/A" + + print(f"Method: {method}, Threshold: {threshold_str}") diff --git a/modelopt/torch/sparsity/attention_sparsity/methods/__init__.py b/modelopt/torch/sparsity/attention_sparsity/methods/__init__.py new file mode 100644 index 000000000..8a109fda7 --- /dev/null +++ b/modelopt/torch/sparsity/attention_sparsity/methods/__init__.py @@ -0,0 +1,27 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Sparse attention methods package.""" + +from .registry import SparseAttentionMethod, get_sparse_method, register_sparse_method + +__all__ = [ + "SparseAttentionMethod", + "get_sparse_method", + "register_sparse_method", +] + +# Import method implementations to trigger registration +from . import flash_skip_softmax diff --git a/modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.py b/modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.py new file mode 100644 index 000000000..c12d77548 --- /dev/null +++ b/modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.py @@ -0,0 +1,343 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Flash Attention-aware softmax skip method for sparse attention. + +This module implements block-wise sparsity that aligns with Flash Attention's +processing pattern for optimal performance. +""" + +import math +from typing import Any + +import numpy as np +import torch + +from . import SparseAttentionMethod, register_sparse_method + + +@register_sparse_method("flash_skip_softmax") +class FlashSkipSoftmax(SparseAttentionMethod): + """Flash Attention-aware softmax skip sparse attention method. + + Implements row-level block-wise sparsity aligned with Flash Attention's + processing pattern for optimal performance and accuracy. + """ + + def __init__(self, method_config: dict | None = None): + """Initialize Flash softmax skip method. + + Args: + method_config: Configuration dict with threshold, br, bc, is_causal, etc. + All required fields should have defaults from SparseAttentionAttributeConfig. + """ + config = method_config or {} + + # Extract configuration + self.threshold_config = config["threshold"] + self.br = config["br"] + self.bc = config["bc"] + self.backend = config["backend"] + self.is_causal = config["is_causal"] + + # Optional parameters not in Pydantic config + self.enable_correction_factor = config.get("enable_correction_factor", True) + self.phase = config.get("phase", None) + + # Calibration mode: when True, prevent threshold updates to preserve calibrator's test threshold + self._calibration_mode = False + + # Initialize threshold + if isinstance(self.threshold_config, dict): + self.threshold = self.threshold_config.get( + "default", self.threshold_config.get("prefill", 1e-4) + ) + else: + self.threshold = self.threshold_config + + def set_calibration_mode(self, enabled: bool): + """Set calibration mode to prevent _update_threshold from modifying the threshold.""" + self._calibration_mode = enabled + + def _update_threshold(self, phase: str): + """Update threshold based on phase.""" + if isinstance(self.threshold_config, dict): + self.threshold = self.threshold_config.get( + phase, self.threshold_config.get("default", self.threshold) + ) + + def _infer_phase(self, attention_scores: torch.Tensor) -> str: + """Infer phase from attention scores shape.""" + return "decode" if attention_scores.shape[2] == 1 else "prefill" + + def _reshape_to_blocks( + self, tensor: torch.Tensor, br: int, bc: int + ) -> tuple[torch.Tensor, ...]: + """Reshape tensor into blocks for Flash Attention processing. + + Args: + tensor: Input tensor of shape [batch, heads, seq_q, seq_k] + br: Block row size + bc: Block column size + + Returns: + Tuple of (blocked_tensor, num_block_rows, num_block_cols, padded_seq_q, padded_seq_k) + """ + batch_size, num_heads, seq_q, seq_k = tensor.shape + + # Calculate padding needed + padded_seq_q = math.ceil(seq_q / br) * br + padded_seq_k = math.ceil(seq_k / bc) * bc + + # Pad tensor if necessary + if padded_seq_q != seq_q or padded_seq_k != seq_k: + pad_q = padded_seq_q - seq_q + pad_k = padded_seq_k - seq_k + # Use dtype min instead of -inf for numerical stability + pad_value = torch.finfo(tensor.dtype).min + tensor = torch.nn.functional.pad(tensor, (0, pad_k, 0, pad_q), value=pad_value) + + # Reshape to blocks + num_block_rows = padded_seq_q // br + num_block_cols = padded_seq_k // bc + + # Keep natural order for row-level processing: [batch, heads, block_rows, br, block_cols, bc] + blocked = tensor.view(batch_size, num_heads, num_block_rows, br, num_block_cols, bc) + + return blocked, num_block_rows, num_block_cols, padded_seq_q, padded_seq_k + + def calc_correction_factor_and_p( + self, attn_weights: torch.Tensor, phase: str + ) -> tuple[torch.Tensor, dict]: + """Calculate sparse mask and statistics for Flash Attention. + + Implements block-wise sparsity compatible with Flash Attention's online softmax: + 1. Reshape attention scores into 128x128 blocks + 2. Track block-wise maximum values (simulating Flash Attention's row processing) + 3. Compute cumulative maximum across blocks (for online normalization) + 4. Apply threshold: mask blocks where p = score - cummax < log(threshold) + 5. Calculate correction factor and sparsity statistics + + Args: + attn_weights: Pre-softmax attention scores [batch, heads, seq_q, seq_k] + phase: "prefill" (seq_q > 1) or "decode" (seq_q = 1) + + Returns: + element_mask: Boolean mask [batch, heads, seq_q, seq_k] + stats: Dict with sparsity, correction_factor, total_blocks, etc. + """ + batch_size, num_heads, seq_q, seq_k = attn_weights.shape + + # Calculate threshold + threshold_scale_factor = getattr(self, "threshold_scale_factor", None) + if threshold_scale_factor: + # Use calibrated dynamic threshold: λ = scale_factor / length + log_threshold = np.log(threshold_scale_factor / seq_k) + else: + # Use static threshold from config + log_threshold = np.log(self.threshold) + + if phase == "prefill": + blocked_attn, num_block_rows, num_block_cols, padded_seq_q, padded_seq_k = ( + self._reshape_to_blocks(attn_weights, self.br, self.bc) + ) + + # Step 1: Compute maximum value in each block + # For each 128x128 block, find max across the 128 columns + # blocked_attn: [batch, heads, block_rows, br=128, block_cols, bc=128] + # block_max: [batch, heads, block_rows, br=128, block_cols] + block_max = blocked_attn.max(dim=-1)[0] + + # Step 2: Track cumulative maximum across blocks (left to right) + # This simulates Flash Attention's online softmax normalization + # block_max_cummax: [batch, heads, block_rows, br=128, block_cols] + block_max_cummax = block_max.cummax(dim=-1)[0] + + # Step 3: Calculate correction factor (how often max changes) + # Used by Flash Attention to adjust running sum when max increases + block_max_larger = torch.ones_like(block_max) + block_max_larger[..., 1:] = block_max[..., 1:] > block_max_cummax[..., :-1] + correction_factor = float(torch.sum(block_max_larger) / torch.numel(block_max_larger)) + + # Step 4: Normalize attention scores by cumulative max + # p represents log-space difference: log(score) - log(cummax) + p = blocked_attn - block_max_cummax[..., None] + + # Step 5: Apply threshold and create block-level mask + # Keep blocks where at least one element exceeds log(threshold) + p_larger_than_thresh = p > log_threshold + # Reduce over bc (128 cols), then br (128 rows) to get block-level decision + # Result: [batch, heads, block_rows, block_cols] + block_mask = p_larger_than_thresh.any(dim=-1).any(dim=-2) + + # Step 6: Expand block mask back to element level + # All 128x128 elements in a block share the same mask value + # [batch, heads, block_rows, block_cols] -> [batch, heads, block_rows, br=128, block_cols, bc=128] + element_mask = block_mask.unsqueeze(-2).unsqueeze(-1).expand_as(blocked_attn) + + # Step 7: Reshape to original attention shape and remove padding + element_mask = element_mask.reshape(batch_size, num_heads, padded_seq_q, padded_seq_k) + element_mask = element_mask[:, :, :seq_q, :seq_k] + + # Step 8: Calculate sparsity statistics + # Count kept blocks (averaged across batch and heads) + kept_blocks = block_mask.sum().item() / (batch_size * num_heads) + + # Total valid blocks (lower triangle only for causal attention) + # Note: Causal mask pre-applied by attention module, so block_mask naturally + # has zeros in upper triangle. We only count lower triangle for denominator. + total_blocks = ( + num_block_rows * (num_block_rows + 1) // 2 # Causal: N(N+1)/2 + if self.is_causal + else num_block_rows * num_block_cols # Non-causal: N*N + ) + sparsity = 1 - (kept_blocks / total_blocks) + else: # decode + blocked_attn, _, num_block_cols, _, padded_seq_k = self._reshape_to_blocks( + attn_weights, 1, self.bc + ) + + # Decode: Single query row attends to all past key blocks + # blocked_attn: [batch, heads, 1, 1, num_block_cols, bc=128] + + # Step 1: Find maximum in each key block + # block_max: [batch, heads, 1, 1, num_block_cols] + block_max = blocked_attn.max(dim=-1)[0] + + # Step 2: Track cumulative maximum across key blocks (left to right) + # Simulates Flash Attention's online softmax normalization + block_max_cummax = block_max.cummax(dim=-1)[0] + + # Step 3: Calculate correction factor + # Tracks how often the maximum increases (needed for Flash Attention rescaling) + block_max_larger = torch.ones_like(block_max) + block_max_larger[..., 1:] = block_max[..., 1:] > block_max_cummax[..., :-1] + correction_factor = float(torch.sum(block_max_larger) / torch.numel(block_max_larger)) + + # Step 4: Normalize scores by cumulative max + # p = log(score) - log(cummax) in log-space + p = blocked_attn - block_max_cummax[..., None] + + # Step 5: Apply threshold and create block mask + # Keep blocks where at least one element exceeds threshold + p_larger_than_thresh = p > log_threshold + block_mask = p_larger_than_thresh.any(dim=-1, keepdim=False) + + # Step 6: Expand to element level and remove padding + element_mask = block_mask[..., None].expand_as(blocked_attn) + element_mask = element_mask.reshape(batch_size, num_heads, 1, padded_seq_k) + element_mask = element_mask[:, :, :seq_q, :seq_k] + + # Step 7: Calculate statistics + kept_blocks = block_mask.sum().item() / (batch_size * num_heads) + total_blocks = num_block_cols + sparsity = 1 - (kept_blocks / total_blocks) + + # Create stats dictionary + stats = { + "correction_factor": correction_factor if self.enable_correction_factor else 1.0, + "sparsity": sparsity, + "phase": phase, + "total_blocks": total_blocks, + "sparse_blocks": int(sparsity * total_blocks), + "sample_length": seq_k, + } + + return element_mask, stats + + def apply_sparsity( + self, + query: torch.Tensor | None = None, + key: torch.Tensor | None = None, + value: torch.Tensor | None = None, + attention_scores: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor | None]: + """Apply Flash Attention-aware block-wise sparsity. + + Args: + query: Query tensor (unused, for API compatibility) + key: Key tensor (unused, for API compatibility) + value: Value tensor (unused, for API compatibility) + attention_scores: Attention scores tensor with shape [batch, heads, seq_q, seq_k] + + Returns: + Tuple with potentially modified attention_scores + """ + # Attention scores must be provided for sparse attention + assert attention_scores is not None, "attention_scores must be provided for apply_sparsity" + + # Attention scores are always 4D: [batch, heads, seq_q, seq_k] + assert len(attention_scores.shape) == 4, ( + f"Expected 4D attention scores, got shape {attention_scores.shape}" + ) + + # Infer phase from tensor shape + phase = self._infer_phase(attention_scores) + + # Update threshold for the detected phase (skip during calibration) + if not self._calibration_mode: + self._update_threshold(phase) + + # Apply block-wise sparsity + sparse_mask, stats = self.calc_correction_factor_and_p(attention_scores, phase) + + # Store stats for module to collect (doesn't persist across calls) + self._last_stats = stats + + # Apply mask to create sparse scores + mask_value = torch.finfo(attention_scores.dtype).min + sparse_scores = attention_scores.masked_fill(~sparse_mask, mask_value) + + return query, key, value, sparse_scores + + def get_threshold_info(self) -> dict[str, Any]: + """Get threshold information for this method. + + Returns: + Dictionary with threshold configuration and calibration info. + """ + threshold_scale_factor = getattr(self, "threshold_scale_factor", None) + + if threshold_scale_factor is not None: + # Calibrated dynamic threshold + return { + "type": "dynamic", + "scale_factor": threshold_scale_factor, + "formula": "λ / length", + "example_lengths": { + 1024: threshold_scale_factor / 1024, + 2048: threshold_scale_factor / 2048, + 4096: threshold_scale_factor / 4096, + 8192: threshold_scale_factor / 8192, + }, + } + elif isinstance(self.threshold_config, dict): + # Phase-specific static thresholds + return { + "type": "static_phased", + "thresholds": self.threshold_config.copy(), + "current": self.threshold, + } + else: + # Single static threshold + return { + "type": "static", + "value": self.threshold, + } + + @property + def name(self) -> str: + """Method identifier.""" + return "flash_skip_softmax" diff --git a/modelopt/torch/sparsity/attention_sparsity/methods/registry.py b/modelopt/torch/sparsity/attention_sparsity/methods/registry.py new file mode 100644 index 000000000..b34b5b887 --- /dev/null +++ b/modelopt/torch/sparsity/attention_sparsity/methods/registry.py @@ -0,0 +1,133 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Registry and base class for sparse attention methods.""" + +from abc import ABC, abstractmethod +from typing import Any + +import torch + + +class SparseAttentionMethod(ABC): + """Base class for sparse attention methods.""" + + @abstractmethod + def apply_sparsity( + self, + query: torch.Tensor | None = None, + key: torch.Tensor | None = None, + value: torch.Tensor | None = None, + attention_scores: torch.Tensor | None = None, + ) -> tuple[torch.Tensor | None, torch.Tensor | None, torch.Tensor | None, torch.Tensor | None]: + """Apply sparsity to attention computation. + + Args: + query: Query tensor + key: Key tensor + value: Value tensor + attention_scores: Pre-computed attention scores + + Returns: + Tuple of (query, key, value, attention_scores) with sparsity applied + """ + + def get_threshold_info(self) -> dict[str, Any]: + """Get threshold information for display/debugging. + + Returns: + Dictionary with threshold information. Should include: + - 'type': 'static', 'dynamic', or 'none' + - 'value': threshold value (for static) + - 'scale_factor': scale factor (for dynamic) + - Other method-specific info + """ + return {"type": "none", "value": None} + + @property + @abstractmethod + def name(self) -> str: + """Method name identifier.""" + + +# Method Registry with versioning support +_SPARSE_ATTENTION_METHODS: dict[str, dict[str, type[SparseAttentionMethod]]] = {} + + +def register_sparse_method(name: str, version: str = "v1"): + """Decorator to register sparse attention methods with version support. + + Args: + name: Method name to register + version: Version string (default: "v1") + + Example:: + + @register_sparse_method("my_method", version="v3") + class MyMethodV3(SparseAttentionMethod): ... + """ + + def decorator(cls: type[SparseAttentionMethod]): + if name not in _SPARSE_ATTENTION_METHODS: + _SPARSE_ATTENTION_METHODS[name] = {} + + if version in _SPARSE_ATTENTION_METHODS[name]: + import warnings + + warnings.warn( + f"Overriding existing sparse attention method: {name}@{version}", + RuntimeWarning, + stacklevel=2, + ) + + _SPARSE_ATTENTION_METHODS[name][version] = cls + return cls + + return decorator + + +def get_sparse_method(name: str, version: str | None = None) -> type[SparseAttentionMethod]: + """Get sparse attention method by name and optional version. + + Args: + name: Method name to retrieve + version: Optional version string. If None, uses latest version. + + Returns: + Method class + + Raises: + ValueError: If method name or version is not registered + + Example: + >>> get_sparse_method("flash_skip_softmax") # Latest version + >>> get_sparse_method("flash_skip_softmax", "v1") # Specific version + """ + if name not in _SPARSE_ATTENTION_METHODS: + available = list(_SPARSE_ATTENTION_METHODS.keys()) + raise ValueError(f"Unknown sparse attention method: {name}. Available: {available}") + + method_versions = _SPARSE_ATTENTION_METHODS[name] + + if not version: + version = sorted(method_versions.keys())[-1] + + if version not in method_versions: + available_versions = list(method_versions.keys()) + raise ValueError( + f"Unknown version {version} for method {name}. Available: {available_versions}" + ) + + return method_versions[version] diff --git a/modelopt/torch/sparsity/attention_sparsity/mode.py b/modelopt/torch/sparsity/attention_sparsity/mode.py new file mode 100644 index 000000000..f389509a5 --- /dev/null +++ b/modelopt/torch/sparsity/attention_sparsity/mode.py @@ -0,0 +1,85 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Sparse attention mode descriptor for ModelOpt.""" + +from modelopt.torch.opt.config import ModeloptBaseConfig +from modelopt.torch.opt.mode import ( + ConvertEntrypoint, + ModeDescriptor, + RestoreEntrypoint, + UpdateEntrypoint, + _ModeRegistryCls, +) + +from .config import SparseAttentionConfig +from .conversion import ( + convert_to_sparse_attention_model, + restore_sparse_attention_model, + update_sparse_attention_metadata, +) + +# Create registry for sparse attention modes +SparseAttentionModeRegistry = _ModeRegistryCls("sparse_attention") + + +@SparseAttentionModeRegistry.register_mode +class SparseAttentionModeDescriptor(ModeDescriptor): + """Mode descriptor for sparse attention optimization. + + This mode enables various sparse attention methods to reduce + computational complexity and memory usage in transformer models. + """ + + @property + def name(self) -> str: + """Returns the value (str representation) of the mode.""" + return "sparse_attention" + + @property + def config_class(self) -> type[ModeloptBaseConfig]: + """Specifies the config class for the mode.""" + return SparseAttentionConfig + + @property + def next_prohibited_modes(self) -> set[str] | None: + """Modes that should not be applied after this mode.""" + # Can work with quantization but not with weight sparsity + return {"sparsity"} + + @property + def export_mode(self) -> str | None: + """The mode that corresponds to the export mode of this mode.""" + return "export_sparse_attention" + + @property + def convert(self) -> ConvertEntrypoint: + """The mode's entrypoint for converting a model.""" + return convert_to_sparse_attention_model + + @property + def restore(self) -> RestoreEntrypoint: + """The mode's entrypoint for restoring a model.""" + return restore_sparse_attention_model + + @property + def update_for_save(self) -> UpdateEntrypoint: + """The mode's entrypoint for updating the model's state before saving.""" + return update_sparse_attention_metadata + + @property + def update_for_new_mode(self) -> UpdateEntrypoint: + """The mode's entrypoint for updating the model's state before new mode.""" + return update_sparse_attention_metadata diff --git a/modelopt/torch/sparsity/attention_sparsity/model_sparsify.py b/modelopt/torch/sparsity/attention_sparsity/model_sparsify.py new file mode 100644 index 000000000..b6b1e809f --- /dev/null +++ b/modelopt/torch/sparsity/attention_sparsity/model_sparsify.py @@ -0,0 +1,177 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Main API functions for sparse attention optimization.""" + +from typing import Any + +import torch + +from modelopt.torch.opt.conversion import apply_mode +from modelopt.torch.opt.searcher import ForwardLoop + +from .calibration import calibrate_sparse_attention +from .config import SparseAttentionConfig +from .mode import SparseAttentionModeRegistry + +__all__ = [ + "calibrate", + "sparsify", +] + + +def sparsify( + model: torch.nn.Module, + config: dict[str, Any] | SparseAttentionConfig, + forward_loop: ForwardLoop | None = None, +) -> torch.nn.Module: + """Applies sparse attention optimization to the model in-place. + + This method performs replacement of attention modules with their sparse counterparts. + + Args: + model: A pytorch model + config: A dictionary or an instance of + :class:`SparseAttentionConfig ` + specifying the values for keys ``"sparse_cfg"`` and ``"method"``. + + The ``"sparse_cfg"`` key specifies the sparse attention configurations. + The ``"method"`` key specifies the sparse attention method (e.g., "flash_skip_softmax"). + + Sparse attention configurations is a dictionary mapping wildcards or filter functions + to its sparse attention attributes. The wildcards or filter functions are matched + against the module names. The sparse attention attributes include ``"threshold"``, + ``"enable"``, and method-specific parameters. + + An example ``config`` dictionary is given below: + + .. code-block::python + + config = { + "sparse_cfg": { + # Phase-aware thresholds with backend selection + "*attention*": { + "method": "flash_skip_softmax", + "threshold": {"prefill": 1e-3, "decode": 1e-5}, + "backend": "pytorch", # Only pytorch backend supported + "enable": True, + }, + # Disable for specific layers + "*layer.0*": {"enable": False}, + # Default settings + "default": {"enable": False}, + }, + } + + For automatic threshold calibration using RULER dataset: + + .. code-block::python + + config = { + "sparse_cfg": { + "*attention*": { + "method": "flash_skip_softmax", + "backend": "pytorch", + "enable": True, + "calibration": { # Enables automatic threshold calibration + "target_sparse_ratio": 0.5, + "samples": 48, + "max_seqlen": 8192, + }, + }, + "default": {"enable": False}, + }, + } + + The ``"backend"`` parameter must be set to ``"pytorch"``: + + - ``"pytorch"``: Softmax patching approach (only supported backend) + + This requires the model to be loaded with ``attn_implementation="eager"``. + + forward_loop: Reserved for future use. + + Here are a few examples for correct ``forward_loop`` definitions: + + Example 1: + + .. code-block:: + + def forward_loop(model) -> None: + # iterate over the data loader and forward data through the model + for batch in data_loader: + model(batch) + + Example 2: + + .. code-block:: + + def forward_loop(model) -> float: + # evaluate the model on the task + return evaluate(model, task, ....) + + .. note:: + + Calibration does not require forwarding the entire dataset through the model. + Please subsample the dataset or reduce the number of batches if needed. + + .. important:: + + The model must always be loaded with ``attn_implementation="eager"`` + for sparse attention to work correctly: + + .. code-block:: python + + from transformers import AutoModelForCausalLM + + model = AutoModelForCausalLM.from_pretrained(b + model_path, + attn_implementation="eager", # Required for sparse attention + torch_dtype=torch.bfloat16, + ) + + This is because sparse attention works by patching torch.nn.functional.softmax, + which is only called in the eager attention implementation. + + Returns: + A pytorch model which has sparse attention applied and optionally calibrated. + """ + model = apply_mode( + model, mode=[("sparse_attention", config)], registry=SparseAttentionModeRegistry + ) + + # Calibrate the sparsity ratio of the attention modules + return calibrate(model, config, forward_loop=forward_loop) + + +def calibrate( + model: torch.nn.Module, + config: dict[str, Any] | SparseAttentionConfig, + forward_loop: ForwardLoop | None = None, +) -> torch.nn.Module: + """Calibrates sparse attention thresholds based on target sparsity. + + Args: + model: Model with sparse attention modules + config: Sparse attention configuration with calibration settings + forward_loop: Optional callable that forwards calibration data through the model. + If provided, uses this for calibration data. + If None, will auto-generate RULER dataset for calibration. + + Returns: + The calibrated model with optimized sparse attention thresholds. + """ + calibrate_sparse_attention(model, config, forward_loop=forward_loop) + return model diff --git a/modelopt/torch/sparsity/attention_sparsity/plugins/__init__.py b/modelopt/torch/sparsity/attention_sparsity/plugins/__init__.py new file mode 100644 index 000000000..ba8c8b821 --- /dev/null +++ b/modelopt/torch/sparsity/attention_sparsity/plugins/__init__.py @@ -0,0 +1,22 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Plugins for sparse attention integration with various frameworks.""" + +from .huggingface import register_sparse_attention_on_the_fly + +__all__ = [ + "register_sparse_attention_on_the_fly", +] diff --git a/modelopt/torch/sparsity/attention_sparsity/plugins/huggingface.py b/modelopt/torch/sparsity/attention_sparsity/plugins/huggingface.py new file mode 100644 index 000000000..0c4a8baf9 --- /dev/null +++ b/modelopt/torch/sparsity/attention_sparsity/plugins/huggingface.py @@ -0,0 +1,120 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Dynamic sparse attention registration for HuggingFace models.""" + +import torch.nn as nn +import transformers + +from modelopt.torch.opt.dynamic import DynamicModule + +from ..sparse_attention import SparseAttentionModule, SparseAttentionRegistry + + +class _GenericSparseAttention(SparseAttentionModule): + """Generic sparse attention that works with any HF attention module. + + This class provides a universal sparse attention wrapper that can + work with various transformer attention implementations. + """ + + def _setup(self): + """Setup sparse attention for any attention type. + + The base SparseAttentionModule handles detection and initialization. + """ + super()._setup() + + def get_attn_type(self, attn_module) -> type: + """Get the original attention type. + + Args: + attn_module: Attention module (possibly wrapped) + + Returns: + Original class type + """ + # If this is a DynamicModule, get the original class + if isinstance(attn_module, DynamicModule): + return attn_module.get_original_cls_by_level(level=0) + return type(attn_module) + + +def register_sparse_attention_on_the_fly(model: nn.Module) -> bool: + """Dynamically register sparse attention for any model. + + This function automatically detects attention modules in the model + and registers them for sparse attention optimization. + + Args: + model: Model to process + + Returns: + True if any modules were registered + """ + if not _is_supported_model(model): + return False + + registered_count = 0 + attention_types = set() + + for name, module in model.named_modules(): + # Skip if already a sparse attention module + if isinstance(module, SparseAttentionModule): + continue + + # Check if this is an attention module by name + module_type = type(module) + type_name = module_type.__name__ + + # Common attention module patterns + is_attention = "attention" in type_name.lower() or type_name.endswith( + ("Attention", "SelfAttention") + ) + + if is_attention and module_type not in SparseAttentionRegistry: + # Register attention type + if module_type not in attention_types: + SparseAttentionRegistry.register({module_type: type_name})(_GenericSparseAttention) + attention_types.add(module_type) + registered_count += 1 + print(f"Registered {type_name} for sparse attention optimization") + + if registered_count > 0: + print(f"Dynamically registered {registered_count} attention module types for sparsity") + + return registered_count > 0 + + +def _is_supported_model(model: nn.Module) -> bool: + """Check if model is supported for sparse attention. + + Supports HuggingFace PreTrainedModel and any PyTorch model with attention modules. + + Args: + model: Model to check + + Returns: + True if model is supported + """ + # Check for HuggingFace PreTrainedModel + try: + if isinstance(model, transformers.PreTrainedModel): + return True + except ImportError: + pass + + # Support any PyTorch model with attention modules + return isinstance(model, nn.Module) diff --git a/modelopt/torch/sparsity/attention_sparsity/sparse_attention.py b/modelopt/torch/sparsity/attention_sparsity/sparse_attention.py new file mode 100644 index 000000000..d31a9e882 --- /dev/null +++ b/modelopt/torch/sparsity/attention_sparsity/sparse_attention.py @@ -0,0 +1,215 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Extensible sparse attention module.""" + +from typing import Any + +import torch +import torch.nn.functional as F + +from modelopt.torch.opt.dynamic import DynamicModule, _DMRegistryCls +from modelopt.torch.quantization.utils import replace_function + +from .config import SparseAttentionAttributeConfig +from .methods import get_sparse_method +from .stats_manager import SparseAttentionStatsManager + + +class SparseAttentionModule(DynamicModule): + """Generic sparse attention module wrapper for applying sparsity to attention layers. + + This module wraps existing attention implementations to add sparse attention + capabilities by patching torch.nn.functional.softmax. + + Forward Flow: + ------------- + 1. Check if sparse attention is enabled (pass-through if disabled) + 2. Create softmax patch context with sparse_softmax function + 3. Apply sparse attention by patching F.softmax: + - Patches torch.nn.functional.softmax with sparse_softmax + - sparse_softmax applies method's sparsity logic before softmax + 4. Forward through original attention with sparsity applied + + Requirements: + ------------- + - Model must be loaded with attn_implementation="eager" for proper softmax interception + - Only PyTorch backend is supported (patches F.softmax) + + Attributes: + ----------- + _enabled: bool + Whether sparse attention is enabled + _method: str + The sparse attention method to use (e.g., "flash_skip_softmax") + _method_config: dict + Configuration dictionary for the sparse method (threshold, br, bc, etc.) + _sparse_method_instance: SparseAttentionMethod + Instance of the configured sparse attention method + """ + + def set_from_attribute_config( + self, attribute_cfg: SparseAttentionAttributeConfig | dict | None = None + ): + """Set sparse attention attributes from configuration. + + Similar to TensorQuantizer.set_from_attribute_config. + + Args: + attribute_cfg: Sparse attention attribute configuration. + """ + # Ensure config is validated through Pydantic + if not isinstance(attribute_cfg, SparseAttentionAttributeConfig): + attribute_cfg = SparseAttentionAttributeConfig(**(attribute_cfg or {})) + + # Store raw config for method initialization + self._method_config = {} + + # Define which attributes are method-specific vs module-specific + # Module-specific attributes control the SparseAttentionModule behavior + _module_attributes = {"enable", "method"} + + # Custom setters for special module attributes + _custom_setters = { + "enable": ("_enabled", lambda val: bool(val)), + "method": ("_method", lambda val: str(val)), + } + + # Process each attribute from validated config + for attribute, val in attribute_cfg.model_dump().items(): + # Validate attribute if using config class + if hasattr(SparseAttentionAttributeConfig, "model_fields"): + assert attribute in SparseAttentionAttributeConfig.model_fields, ( + f"{attribute} is not a valid SparseAttentionModule attribute" + ) + + if attribute in _module_attributes: + # Module-level attribute: store with underscore prefix + attr_name, setter = _custom_setters.get(attribute, (f"_{attribute}", lambda v: v)) + setattr(self, attr_name, setter(val)) + else: + # Method-specific attribute: store in config dict + self._method_config[attribute] = val + + # Initialize sparse method instance + self._init_sparse_method() + + # Create stats manager based on config + if self._method_config.get("collect_stats", False): + self._stats_manager = SparseAttentionStatsManager( + module_name="sparse_attention", enabled=True + ) + else: + self._stats_manager = None + + def _init_sparse_method(self): + """Initialize the sparse method instance.""" + method_class = get_sparse_method(self._method) + + # Initialize the sparse method instance + # _method_config is always initialized in set_from_attribute_config + self._sparse_method_instance = method_class(method_config=self._method_config) # type: ignore[call-arg] + + def enable(self): + """Enable sparse attention for this module.""" + self._enabled = True + + def disable(self): + """Disable sparse attention for this module.""" + self._enabled = False + + @property + def is_enabled(self) -> bool: + """Check if sparse attention is enabled.""" + return getattr(self, "_enabled", True) + + def get_stats(self) -> dict: + """Get sparsity statistics from the stats manager. + + Returns: + Dictionary with sparsity statistics including 'average_sparsity' if available. + Returns empty dict if stats manager is not enabled. + """ + if self._stats_manager is not None and self._stats_manager.enabled: + return self._stats_manager.get_summary() + return {} + + def get_threshold_info(self) -> dict[str, Any]: + """Get threshold information from the sparse method instance. + + Returns: + Dictionary with threshold information from the sparse method. + """ + if hasattr(self, "_sparse_method_instance") and self._sparse_method_instance is not None: + return self._sparse_method_instance.get_threshold_info() + return {"type": "none", "value": None} + + def _setup(self): + """Setup called by DynamicModule.""" + # Apply default configuration if not yet configured + if not hasattr(self, "_method"): + self.set_from_attribute_config(None) + + def forward(self, *args, **kwargs): + """Forward with selected sparse attention method. + + This method dispatches to the appropriate sparse attention implementation + based on the configured method and backend. + """ + # Pass through if sparse attention is disabled + if not self.is_enabled: + return super().forward(*args, **kwargs) + + # Get the appropriate context manager for this configuration + context = self._get_sparse_context() + + # Apply sparse attention through the context + with context: + result = super().forward(*args, **kwargs) + + # Collect stats if manager is available + if self._stats_manager is not None and hasattr(self._sparse_method_instance, "_last_stats"): + self._stats_manager.collect(self._sparse_method_instance._last_stats) + + return result + + def _get_sparse_context(self): + """Get the softmax patch context for applying sparse attention.""" + return self._create_softmax_patch_context() + + def _create_softmax_patch_context(self): + """Create context manager for patching softmax function.""" + return replace_function(torch.nn.functional, "softmax", self._create_sparse_softmax()) + + def _create_sparse_softmax(self): + """Create sparse softmax function for current method.""" + original_softmax = F.softmax + + def sparse_softmax(input, dim=-1, *args, **kwargs): + # Let the method handle the sparsification + _, _, _, sparse_input = self._sparse_method_instance.apply_sparsity( + None, None, None, input + ) + + # Use sparse input if modified, otherwise use original + if sparse_input is not None: + return original_softmax(sparse_input, dim, *args, **kwargs) + return original_softmax(input, dim, *args, **kwargs) + + return sparse_softmax + + +# Create registry for sparse attention modules +SparseAttentionRegistry = _DMRegistryCls("SparseAttention", SparseAttentionModule) diff --git a/modelopt/torch/sparsity/attention_sparsity/stats_manager.py b/modelopt/torch/sparsity/attention_sparsity/stats_manager.py new file mode 100644 index 000000000..9fc57a0b1 --- /dev/null +++ b/modelopt/torch/sparsity/attention_sparsity/stats_manager.py @@ -0,0 +1,137 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Statistics manager for sparse attention modules.""" + + +class SparseAttentionStatsManager: + """Centralized statistics manager for sparse attention. + + This class is the single source of truth for all statistics collection + in sparse attention modules. It handles both runtime aggregation and + per-sample calibration statistics. + + Design principles: + - Single responsibility: only stats management + - No computation: receives pre-computed stats from methods + - Optional: can be None if stats collection disabled + - Zero overhead when disabled + """ + + def __init__(self, module_name: str, enabled: bool = True): + """Initialize stats manager. + + Args: + module_name: Name of the module this manager is attached to + enabled: Whether stats collection is enabled + """ + self.module_name = module_name + self.enabled = enabled + self.calibration_mode = False + + # Aggregated stats (running totals across all forward passes) + self.aggregated_stats: dict = { + "total_calls": 0, + "total_blocks": 0, + "sparse_blocks": 0, + "phase_counts": {"prefill": 0, "decode": 0, "unknown": 0}, + } + + # Per-sample stats (only populated during calibration) + self.per_sample_stats: list[dict] = [] + + def collect(self, stats: dict): + """Collect statistics from a single forward pass. + + Args: + stats: Dictionary containing statistics from method computation. + Expected keys: sparsity, phase, total_blocks, sparse_blocks, + sample_length (optional) + """ + if not self.enabled: + return + + # Update aggregated stats + self.aggregated_stats["total_calls"] += 1 + self.aggregated_stats["total_blocks"] += stats.get("total_blocks", 0) + self.aggregated_stats["sparse_blocks"] += stats.get("sparse_blocks", 0) + + phase = stats.get("phase", "unknown") + if phase in self.aggregated_stats["phase_counts"]: + self.aggregated_stats["phase_counts"][phase] += 1 + + # In calibration mode, store per-sample stats + if self.calibration_mode: + self.per_sample_stats.append( + { + "module": self.module_name, + "sparsity": stats.get("sparsity", 0.0), + "sample_length": stats.get("sample_length", 0), + "phase": phase, + } + ) + + def get_summary(self) -> dict: + """Get aggregated statistics summary. + + Returns: + Dictionary with module name, total calls, average sparsity, + and phase distribution. + """ + total_blocks = self.aggregated_stats["total_blocks"] + if total_blocks > 0: + avg_sparsity = self.aggregated_stats["sparse_blocks"] / total_blocks + else: + avg_sparsity = 0.0 + + return { + "module": self.module_name, + "total_calls": self.aggregated_stats["total_calls"], + "average_sparsity": avg_sparsity, + "phase_distribution": self.aggregated_stats["phase_counts"].copy(), + } + + def set_calibration_mode(self, enabled: bool, reset_history: bool = True): + """Enable or disable calibration mode. + + In calibration mode, per-sample statistics are stored for detailed + analysis. Otherwise, only aggregated stats are maintained. + + Args: + enabled: Whether to enable calibration mode + reset_history: Whether to clear per_sample_stats when enabling + """ + self.calibration_mode = enabled + if enabled and reset_history: + self.per_sample_stats = [] + + def reset(self): + """Reset all statistics to initial state.""" + self.aggregated_stats = { + "total_calls": 0, + "total_blocks": 0, + "sparse_blocks": 0, + "phase_counts": {"prefill": 0, "decode": 0, "unknown": 0}, + } + self.per_sample_stats = [] + + def get_calibration_stats(self) -> list[dict]: + """Get per-sample calibration statistics. + + Returns: + List of per-sample statistics dictionaries. + Empty list if not in calibration mode. + """ + return self.per_sample_stats diff --git a/setup.py b/setup.py index 158c91e40..3ccd773b3 100644 --- a/setup.py +++ b/setup.py @@ -82,6 +82,8 @@ "torch-geometric", "tox>4.18", "tox-current-env>=0.0.12", + "nltk", + "wonderwords", ], # docs "dev-docs": [ diff --git a/tests/_test_utils/torch_sparsity/sparse_attention_common.py b/tests/_test_utils/torch_sparsity/sparse_attention_common.py new file mode 100644 index 000000000..5ed079966 --- /dev/null +++ b/tests/_test_utils/torch_sparsity/sparse_attention_common.py @@ -0,0 +1,197 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Common utilities for sparse attention testing.""" + +import torch +import torch.nn as nn + +import modelopt.torch.opt as mto +import modelopt.torch.sparsity.attention_sparsity as sparse_attn +from modelopt.torch.sparsity.attention_sparsity.sparse_attention import SparseAttentionModule + + +# Test models for sparse attention +class SimpleAttentionModel(nn.Module): + """Simple attention model for testing.""" + + def __init__(self, hidden_size=256, num_heads=8): + super().__init__() + self.hidden_size = hidden_size + self.num_heads = num_heads + self.attention = nn.MultiheadAttention( + embed_dim=hidden_size, num_heads=num_heads, batch_first=True + ) + self.fc = nn.Linear(hidden_size, hidden_size) + + def forward(self, x): + attn_output, _ = self.attention(x, x, x, need_weights=False) + return self.fc(attn_output) + + @classmethod + def get_input(cls, hidden_size=256, seq_len=10, batch_size=2): + """Get input tensor for testing.""" + return torch.randn(batch_size, seq_len, hidden_size) + + +class SimpleTransformerEncoderLayer(nn.Module): + """Simple TransformerEncoderLayer wrapper for testing.""" + + def __init__(self, d_model=128, nhead=4, dim_feedforward=256): + super().__init__() + self.d_model = d_model + self.nhead = nhead + self.layer = nn.TransformerEncoderLayer( + d_model=d_model, + nhead=nhead, + dim_feedforward=dim_feedforward, + batch_first=True, + ) + + def forward(self, x): + return self.layer(x) + + @classmethod + def get_input(cls, d_model=128, seq_len=20, batch_size=2): + """Get input tensor for testing.""" + return torch.randn(batch_size, seq_len, d_model) + + +class SimpleTransformerEncoder(nn.Module): + """Simple TransformerEncoder wrapper for testing.""" + + def __init__(self, d_model=128, nhead=4, num_layers=2): + super().__init__() + self.d_model = d_model + self.nhead = nhead + self.encoder = nn.TransformerEncoder( + nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, batch_first=True), + num_layers=num_layers, + ) + + def forward(self, x): + return self.encoder(x) + + @classmethod + def get_input(cls, d_model=128, seq_len=10, batch_size=2): + """Get input tensor for testing.""" + return torch.randn(batch_size, seq_len, d_model) + + +# Test configurations +FLASH_SKIP_SOFTMAX_DEFAULT_CFG = { + "sparse_cfg": { + "*attention*": { + "method": "flash_skip_softmax", + "threshold": 1e-4, + "br": 128, + "bc": 128, + "enable": True, + } + }, +} + +FLASH_SKIP_SOFTMAX_PHASE_AWARE_CFG = { + "sparse_cfg": { + "*attention*": { + "method": "flash_skip_softmax", + "threshold": {"prefill": 1e-3, "decode": 1e-5}, + "br": 128, + "bc": 128, + "enable": True, + } + }, +} + + +def get_test_configs(): + """Get test configurations for parameterized tests. + + Note: Calibration config excluded (requires GPU and real tokenizers). + """ + return [FLASH_SKIP_SOFTMAX_DEFAULT_CFG, FLASH_SKIP_SOFTMAX_PHASE_AWARE_CFG] + + +def sparsify_model_and_forward(model, config, calib_data): + """Apply sparse attention and run forward passes. + + Args: + model: Model to sparsify + config: Sparse attention configuration + calib_data: List of calibration data tensors + + Returns: + Sparsified model + """ + + def forward_loop(model): + for batch in calib_data: + model(batch) + + # Apply sparse attention + model = sparse_attn.sparsify(model, config, forward_loop=forward_loop) + + # Verify sparse attention modules were inserted + assert any(isinstance(m, SparseAttentionModule) for m in model.modules()), ( + "No sparse attention modules found" + ) + + # Test forward passes + model.eval() + with torch.no_grad(): + for batch in calib_data: + output = model(batch) + assert not torch.isnan(output).any(), ( + f"NaN detected in output for batch shape {batch.shape}" + ) + assert output is not None, f"Output is None for batch shape {batch.shape}" + + return model + + +def save_restore_test(model_cls, device, sparse_config, atol=1e-6): + """Test save and restore of sparse attention state. + + Args: + model_cls: Model class to test + device: Device to run on ('cpu' or 'cuda') + sparse_config: Sparse attention configuration + """ + # Create and sparsify reference model + model_sparse = model_cls().to(device) + calib_data = [model_sparse.get_input().to(device) for _ in range(2)] + + sparsify_model_and_forward(model_sparse, sparse_config, calib_data) + + # Save state + state_dict = mto.modelopt_state(model_sparse) + + # Restore to new model + model_restored = model_cls().to(device) + mto.restore_from_modelopt_state(model_restored, state_dict) + model_restored.load_state_dict(model_sparse.state_dict()) + + # Verify outputs match + test_input = calib_data[0] + model_sparse.eval() + model_restored.eval() + + with torch.no_grad(): + output_sparse = model_sparse(test_input) + output_restored = model_restored(test_input) + + assert torch.allclose(output_sparse, output_restored, atol), ( + "Restored model output doesn't match original" + ) diff --git a/tests/examples/llm_eval/test_llm_eval.py b/tests/examples/llm_eval/test_llm_eval.py index 0abf78b53..88d29dedc 100644 --- a/tests/examples/llm_eval/test_llm_eval.py +++ b/tests/examples/llm_eval/test_llm_eval.py @@ -36,3 +36,20 @@ def test_llama_eval_fp8(): finally: # Force kill llm-serve if it's still running subprocess.run(["pkill", "-f", "llm-serve"], check=False) + + +def test_llama_eval_sparse_attention(tiny_llama_path): + """Test sparse attention with llm_eval integration.""" + try: + # Test with default sparse attention config (no quantization) + run_llm_ptq_command( + model=tiny_llama_path, + quant="none", # No quantization, only sparse attention + tasks="lm_eval", + lm_eval_tasks="hellaswag", + lm_eval_limit=0.05, # Small limit for fast test + sparse_cfg="SKIP_SOFTMAX_DEFAULT", + batch=4, + ) + finally: + subprocess.run(["pkill", "-f", "llm-serve"], check=False) diff --git a/tests/examples/llm_sparsity/attention_sparsity/test_attention_sparsity.py b/tests/examples/llm_sparsity/attention_sparsity/test_attention_sparsity.py new file mode 100644 index 000000000..161d5d1df --- /dev/null +++ b/tests/examples/llm_sparsity/attention_sparsity/test_attention_sparsity.py @@ -0,0 +1,55 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test attention sparsity example script.""" + +import pytest +from _test_utils.examples.run_command import extend_cmd_parts, run_example_command +from _test_utils.torch.misc import minimum_gpu + + +def run_attention_sparsity_command(*, model: str, method: str = "skip_softmax", **kwargs): + """Run attention sparsity example script. + + Args: + model: Path to model + method: Sparse attention method (corresponds to --sparse_attn arg) + **kwargs: Additional arguments to pass to the script + """ + kwargs.update( + { + "pyt_ckpt_path": model, + "sparse_attn": method, + } + ) + kwargs.setdefault("seq_len", 128) + kwargs.setdefault("num_samples", 1) + kwargs.setdefault("max_new_tokens", 16) + + cmd_parts = extend_cmd_parts(["python", "hf_sa.py"], **kwargs) + run_example_command(cmd_parts, "llm_sparsity/attention_sparsity") + + +@minimum_gpu(1) +@pytest.mark.parametrize("method", ["skip_softmax", "skip_softmax_calib"]) +def test_attention_sparsity(tiny_llama_path, tmp_path, method): + """Test sparse attention with TinyLlama (with and without calibration).""" + run_attention_sparsity_command( + model=tiny_llama_path, + method=method, + seq_len=128, + num_samples=1, + max_new_tokens=10, + ) diff --git a/tests/examples/llm_sparsity/test_llama_sparsify.py b/tests/examples/llm_sparsity/weight_sparsity/test_llama_sparsify.py similarity index 93% rename from tests/examples/llm_sparsity/test_llama_sparsify.py rename to tests/examples/llm_sparsity/weight_sparsity/test_llama_sparsify.py index 7f9ef929b..7094b2989 100644 --- a/tests/examples/llm_sparsity/test_llama_sparsify.py +++ b/tests/examples/llm_sparsity/weight_sparsity/test_llama_sparsify.py @@ -31,7 +31,7 @@ def run_llm_sparsity_command( kwargs.setdefault("model_max_length", 1024) cmd_parts = extend_cmd_parts(["python", "hf_pts.py"], **kwargs) - run_example_command(cmd_parts, "llm_sparsity") + run_example_command(cmd_parts, "llm_sparsity/weight_sparsity") def run_llm_sparsity_ft_command( @@ -51,13 +51,15 @@ def run_llm_sparsity_ft_command( kwargs.setdefault("eval_bs", 1) cmd_parts = extend_cmd_parts(["bash", "launch_finetune.sh"], **kwargs) - run_example_command(cmd_parts, "llm_sparsity") + run_example_command(cmd_parts, "llm_sparsity/weight_sparsity") @pytest.fixture(scope="session") def data_path(tmp_path_factory): data_path = tmp_path_factory.mktemp("data") - run_example_command(["python", "data_prep.py", "--save_path", data_path], "llm_sparsity") + run_example_command( + ["python", "data_prep.py", "--save_path", data_path], "llm_sparsity/weight_sparsity" + ) # Copy eval data to train path for faster test run_example_command( diff --git a/tests/gpu/torch/sparsity/attention_sparsity/test_attention_sparsity_gpu.py b/tests/gpu/torch/sparsity/attention_sparsity/test_attention_sparsity_gpu.py new file mode 100644 index 000000000..bad077fdb --- /dev/null +++ b/tests/gpu/torch/sparsity/attention_sparsity/test_attention_sparsity_gpu.py @@ -0,0 +1,144 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""GPU tests for attention sparsity module.""" + +import pytest +import torch +from _test_utils.torch_sparsity.sparse_attention_common import ( + FLASH_SKIP_SOFTMAX_DEFAULT_CFG, + SimpleAttentionModel, + SimpleTransformerEncoder, + SimpleTransformerEncoderLayer, + get_test_configs, + save_restore_test, + sparsify_model_and_forward, +) + +import modelopt.torch.sparsity.attention_sparsity as sparse_attn + +# Skip all tests if GPU is not available +pytestmark = pytest.mark.skipif(not torch.cuda.is_available(), reason="GPU not available") + + +class TestAttentionSparsityGPU: + """GPU tests for attention sparsity.""" + + @pytest.fixture(autouse=True) + def setup(self): + """Setup for each test.""" + self.device = torch.device("cuda") + torch.cuda.empty_cache() + + @pytest.mark.parametrize( + "model_cls", + [SimpleAttentionModel, SimpleTransformerEncoderLayer, SimpleTransformerEncoder], + ) + @pytest.mark.parametrize("config", get_test_configs()) + def test_gpu_forward(self, model_cls, config): + """Test sparse attention forward pass on GPU.""" + model = model_cls().to(self.device) + calib_data = [model.get_input().to(self.device) for _ in range(2)] + + sparsify_model_and_forward(model, config, calib_data) + + # Additional GPU-specific checks + for batch in calib_data: + with torch.no_grad(): + output = model(batch) + assert output.device.type == "cuda" + + @pytest.mark.parametrize( + "model_cls", + [SimpleAttentionModel, SimpleTransformerEncoderLayer, SimpleTransformerEncoder], + ) + def test_save_restore(self, model_cls): + """Test save and restore on GPU.""" + save_restore_test(model_cls, "cuda", FLASH_SKIP_SOFTMAX_DEFAULT_CFG) + + @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16]) + def test_different_dtypes(self, dtype): + """Test sparse attention with different dtypes.""" + model = SimpleTransformerEncoderLayer(d_model=256, nhead=8).to(self.device).to(dtype) + calib_data = [model.get_input(d_model=256).to(self.device).to(dtype) for _ in range(2)] + + sparse_model = sparsify_model_and_forward(model, FLASH_SKIP_SOFTMAX_DEFAULT_CFG, calib_data) + + # Test forward + x = model.get_input(d_model=256).to(self.device).to(dtype) + with torch.no_grad(): + output = sparse_model(x) + + assert output.dtype == dtype + assert not torch.isnan(output).any() + if dtype != torch.bfloat16: # bfloat16 can have inf + assert not torch.isinf(output).any() + + def test_backward_pass(self): + """Test that gradients flow correctly through sparse attention.""" + model = SimpleAttentionModel(hidden_size=128, num_heads=4).to(self.device) + model = sparse_attn.sparsify(model, FLASH_SKIP_SOFTMAX_DEFAULT_CFG) + + # Enable training mode + model.train() + + x = model.get_input(hidden_size=128, seq_len=32).to(self.device) + x.requires_grad = True + + # Forward + output = model(x) + loss = output.sum() + + # Backward + loss.backward() + + # Check gradients exist + assert x.grad is not None + assert not torch.isnan(x.grad).any() + + # Check model gradients + for name, param in model.named_parameters(): + if param.requires_grad: + assert param.grad is not None, f"No gradient for {name}" + + @pytest.mark.parametrize("seq_len", [1, 1024, 2048]) + def test_various_sequence_lengths(self, seq_len): + """Test sparse attention with various sequence lengths.""" + model = SimpleAttentionModel(hidden_size=128, num_heads=4).to(self.device) + model = sparse_attn.sparsify(model, FLASH_SKIP_SOFTMAX_DEFAULT_CFG) + + x = model.get_input(hidden_size=128, seq_len=seq_len, batch_size=1).to(self.device) + + model.eval() + with torch.no_grad(): + output = model(x) + + assert output.shape == (1, seq_len, 128) + assert not torch.isnan(output).any() + + @pytest.mark.parametrize("batch_size", [1, 8, 16]) + def test_various_batch_sizes(self, batch_size): + """Test sparse attention with various batch sizes.""" + model = SimpleTransformerEncoderLayer(d_model=128, nhead=4).to(self.device) + model = sparse_attn.sparsify(model, FLASH_SKIP_SOFTMAX_DEFAULT_CFG) + + x = model.get_input(d_model=128, seq_len=64, batch_size=batch_size).to(self.device) + + model.eval() + with torch.no_grad(): + output = model(x) + + assert output.shape == (batch_size, 64, 128) + assert not torch.isnan(output).any() diff --git a/tests/gpu/torch/sparsity/attention_sparsity/test_calibration_gpu.py b/tests/gpu/torch/sparsity/attention_sparsity/test_calibration_gpu.py new file mode 100644 index 000000000..913dc24a0 --- /dev/null +++ b/tests/gpu/torch/sparsity/attention_sparsity/test_calibration_gpu.py @@ -0,0 +1,388 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""GPU tests for sparse attention calibration.""" + +import pytest +import torch +from _test_utils.torch_sparsity.sparse_attention_common import SimpleTransformerEncoderLayer + +import modelopt.torch.opt as mto +from modelopt.torch.sparsity.attention_sparsity import sparsify +from modelopt.torch.sparsity.attention_sparsity.calibration import RulerDatasetBuilder +from modelopt.torch.sparsity.attention_sparsity.sparse_attention import SparseAttentionModule + +# Skip all tests if no GPU available +pytestmark = pytest.mark.skipif(not torch.cuda.is_available(), reason="GPU required") + + +class TestRulerDatasetBuilderGPU: + """Test RULER dataset generation with real tokenizers on GPU.""" + + def test_ruler_generation_with_real_tokenizer(self): + """Test RULER generation with GPT2 tokenizer.""" + builder = RulerDatasetBuilder( + samples=6, # Need at least 6 samples (1 per task) + max_seqlen=1024, # Generates: [1024] + tokenizer_name_or_path="gpt2", + seed=42, + ) + + dataset = builder.build_calibration_dataset() + + # Should generate 6 samples (1 per task) + assert len(dataset) == 6 + + # All samples should have valid structure + for sample in dataset: + assert "input" in sample + assert "length" in sample + assert sample["length"] > 0 + + def test_generated_length_accuracy(self): + """Test that generated token counts are accurate.""" + builder = RulerDatasetBuilder( + samples=3, + max_seqlen=1024, # Generates: [1024] + tokenizer_name_or_path="gpt2", + seed=42, + ) + + dataset = builder.build_calibration_dataset() + + # Check that lengths are within reasonable range of target + for sample in dataset: + # RULER aims for 70-90% of target for context + assert 700 < sample["length"] < 1400 + + def test_multiple_subtasks(self): + """Test generation with multiple RULER subtasks.""" + builder = RulerDatasetBuilder( + samples=12, # Need at least 6 for 1 per task, use 12 for 2 per task + max_seqlen=1024, # Generates: [1024] + tokenizer_name_or_path="gpt2", + seed=42, + ) + + dataset = builder.build_calibration_dataset() + + # Check task distribution (should have multiple tasks from RULER_TASKS) + tasks_found = {s["task"] for s in dataset} + assert len(tasks_found) >= 2 # At least 2 different tasks + + def test_large_context_lengths(self): + """Test with larger context lengths.""" + builder = RulerDatasetBuilder( + samples=24, # 4 lengths * 6 tasks = need 24 for 1 per (length, task) + max_seqlen=8192, # Generates: [8192, 4096, 2048, 1024] + tokenizer_name_or_path="gpt2", + seed=42, + ) + + dataset = builder.build_calibration_dataset() + + assert len(dataset) == 24 + + # Verify we have different lengths + lengths = [s["length"] for s in dataset] + # Should have variety of lengths across the bins + assert len(set(lengths)) > 1 # At least 2 different target lengths used + + +class TestCalibrationGPU: + """Test calibration with real models on GPU.""" + + @pytest.fixture + def simple_model(self): + """Create simple attention model for testing.""" + model = SimpleTransformerEncoderLayer(d_model=256, nhead=8).cuda() + return model + + def test_calibration_simple_model(self, simple_model): + """Test calibration with simple attention model.""" + model = simple_model + + config = { + "sparse_cfg": { + "*attn*": { + "method": "flash_skip_softmax", + "threshold": 1e-3, + "br": 64, + "bc": 64, + "backend": "pytorch", + "enable": True, + "calibration": { + "target_sparse_ratio": 0.5, + "samples": 4, + "max_seqlen": 1024, + }, + } + }, + } + + def forward_loop(model): + # Simple forward loop for calibration + pass + + # Apply sparse attention with calibration + sparse_model = sparsify(model, config, forward_loop=forward_loop) + + # Verify sparse attention modules exist + sparse_modules = [m for m in sparse_model.modules() if isinstance(m, SparseAttentionModule)] + assert len(sparse_modules) > 0 + + # Verify calibration was applied + for module in sparse_modules: + method = module._sparse_method_instance + # Check if calibrated threshold scale factor is set + if hasattr(method, "threshold_scale_factor") and method.threshold_scale_factor: + assert method.threshold_scale_factor > 0 + + def test_calibration_pytorch_backend(self, simple_model): + """Test calibration with pytorch backend.""" + model = simple_model + + config = { + "sparse_cfg": { + "*attn*": { + "method": "flash_skip_softmax", + "threshold": 1e-3, + "backend": "pytorch", + "enable": True, + "calibration": { + "target_sparse_ratio": 0.5, + "samples": 2, + "max_seqlen": 1024, + }, + } + }, + } + + def forward_loop(model): + pass + + sparse_model = sparsify(model, config, forward_loop=forward_loop) + + # Check backend is set correctly + for module in sparse_model.modules(): + if isinstance(module, SparseAttentionModule): + method = module._sparse_method_instance + assert hasattr(method, "backend") + assert method.backend == "pytorch" + + def test_simplified_calibration(self, simple_model): + """Test simplified calibration (prefill phase only).""" + model = simple_model + + config = { + "sparse_cfg": { + "*attn*": { + "method": "flash_skip_softmax", + "threshold": 1e-3, + "enable": True, + "calibration": { + "target_sparse_ratio": 0.5, + "samples": 4, + "max_seqlen": 1024, + }, + } + }, + } + + def forward_loop(model): + pass + + sparse_model = sparsify(model, config, forward_loop=forward_loop) + + # Should complete without errors + assert sparse_model is not None + + def test_calibration_persistence(self, simple_model): + """Test save and restore of calibrated model.""" + model = simple_model + + config = { + "sparse_cfg": { + "*attn*": { + "method": "flash_skip_softmax", + "threshold": 1e-3, + "enable": True, + "calibration": { + "target_sparse_ratio": 0.5, + "samples": 2, + "max_seqlen": 1024, + }, + } + }, + } + + def forward_loop(model): + pass + + # Calibrate model + sparse_model = sparsify(model, config, forward_loop=forward_loop) + + # Save modelopt state + modelopt_state = mto.modelopt_state(sparse_model) + + # Create new model and restore + model_restored = SimpleTransformerEncoderLayer(d_model=256, nhead=8).cuda() + + restored = mto.restore_from_modelopt_state(model_restored, modelopt_state) + + # Check that sparse attention is restored + has_sparse = any(isinstance(m, SparseAttentionModule) for m in restored.modules()) + assert has_sparse + + +class TestCalibrationEndToEnd: + """Integration tests with inference.""" + + @pytest.fixture + def simple_model_setup(self): + """Setup simple model.""" + model = SimpleTransformerEncoderLayer(d_model=256, nhead=8).cuda() + return model + + def test_calibrated_model_inference(self, simple_model_setup): + """Test inference with calibrated model.""" + model = simple_model_setup + + config = { + "sparse_cfg": { + "*attn*": { + "method": "flash_skip_softmax", + "threshold": 1e-3, + "backend": "pytorch", + "enable": True, + "calibration": { + "target_sparse_ratio": 0.5, + "samples": 2, + "max_seqlen": 1024, + }, + } + }, + } + + def forward_loop(model): + pass + + # Calibrate model + sparse_model = sparsify(model, config, forward_loop=forward_loop) + + # Test inference + test_input = SimpleTransformerEncoderLayer.get_input(d_model=256, seq_len=10).cuda() + + sparse_model.eval() + with torch.no_grad(): + output = sparse_model(test_input) + + # Check output is valid + assert output is not None + assert not torch.isnan(output).any() + + def test_calibrated_vs_fixed_threshold(self, simple_model_setup): + """Compare calibrated vs fixed threshold models.""" + # Config with calibration + config_calibrated = { + "sparse_cfg": { + "*attn*": { + "method": "flash_skip_softmax", + "threshold": 1e-3, + "enable": True, + "calibration": { + "target_sparse_ratio": 0.5, + "samples": 2, + "max_seqlen": 1024, + }, + } + }, + } + + # Config with fixed threshold (no calibration) + config_fixed = { + "sparse_cfg": { + "*attn*": { + "method": "flash_skip_softmax", + "threshold": 1e-3, + "enable": True, + } + }, + } + + def forward_loop(model): + pass + + # Test both can be created + model_calibrated = sparsify( + SimpleTransformerEncoderLayer(d_model=256, nhead=8).cuda(), + config_calibrated, + forward_loop=forward_loop, + ) + + model_fixed = sparsify( + SimpleTransformerEncoderLayer(d_model=256, nhead=8).cuda(), + config_fixed, + ) + + # Both should work for inference + test_input = SimpleTransformerEncoderLayer.get_input(d_model=256, seq_len=10).cuda() + + with torch.no_grad(): + output_calibrated = model_calibrated(test_input) + output_fixed = model_fixed(test_input) + + assert output_calibrated is not None + assert output_fixed is not None + + def test_memory_usage(self, simple_model_setup): + """Test that calibration doesn't cause memory issues.""" + model = simple_model_setup + + # Clear cache before test + torch.cuda.empty_cache() + initial_memory = torch.cuda.memory_allocated() + + config = { + "sparse_cfg": { + "*attn*": { + "method": "flash_skip_softmax", + "threshold": 1e-3, + "enable": True, + "calibration": { + "target_sparse_ratio": 0.5, + "samples": 2, + "max_seqlen": 1024, + }, + } + }, + } + + def forward_loop(model): + pass + + # Calibrate + sparsify(model, config, forward_loop=forward_loop) + + # Check memory didn't explode + final_memory = torch.cuda.memory_allocated() + memory_increase = final_memory - initial_memory + + # Memory should be reasonable (not more than 2GB increase) + assert memory_increase < 2 * 1024**3 # 2GB + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/tests/gpu/torch/sparsity/attention_sparsity/test_integration_gpu.py b/tests/gpu/torch/sparsity/attention_sparsity/test_integration_gpu.py new file mode 100644 index 000000000..586cb3b9d --- /dev/null +++ b/tests/gpu/torch/sparsity/attention_sparsity/test_integration_gpu.py @@ -0,0 +1,190 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Integration testing with locally created minimal Llama model.""" + +import pytest +import torch +from _test_utils.torch.transformers_models import create_tiny_llama_dir +from transformers import AutoModelForCausalLM, AutoTokenizer + +import modelopt.torch.sparsity.attention_sparsity as sparse_attn +from modelopt.torch.sparsity.attention_sparsity import SparseAttentionConfig +from modelopt.torch.sparsity.attention_sparsity.sparse_attention import SparseAttentionModule + +# Skip all tests if GPU is not available +pytestmark = pytest.mark.skipif(not torch.cuda.is_available(), reason="GPU not available") + + +@pytest.fixture(scope="module") +def tiny_llama_dir(tmp_path_factory): + """Create minimal Llama model locally.""" + return create_tiny_llama_dir( + tmp_path_factory.mktemp("tiny_llama"), + with_tokenizer=True, + num_hidden_layers=2, # Minimal layers for fast testing + hidden_size=512, + intermediate_size=1024, + ) + + +@pytest.fixture(scope="module") +def tinyllama_model(tiny_llama_dir): + """Load locally created tiny Llama model.""" + model = AutoModelForCausalLM.from_pretrained( + tiny_llama_dir, + attn_implementation="eager", + device_map="cuda", + ) + return model + + +@pytest.fixture(scope="module") +def tinyllama_tokenizer(tiny_llama_dir): + """Load tokenizer for tiny Llama model.""" + tokenizer = AutoTokenizer.from_pretrained(tiny_llama_dir) + return tokenizer + + +class TestTinyLlama: + """TinyLlama sparse attention tests.""" + + def test_load_and_sparsify(self, tinyllama_model): + """Load TinyLlama and apply sparse attention.""" + model = tinyllama_model + + config = SparseAttentionConfig( + sparse_cfg={ + "*attn*": { + "method": "flash_skip_softmax", + "threshold": 1e-3, + "br": 128, + "bc": 128, + "backend": "pytorch", + "enable": True, + } + }, + ) + + sparse_model = sparse_attn.sparsify(model, config) + + # Verify sparse attention modules were added + sparse_count = sum( + 1 for m in sparse_model.modules() if isinstance(m, SparseAttentionModule) + ) + assert sparse_count > 0, "No sparse attention modules found" + + # Our tiny llama has 2 layers, so should have 2 attention modules + assert sparse_count == 2, f"Expected 2 sparse modules, got {sparse_count}" + + def test_forward_prefill(self, tinyllama_model, tinyllama_tokenizer): + """Forward pass with seq_len=64 (prefill).""" + model = tinyllama_model + tokenizer = tinyllama_tokenizer + + config = SparseAttentionConfig( + sparse_cfg={ + "*attn*": { + "threshold": 1e-3, + "backend": "pytorch", + "enable": True, + } + }, + ) + + sparse_model = sparse_attn.sparsify(model, config) + + # Create prefill input (seq_len > 1) + test_text = "Once upon a time in a land far away" + inputs = tokenizer(test_text, return_tensors="pt").to("cuda") + + # Forward pass + sparse_model.eval() + with torch.no_grad(): + outputs = sparse_model(**inputs) + + # Verify output + assert outputs.logits is not None + assert not torch.isnan(outputs.logits).any() + assert outputs.logits.shape[1] == inputs.input_ids.shape[1] # seq_len preserved + + def test_forward_decode(self, tinyllama_model): + """Forward pass with seq_len=1 (decode).""" + model = tinyllama_model + + config = SparseAttentionConfig( + sparse_cfg={ + "*attn*": { + "threshold": 1e-5, # More conservative for decode + "backend": "pytorch", + "enable": True, + } + }, + ) + + sparse_model = sparse_attn.sparsify(model, config) + + # Create decode input (seq_len = 1) + input_ids = torch.randint(0, 32000, (1, 1), device="cuda") + + # Forward pass + sparse_model.eval() + with torch.no_grad(): + outputs = sparse_model(input_ids) + + # Verify output + assert outputs.logits is not None + assert not torch.isnan(outputs.logits).any() + assert outputs.logits.shape == (1, 1, 32000) # batch=1, seq=1, vocab_size + + def test_gqa_attention(self, tinyllama_model): + """Verify GQA support (num_kv_heads < num_heads).""" + model = tinyllama_model + + # Check if model uses GQA + config = model.config + has_gqa = hasattr(config, "num_key_value_heads") and ( + config.num_key_value_heads < config.num_attention_heads + ) + + if not has_gqa: + pytest.skip("Model does not use GQA") + + # Apply sparse attention + sparse_config = SparseAttentionConfig( + sparse_cfg={ + "*attn*": { + "threshold": 1e-3, + "backend": "pytorch", + "enable": True, + } + }, + ) + + sparse_model = sparse_attn.sparsify(model, sparse_config) + + # Test forward pass with GQA + input_ids = torch.randint(0, 32000, (1, 32), device="cuda") + + sparse_model.eval() + with torch.no_grad(): + outputs = sparse_model(input_ids) + + assert outputs.logits is not None + assert not torch.isnan(outputs.logits).any() + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/tests/unit/torch/sparsity/attention_sparsity/test_flash_skip_softmax.py b/tests/unit/torch/sparsity/attention_sparsity/test_flash_skip_softmax.py new file mode 100644 index 000000000..b487d8639 --- /dev/null +++ b/tests/unit/torch/sparsity/attention_sparsity/test_flash_skip_softmax.py @@ -0,0 +1,282 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for FlashSkipSoftmax method internals.""" + +import pytest +import torch + +pytest.importorskip("transformers") + +from modelopt.torch.sparsity.attention_sparsity.methods.flash_skip_softmax import FlashSkipSoftmax + + +class TestFlashSkipSoftmaxMethod: + """Test FlashSkipSoftmax method internals.""" + + def test_phase_inference(self): + """Test phase detection from attention score shape.""" + method = FlashSkipSoftmax( + { + "threshold": 1e-3, + "br": 128, + "bc": 128, + "backend": "pytorch", + "is_causal": True, + } + ) + + # Prefill: seq_q > 1 + prefill_scores = torch.randn(2, 4, 64, 64) + assert method._infer_phase(prefill_scores) == "prefill" + + # Decode: seq_q = 1 + decode_scores = torch.randn(2, 4, 1, 64) + assert method._infer_phase(decode_scores) == "decode" + + def test_threshold_update_dict_config(self): + """Test threshold updates with dict config.""" + method = FlashSkipSoftmax( + { + "threshold": {"prefill": 1e-3, "decode": 1e-5}, + "br": 128, + "bc": 128, + "backend": "pytorch", + "is_causal": True, + } + ) + + # Initially uses prefill threshold + initial_threshold = method.threshold + + # Update to decode + method._update_threshold("decode") + assert method.threshold == 1e-5 + assert method.threshold != initial_threshold + + # Update back to prefill + method._update_threshold("prefill") + assert method.threshold == 1e-3 + + def test_threshold_update_static_config(self): + """Test threshold with static float config.""" + method = FlashSkipSoftmax( + { + "threshold": 5e-4, + "br": 128, + "bc": 128, + "backend": "pytorch", + "is_causal": True, + } + ) + + initial_threshold = method.threshold + assert initial_threshold == 5e-4 + + # Should not change for static config + method._update_threshold("decode") + assert method.threshold == 5e-4 + + def test_block_reshaping_divisible(self): + """Test block reshaping with divisible sequence lengths.""" + method = FlashSkipSoftmax( + { + "threshold": 1e-3, + "br": 128, + "bc": 128, + "backend": "pytorch", + "is_causal": True, + } + ) + + # Seq lengths divisible by 128 + attn = torch.randn(2, 4, 256, 256) + blocked, num_br, num_bc, padded_q, padded_k = method._reshape_to_blocks(attn, 128, 128) + + # Verify block dimensions + assert blocked.shape == (2, 4, 2, 128, 2, 128) # 256/128 = 2 blocks + assert num_br == 2 + assert num_bc == 2 + assert padded_q == 256 # No padding + assert padded_k == 256 # No padding + + def test_block_reshaping_with_padding(self): + """Test block reshaping with non-divisible lengths.""" + method = FlashSkipSoftmax( + { + "threshold": 1e-3, + "br": 128, + "bc": 128, + "backend": "pytorch", + "is_causal": True, + } + ) + + # Seq lengths NOT divisible by 128 + attn = torch.randn(2, 4, 200, 300) + blocked, num_br, num_bc, padded_q, padded_k = method._reshape_to_blocks(attn, 128, 128) + + # Verify padding applied + assert padded_q == 256 # ceil(200/128) * 128 = 2 * 128 + assert padded_k == 384 # ceil(300/128) * 128 = 3 * 128 + assert num_br == 2 + assert num_bc == 3 + assert blocked.shape == (2, 4, 2, 128, 3, 128) + + def test_correction_factor_calculation_prefill(self): + """Test correction factor for prefill phase.""" + method = FlashSkipSoftmax( + { + "threshold": 1e-3, + "br": 128, + "bc": 128, + "backend": "pytorch", + "is_causal": True, + } + ) + + # Create simple attention pattern + attn = torch.randn(1, 1, 128, 256) + + mask, stats = method.calc_correction_factor_and_p(attn, "prefill") + + # Verify stats structure + assert "correction_factor" in stats + assert "sparsity" in stats + assert "phase" in stats + assert "total_blocks" in stats + assert stats["phase"] == "prefill" + assert 0 <= stats["correction_factor"] <= 1 + # Sparsity can be negative if threshold is too low (more blocks kept than expected) + assert -1 <= stats["sparsity"] <= 1 + + def test_correction_factor_calculation_decode(self): + """Test correction factor for decode phase.""" + method = FlashSkipSoftmax( + { + "threshold": 1e-5, + "br": 128, + "bc": 128, + "backend": "pytorch", + "is_causal": True, + } + ) + + # Decode: single query + attn = torch.randn(1, 1, 1, 256) + + mask, stats = method.calc_correction_factor_and_p(attn, "decode") + + # Verify stats structure + assert stats["phase"] == "decode" + assert "correction_factor" in stats + assert 0 <= stats["sparsity"] <= 1 + assert mask.shape == (1, 1, 1, 256) + + def test_sparsity_statistics(self): + """Test sparsity statistics structure.""" + method = FlashSkipSoftmax( + { + "threshold": 1e-3, + "br": 128, + "bc": 128, + "backend": "pytorch", + "is_causal": True, + } + ) + + attn = torch.randn(1, 1, 128, 256) + _, stats = method.calc_correction_factor_and_p(attn, "prefill") + + # Verify statistics are present + assert stats["total_blocks"] > 0 + assert "sparse_blocks" in stats + assert "sample_length" in stats + assert stats["sample_length"] == 256 + + def test_block_mask_correctness(self): + """Test block mask shape and type.""" + method = FlashSkipSoftmax( + { + "threshold": 1e-3, + "br": 128, + "bc": 128, + "backend": "pytorch", + "is_causal": True, + } + ) + + attn = torch.randn(2, 4, 128, 256) + mask, _ = method.calc_correction_factor_and_p(attn, "prefill") + + # Verify mask properties + assert mask.shape == attn.shape + assert mask.dtype == torch.bool + assert mask.device == attn.device + + def test_causal_vs_noncausal(self): + """Test total_blocks calculation for causal vs non-causal.""" + config_base = { + "threshold": 1e-3, + "br": 128, + "bc": 128, + "backend": "pytorch", + } + + method_causal = FlashSkipSoftmax({**config_base, "is_causal": True}) + method_noncausal = FlashSkipSoftmax({**config_base, "is_causal": False}) + + attn = torch.randn(1, 1, 256, 256) # 2x2 blocks + + _, stats_causal = method_causal.calc_correction_factor_and_p(attn, "prefill") + _, stats_noncausal = method_noncausal.calc_correction_factor_and_p(attn, "prefill") + + # Causal: 2*(2+1)/2 = 3 blocks + # Non-causal: 2*2 = 4 blocks + assert stats_causal["total_blocks"] == 3 + assert stats_noncausal["total_blocks"] == 4 + + def test_apply_sparsity_assertions(self): + """Test apply_sparsity input validation.""" + method = FlashSkipSoftmax( + { + "threshold": 1e-3, + "br": 128, + "bc": 128, + "backend": "pytorch", + "is_causal": True, + } + ) + + # Test: attention_scores required + with pytest.raises(AssertionError, match="attention_scores must be provided"): + method.apply_sparsity() + + # Test: 4D shape required + with pytest.raises(AssertionError, match="Expected 4D"): + method.apply_sparsity(attention_scores=torch.randn(2, 64, 64)) # 3D + + def test_name_property(self): + """Test method name property.""" + method = FlashSkipSoftmax( + { + "threshold": 1e-3, + "br": 128, + "bc": 128, + "backend": "pytorch", + "is_causal": True, + } + ) + assert method.name == "flash_skip_softmax" diff --git a/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_calibration.py b/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_calibration.py new file mode 100644 index 000000000..4558ca22b --- /dev/null +++ b/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_calibration.py @@ -0,0 +1,623 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for sparse attention calibration.""" + +import pytest + +pytest.importorskip("transformers") + +import numpy as np +from _test_utils.torch_sparsity.sparse_attention_common import ( + SimpleAttentionModel, + SimpleTransformerEncoder, +) +from pydantic import ValidationError + +from modelopt.torch.sparsity.attention_sparsity import sparsify +from modelopt.torch.sparsity.attention_sparsity.calibration import ( + DynamicThresholdCalibrator, + RulerDatasetBuilder, +) +from modelopt.torch.sparsity.attention_sparsity.calibration.calibrate import ( + _extract_calibration_config, + calibrate_sparse_attention, + create_calibration_forward_loop, +) +from modelopt.torch.sparsity.attention_sparsity.calibration.dataset import _generate_target_lengths +from modelopt.torch.sparsity.attention_sparsity.config import CalibrationConfig +from modelopt.torch.sparsity.attention_sparsity.sparse_attention import SparseAttentionModule + + +class TestLengthGeneration: + """Test automatic target length generation.""" + + def test_generate_target_lengths_default(self): + """Test default 4 bins generation.""" + lengths = _generate_target_lengths(32768, num_length_bins=4) + assert lengths == [32768, 16384, 8192, 4096] + + def test_generate_target_lengths_stops_at_minimum(self): + """Test generation stops at minimum threshold.""" + lengths = _generate_target_lengths(2048, num_length_bins=4) + assert lengths == [2048, 1024] # Stops at 1024 + + def test_generate_target_lengths_fewer_bins(self): + """Test with fewer bins.""" + lengths = _generate_target_lengths(16384, num_length_bins=2) + assert lengths == [16384, 8192] + + def test_generate_target_lengths_more_bins(self): + """Test with more bins.""" + lengths = _generate_target_lengths(65536, num_length_bins=6) + assert lengths == [65536, 32768, 16384, 8192, 4096, 2048] + + def test_generate_target_lengths_exactly_minimum(self): + """Test when max_seqlen equals minimum.""" + lengths = _generate_target_lengths(1024, num_length_bins=4) + assert lengths == [1024] + + +class TestRulerDatasetBuilder: + """Test RULER dataset generation without requiring real tokenizers.""" + + def test_builder_initialization(self): + """Test that builder initializes correctly.""" + builder = RulerDatasetBuilder( + samples=12, + max_seqlen=2048, # Generates: [2048, 1024] + tokenizer_name_or_path="gpt2", + seed=42, + ) + + assert builder.total_samples == 12 + assert builder.max_seqlen == 2048 + assert builder.target_lengths == [2048, 1024] + assert builder.samples_per_length == [6, 6] # Evenly distributed + assert len(builder.subtasks) == 6 # All RULER_TASKS + assert builder.seed == 42 + + def test_builder_initialization_invalid_config(self): + """Test that builder raises error for invalid inputs.""" + # Test invalid samples + with pytest.raises(ValueError, match="samples must be positive"): + RulerDatasetBuilder( + samples=0, + max_seqlen=2048, + tokenizer_name_or_path="gpt2", + ) + + # Test max_seqlen below minimum + with pytest.raises(ValueError, match="max_seqlen must be >= 1024"): + RulerDatasetBuilder( + samples=4, + max_seqlen=512, # Below minimum + tokenizer_name_or_path="gpt2", + ) + + def test_dataset_generation_minimal(self): + """Test generating small dataset.""" + builder = RulerDatasetBuilder( + samples=12, # 6 tasks x 2 lengths = need 12 for 1 per task per length + max_seqlen=2048, # Generates: [2048, 1024] + tokenizer_name_or_path="gpt2", + seed=42, + ) + + dataset = builder.build_calibration_dataset() + + # Should generate 12 samples (6 tasks x 1 sample per task x 2 lengths) + assert len(dataset) == 12 + assert all(isinstance(sample, dict) for sample in dataset) + + def test_dataset_structure(self): + """Test that dataset has correct structure.""" + builder = RulerDatasetBuilder( + samples=6, # Need at least 6 (1 per task) + max_seqlen=1024, # Generates: [1024] + tokenizer_name_or_path="gpt2", + seed=42, + ) + + dataset = builder.build_calibration_dataset() + sample = dataset[0] + + # Check required fields + assert "input" in sample + assert "length" in sample + assert "task" in sample + assert "target_length" in sample + + # Check field types + assert isinstance(sample["input"], str) + assert isinstance(sample["length"], int) + assert isinstance(sample["task"], str) + assert sample["length"] > 0 + + def test_sample_distribution(self): + """Test that samples are distributed across lengths and subtasks.""" + builder = RulerDatasetBuilder( + samples=24, # 6 tasks x 2 lengths x 2 samples = 24 + max_seqlen=2048, # Generates: [2048, 1024] + tokenizer_name_or_path="gpt2", + seed=42, + ) + + dataset = builder.build_calibration_dataset() + + # Should have 24 samples (12 per length, 2 per task) + assert len(dataset) == 24 + + # Check task distribution (should have variety from all RULER_TASKS) + tasks = [s["task"] for s in dataset] + # Verify we have all 6 tasks represented + assert len(set(tasks)) == 6 + + def test_length_targeting(self): + """Test that generated lengths are close to targets.""" + builder = RulerDatasetBuilder( + samples=6, # 1 per task + max_seqlen=1024, # Generates: [1024] + tokenizer_name_or_path="gpt2", + seed=42, + ) + + dataset = builder.build_calibration_dataset() + + # Lengths should be within reasonable range of target + # RULER aims for 70-90% of target length for context + for sample in dataset: + assert 700 < sample["length"] < 1400 # Reasonable range around 1024 + + def test_uneven_sample_distribution(self): + """Test that samples are distributed evenly (remainder dropped).""" + builder = RulerDatasetBuilder( + samples=50, # 50 samples across 4 lengths + max_seqlen=8192, # Generates: [8192, 4096, 2048, 1024] + tokenizer_name_or_path="gpt2", + seed=42, + ) + + # Even distribution: 50//4 = 12 per length + assert builder.total_samples == 50 + assert builder.target_lengths == [8192, 4096, 2048, 1024] + assert builder.samples_per_length == [12, 12, 12, 12] + assert sum(builder.samples_per_length) == 48 # 2 samples dropped (remainder) + + # Actual generated samples: 12//6=2 per task, 4 lengths, 6 tasks + # Total: 2 x 6 x 4 = 48 + dataset = builder.build_calibration_dataset() + assert len(dataset) == 48 + + +class TestDynamicThresholdCalibrator: + """Test calibration algorithm correctness.""" + + def test_calibrator_initialization(self): + """Test that calibrator initializes correctly.""" + calibrator = DynamicThresholdCalibrator( + target_sparse_ratio=0.5, + threshold_trials=[1e-4, 1e-3, 1e-2], + ) + + assert calibrator.target_sparse_ratio == 0.5 + assert len(calibrator.threshold_trials) == 3 + + def test_calibrator_default_threshold_trials(self): + """Test that calibrator has default threshold trials.""" + calibrator = DynamicThresholdCalibrator( + target_sparse_ratio=0.5, + ) + + # Should have default threshold trials + assert calibrator.threshold_trials is not None + assert len(calibrator.threshold_trials) == 12 + # Check they are positive and in valid range + trials = calibrator.threshold_trials + assert all(0 < t < 1 for t in trials) + + def test_regression_calculation_synthetic(self): + """Test 'a' parameter calculation with synthetic data.""" + # Create synthetic optimal pairs + # If threshold = a / length, then with perfect data: + # length=1000, threshold=10 => a=10000 + # length=2000, threshold=5 => a=10000 + optimal_pairs = [ + {"length": 1000, "optimal_threshold": 10.0, "achieved_sparsity": 0.5}, + {"length": 2000, "optimal_threshold": 5.0, "achieved_sparsity": 0.5}, + {"length": 4000, "optimal_threshold": 2.5, "achieved_sparsity": 0.5}, + ] + + # Manual regression calculation + lengths = np.array([p["length"] for p in optimal_pairs]) + thresholds = np.array([p["optimal_threshold"] for p in optimal_pairs]) + + x = 1.0 / lengths + y = thresholds + + # Calculate 'a' using least squares + a_parameter = np.sum(x * y) / np.sum(x**2) + + # Should be close to 10000 + assert 9500 < a_parameter < 10500 + + # Test individual 'a' values + a_per_sample = y * lengths + assert np.allclose(a_per_sample, 10000, rtol=0.05) + + def test_multiple_samples_different_lengths(self): + """Test regression with varied lengths.""" + # More realistic scenario with some variance + optimal_pairs = [ + {"length": 500, "optimal_threshold": 20.0, "achieved_sparsity": 0.5}, + {"length": 1000, "optimal_threshold": 10.5, "achieved_sparsity": 0.51}, + {"length": 2000, "optimal_threshold": 5.2, "achieved_sparsity": 0.49}, + {"length": 4000, "optimal_threshold": 2.4, "achieved_sparsity": 0.50}, + ] + + lengths = np.array([p["length"] for p in optimal_pairs]) + thresholds = np.array([p["optimal_threshold"] for p in optimal_pairs]) + + x = 1.0 / lengths + y = thresholds + + a_parameter = np.sum(x * y) / np.sum(x**2) + + # Should still be around 10000 with some tolerance for variance + assert 9000 < a_parameter < 11000 + + def test_r_squared_calculation(self): + """Test R-squared calculation for regression quality.""" + # Perfect fit data + optimal_pairs = [ + {"length": 1000, "optimal_threshold": 10.0}, + {"length": 2000, "optimal_threshold": 5.0}, + {"length": 4000, "optimal_threshold": 2.5}, + ] + + lengths = np.array([p["length"] for p in optimal_pairs]) + thresholds = np.array([p["optimal_threshold"] for p in optimal_pairs]) + + x = 1.0 / lengths + y = thresholds + + a_parameter = np.sum(x * y) / np.sum(x**2) + + # Calculate R-squared + y_pred = a_parameter * x + ss_res = np.sum((y - y_pred) ** 2) + ss_tot = np.sum((y - np.mean(y)) ** 2) + r_squared = 1 - (ss_res / ss_tot) if ss_tot > 0 else 0 + + # Perfect fit should have R^2 close to 1 + assert r_squared > 0.99 + + +class TestCalibrationIntegration: + """Test end-to-end calibration without GPU.""" + + def test_calibration_disabled(self): + """Test that no calibration occurs when disabled.""" + model = SimpleAttentionModel(hidden_size=64, num_heads=4) + + config = { + "sparse_cfg": { + "*attention*": { + "method": "flash_skip_softmax", + "threshold": 1e-3, + "br": 64, + "bc": 64, + "enable": True, + } + }, + } + + # No forward_loop needed when calibration disabled + sparse_model = sparsify(model, config) + + # Check that sparse attention is applied but not calibrated + has_sparse = any(isinstance(m, SparseAttentionModule) for m in sparse_model.modules()) + assert has_sparse + + # Check that no calibration is set + for module in sparse_model.modules(): + if isinstance(module, SparseAttentionModule): + method = module._sparse_method_instance + assert not getattr(method, "threshold_scale_factor", None) + + def test_sparsify_with_calibration_requires_forward_loop(self): + """Test that calibration requires forward_loop or proper model config.""" + model = SimpleAttentionModel(hidden_size=64, num_heads=4) + + config = { + "sparse_cfg": { + "calibration": { + "target_sparse_ratio": 0.5, + "samples": 4, + "max_seqlen": 1024, + }, + "*attention*": { + "method": "flash_skip_softmax", + "threshold": 1e-3, + "br": 64, + "bc": 64, + "enable": True, + }, + }, + } + + # Without forward_loop and without model.config._name_or_path, should raise ValueError + with pytest.raises(ValueError, match="Could not load tokenizer"): + sparsify(model, config, forward_loop=None) + + def test_multiple_sparse_modules(self): + """Test that calibration handles multiple attention layers.""" + model = SimpleTransformerEncoder() + + config = { + "sparse_cfg": {"*attn*": {"threshold": 1e-3, "br": 64, "bc": 64, "enable": True}}, + } + + sparse_model = sparsify(model, config) + + # Count sparse attention modules + sparse_count = sum( + 1 for m in sparse_model.modules() if isinstance(m, SparseAttentionModule) + ) + + # Should have 2 sparse attention modules + assert sparse_count == 2 + + def test_calibration_config_validation(self): + """Test CalibrationConfig validation.""" + # Valid config + config = CalibrationConfig( + target_sparse_ratio=0.5, + samples=48, + max_seqlen=32768, + ) + assert config.target_sparse_ratio == 0.5 + assert config.samples == 48 + assert config.max_seqlen == 32768 + + # Invalid target_sparse_ratio (> 1.0) + with pytest.raises(ValueError, match="target_sparse_ratio must be between"): + CalibrationConfig(target_sparse_ratio=1.5, samples=48, max_seqlen=32768) + + # Invalid target_sparse_ratio (< 0.0) + with pytest.raises(ValueError, match="target_sparse_ratio must be between"): + CalibrationConfig(target_sparse_ratio=-0.1, samples=48, max_seqlen=32768) + + # Invalid samples + with pytest.raises(ValueError, match="samples must be positive"): + CalibrationConfig(target_sparse_ratio=0.5, samples=0, max_seqlen=32768) + + # Invalid max_seqlen + with pytest.raises(ValueError, match="max_seqlen must be >= 1024"): + CalibrationConfig(target_sparse_ratio=0.5, samples=48, max_seqlen=512) + + def test_threshold_trials_validation(self): + """Test threshold_trials validation.""" + # Valid custom threshold_trials + config = CalibrationConfig( + target_sparse_ratio=0.5, + threshold_trials=[1e-5, 1e-4, 1e-3, 1e-2], + ) + assert config.threshold_trials == [1e-5, 1e-4, 1e-3, 1e-2] + + # None (use defaults) + config_default = CalibrationConfig(target_sparse_ratio=0.5) + assert config_default.threshold_trials is None + + # Invalid: empty list + with pytest.raises(ValueError, match="threshold_trials must not be empty"): + CalibrationConfig(threshold_trials=[]) + + # Invalid: threshold out of range (>= 1.0) + with pytest.raises(ValueError, match="must be in range"): + CalibrationConfig(threshold_trials=[1e-4, 1.0]) + + # Invalid: threshold out of range (<= 0) + with pytest.raises(ValueError, match="must be in range"): + CalibrationConfig(threshold_trials=[1e-4, 0]) + + # Invalid: not a list (Pydantic raises ValidationError, not ValueError) + with pytest.raises(ValidationError, match="Input should be a valid list"): + CalibrationConfig(threshold_trials=1e-4) + + +class TestDynamicThresholdCalibratorMethods: + """Test individual methods of DynamicThresholdCalibrator.""" + + def test_set_threshold(self): + """Test _set_threshold method.""" + model = SimpleAttentionModel(hidden_size=64, num_heads=4) + config = { + "sparse_cfg": { + "*attention*": { + "method": "flash_skip_softmax", + "threshold": 0.1, + "br": 64, + "bc": 64, + "enable": True, + } + }, + } + sparse_model = sparsify(model, config) + + # Get sparse modules + modules = [m for m in sparse_model.modules() if isinstance(m, SparseAttentionModule)] + assert len(modules) > 0 + + # Create calibrator and set threshold + calibrator = DynamicThresholdCalibrator(target_sparse_ratio=0.5) + calibrator._set_threshold(modules, 0.05) + + # Verify threshold was set + for module in modules: + assert module._sparse_method_instance.threshold == 0.05 + + def test_enable_disable_calibration_mode(self): + """Test _enable_calibration_mode and _disable_calibration_mode.""" + model = SimpleAttentionModel(hidden_size=64, num_heads=4) + config = { + "sparse_cfg": { + "*attention*": { + "method": "flash_skip_softmax", + "threshold": 0.1, + "br": 64, + "bc": 64, + "enable": True, + } + }, + } + sparse_model = sparsify(model, config) + + modules = [m for m in sparse_model.modules() if isinstance(m, SparseAttentionModule)] + + calibrator = DynamicThresholdCalibrator(target_sparse_ratio=0.5) + + # Enable calibration mode + calibrator._enable_calibration_mode(modules) + + for module in modules: + assert module._stats_manager is not None + assert module._stats_manager.enabled is True + assert module._stats_manager.calibration_mode is True + assert module._sparse_method_instance._calibration_mode is True + + # Disable calibration mode + calibrator._disable_calibration_mode(modules) + + for module in modules: + assert module._stats_manager.calibration_mode is False + assert module._sparse_method_instance._calibration_mode is False + + def test_extract_calibration_stats_no_stats(self): + """Test _extract_calibration_stats when no stats collected.""" + model = SimpleAttentionModel(hidden_size=64, num_heads=4) + config = { + "sparse_cfg": { + "*attention*": { + "method": "flash_skip_softmax", + "threshold": 0.1, + "br": 64, + "bc": 64, + "enable": True, + } + }, + } + sparse_model = sparsify(model, config) + + modules = [m for m in sparse_model.modules() if isinstance(m, SparseAttentionModule)] + + calibrator = DynamicThresholdCalibrator(target_sparse_ratio=0.5) + + # Extract stats without running any forward passes + stats = calibrator._extract_calibration_stats(modules) + + # Should return empty list + assert stats == [] + + def test_calibrator_with_single_sample(self): + """Test calibrator edge case with only one sample.""" + calibrator = DynamicThresholdCalibrator( + target_sparse_ratio=0.5, + threshold_trials=[0.001, 0.01, 0.1], + ) + + # Even with one sample, regression should work + assert calibrator.target_sparse_ratio == 0.5 + assert len(calibrator.threshold_trials) == 3 + + +class TestCalibrateFunction: + """Test calibrate_sparse_attention function.""" + + def test_calibrate_no_config(self): + """Test calibration when config has no calibration section.""" + model = SimpleAttentionModel(hidden_size=64, num_heads=4) + + # Config without calibration + config = { + "sparse_cfg": { + "*attention*": { + "method": "flash_skip_softmax", + "threshold": 0.1, + "br": 64, + "bc": 64, + "enable": True, + } + }, + } + + # Should return empty dict when no calibration config + result = calibrate_sparse_attention(model, config) + + assert result == {} + + def test_extract_calibration_config(self): + """Test _extract_calibration_config function.""" + # Config with calibration + config = { + "sparse_cfg": { + "calibration": { + "target_sparse_ratio": 0.3, + "samples": 12, + "max_seqlen": 2048, + }, + "*attn*": { + "method": "flash_skip_softmax", + }, + }, + } + + calib_config = _extract_calibration_config(config) + + assert calib_config is not None + assert calib_config.target_sparse_ratio == 0.3 + assert calib_config.samples == 12 + assert calib_config.max_seqlen == 2048 + + def test_extract_calibration_config_none(self): + """Test _extract_calibration_config when no calibration.""" + # Config without calibration + config = { + "sparse_cfg": { + "*attn*": { + "method": "flash_skip_softmax", + "threshold": 0.1, + } + }, + } + + calib_config = _extract_calibration_config(config) + + assert calib_config is None + + def test_create_calibration_forward_loop(self): + """Test create_calibration_forward_loop function.""" + calibration_data = [ + {"input": "This is a test sample.", "length": 512}, + {"input": "Another test sample.", "length": 1024}, + ] + + forward_loop = create_calibration_forward_loop( + calibration_data=calibration_data, + tokenizer_name_or_path="gpt2", + ) + + # Should return a callable + assert callable(forward_loop) diff --git a/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_config.py b/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_config.py new file mode 100644 index 000000000..1824825f9 --- /dev/null +++ b/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_config.py @@ -0,0 +1,129 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test sparse attention configuration validation.""" + +import pytest +from pydantic import ValidationError + +pytest.importorskip("transformers") + +from modelopt.torch.sparsity.attention_sparsity.config import ( + SKIP_SOFTMAX_DEFAULT, + FlashSkipSoftmaxConfig, + SparseAttentionAttributeConfig, + SparseAttentionConfig, +) + + +class TestSparseAttentionAttributeConfig: + """Test SparseAttentionAttributeConfig validators.""" + + def test_valid_config(self): + """Test creating valid config.""" + config = SparseAttentionAttributeConfig( + method="flash_skip_softmax", + threshold=1e-4, + br=128, + bc=128, + enable=True, + ) + assert config.method == "flash_skip_softmax" + assert config.threshold == 1e-4 + assert config.br == 128 + assert config.bc == 128 + + def test_method_validation(self): + """Test method must be string.""" + with pytest.raises(ValidationError, match="Input should be a valid string"): + SparseAttentionAttributeConfig(method=123) + + def test_block_size_validation_negative(self): + """Test block sizes must be positive.""" + with pytest.raises(ValidationError, match="Block size must be positive"): + SparseAttentionAttributeConfig(br=-1) + + with pytest.raises(ValidationError, match="Block size must be positive"): + SparseAttentionAttributeConfig(bc=0) + + def test_block_size_validation_large(self): + """Test that large block sizes are accepted.""" + # Large block sizes are allowed (warning removed for simplicity) + config = SparseAttentionAttributeConfig(br=2048) + assert config.br == 2048 + + def test_threshold_validation_range(self): + """Test threshold must be in range (0, 1).""" + with pytest.raises(ValidationError, match="Threshold must be in range"): + SparseAttentionAttributeConfig(threshold=0) + + with pytest.raises(ValidationError, match="Threshold must be in range"): + SparseAttentionAttributeConfig(threshold=-0.1) + + with pytest.raises(ValidationError, match="Threshold must be in range"): + SparseAttentionAttributeConfig(threshold=1.0) + + with pytest.raises(ValidationError, match="Threshold must be in range"): + SparseAttentionAttributeConfig(threshold=1.5) + + def test_threshold_validation_dict(self): + """Test threshold dict validation.""" + # Valid phase-aware threshold + config = SparseAttentionAttributeConfig(threshold={"prefill": 1e-3, "decode": 1e-5}) + assert config.threshold == {"prefill": 1e-3, "decode": 1e-5} + + # Invalid phase key + with pytest.raises(ValidationError, match="Invalid threshold phases"): + SparseAttentionAttributeConfig(threshold={"invalid_phase": 1e-3}) + + # Invalid threshold value in dict (negative) + with pytest.raises(ValidationError, match="must be in range"): + SparseAttentionAttributeConfig(threshold={"prefill": -1e-3}) + + # Invalid threshold value in dict (>= 1.0) + with pytest.raises(ValidationError, match="must be in range"): + SparseAttentionAttributeConfig(threshold={"prefill": 1.0}) + + def test_threshold_validation_type(self): + """Test threshold type validation.""" + with pytest.raises(ValidationError, match="Input should be a valid"): + SparseAttentionAttributeConfig(threshold="invalid") + + +class TestSparseAttentionConfig: + """Test SparseAttentionConfig.""" + + def test_default_config(self): + """Test default configuration.""" + config = SparseAttentionConfig() + assert "sparse_cfg" in config.model_dump() + # Check default pattern has method + assert config.sparse_cfg["*attention*"]["method"] == "flash_skip_softmax" + + def test_predefined_config(self): + """Test pre-defined configuration.""" + assert "sparse_cfg" in SKIP_SOFTMAX_DEFAULT + assert "method" in SKIP_SOFTMAX_DEFAULT["sparse_cfg"]["*attn*"] + assert "*attn*" in SKIP_SOFTMAX_DEFAULT["sparse_cfg"] + + +class TestFlashSkipSoftmaxConfig: + """Test FlashSkipSoftmaxConfig.""" + + def test_default_values(self): + """Test default values for flash_skip_softmax config.""" + config = FlashSkipSoftmaxConfig() + assert "*attention*" in config.sparse_cfg + assert config.sparse_cfg["*attention*"]["method"] == "flash_skip_softmax" diff --git a/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_conversion.py b/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_conversion.py new file mode 100644 index 000000000..1ba86c143 --- /dev/null +++ b/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_conversion.py @@ -0,0 +1,305 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for sparse attention conversion and replacement.""" + +import pytest + +pytest.importorskip("transformers") + +import torch.nn as nn +from _test_utils.torch_sparsity.sparse_attention_common import ( + FLASH_SKIP_SOFTMAX_DEFAULT_CFG, + SimpleAttentionModel, + SimpleTransformerEncoderLayer, +) + +import modelopt.torch.opt as mto +import modelopt.torch.sparsity.attention_sparsity as sparse_attn +from modelopt.torch.sparsity.attention_sparsity.conversion import ( + disable_sparse_attention, + enable_sparse_attention, + print_sparse_attention_summary, +) +from modelopt.torch.sparsity.attention_sparsity.sparse_attention import SparseAttentionModule + + +class TestSparseAttentionReplacement: + """Test module replacement logic.""" + + def test_basic_replacement(self): + """Test that attention modules are replaced with sparse versions.""" + model = SimpleAttentionModel() + + # Count original attention modules + original_attention_count = sum( + isinstance(m, nn.MultiheadAttention) for m in model.modules() + ) + assert original_attention_count > 0 + + # Apply sparse attention + sparse_model = sparse_attn.sparsify(model, FLASH_SKIP_SOFTMAX_DEFAULT_CFG) + + # Count sparse attention modules + sparse_attention_count = sum( + isinstance(m, SparseAttentionModule) for m in sparse_model.modules() + ) + + # Verify replacement occurred + assert sparse_attention_count > 0 + + def test_enable_disable_toggle(self): + """Test enabling and disabling sparse attention.""" + model = SimpleAttentionModel() + model = sparse_attn.sparsify(model, FLASH_SKIP_SOFTMAX_DEFAULT_CFG) + + # Check initially enabled + for module in model.modules(): + if isinstance(module, SparseAttentionModule): + assert module.is_enabled + + # Disable all sparse attention modules + disable_sparse_attention(model, "*") + for module in model.modules(): + if isinstance(module, SparseAttentionModule): + assert not module.is_enabled + + # Re-enable all sparse attention modules + enable_sparse_attention(model, "*") + for module in model.modules(): + if isinstance(module, SparseAttentionModule): + assert module.is_enabled + + def test_pattern_based_replacement(self): + """Test pattern-based selective replacement.""" + model = SimpleTransformerEncoderLayer() + + # Apply with pattern + config = { + "sparse_cfg": { + "*self_attn*": { + "method": "flash_skip_softmax", + "threshold": 1e-4, + "br": 128, + "bc": 128, + "enable": True, + }, + "default": {"enable": False}, + }, + } + + sparse_model = sparse_attn.sparsify(model, config) + + # Verify sparse modules exist + has_sparse = any(isinstance(m, SparseAttentionModule) for m in sparse_model.modules()) + assert has_sparse + + +class TestConversionEdgeCases: + """Test edge cases and error paths in conversion.""" + + def test_callable_filter(self): + """Test using callable filter instead of wildcard.""" + model = SimpleAttentionModel() + + # Use callable filter + def filter_func(name): + return "attn" in name + + config = { + "sparse_cfg": { + filter_func: { + "method": "flash_skip_softmax", + "threshold": 1e-3, + "enable": True, + }, + }, + } + + sparse_model = sparse_attn.sparsify(model, config) + has_sparse = any(isinstance(m, SparseAttentionModule) for m in sparse_model.modules()) + assert has_sparse + + def test_no_matching_modules(self): + """Test pattern that matches nothing.""" + model = SimpleAttentionModel() + + config = { + "sparse_cfg": { + "*nonexistent*": { + "method": "flash_skip_softmax", + "threshold": 1e-3, + "enable": True, + }, + }, + } + + # Should not error, even with no matches + sparse_attn.sparsify(model, config) + + def test_disable_enable_functions(self): + """Test disable/enable utility functions.""" + from modelopt.torch.sparsity.attention_sparsity.conversion import ( + disable_sparse_attention, + enable_sparse_attention, + ) + + model = SimpleAttentionModel() + model = sparse_attn.sparsify(model, FLASH_SKIP_SOFTMAX_DEFAULT_CFG) + + # Disable all + disable_sparse_attention(model, "*") + for module in model.modules(): + if isinstance(module, SparseAttentionModule): + assert not module.is_enabled + + # Enable all + enable_sparse_attention(model, "*") + for module in model.modules(): + if isinstance(module, SparseAttentionModule): + assert module.is_enabled + + def test_print_sparse_attention_summary(self, capsys): + """Test print_sparse_attention_summary function.""" + model = SimpleAttentionModel() + model = sparse_attn.sparsify(model, FLASH_SKIP_SOFTMAX_DEFAULT_CFG) + + # Print summary + print_sparse_attention_summary(model) + + # Capture output + captured = capsys.readouterr() + assert "Total sparse attention modules:" in captured.out + assert "Enabled:" in captured.out + + def test_restore_sparse_attention_model(self): + """Test save/restore via modelopt_state.""" + # Create and sparsify original model + model_orig = SimpleAttentionModel() + model_orig = sparse_attn.sparsify(model_orig, FLASH_SKIP_SOFTMAX_DEFAULT_CFG) + + # Save state + state_dict = mto.modelopt_state(model_orig) + + # Restore to new model + model_restored = SimpleAttentionModel() + mto.restore_from_modelopt_state(model_restored, state_dict) + + # Verify restoration + has_sparse = any(isinstance(m, SparseAttentionModule) for m in model_restored.modules()) + assert has_sparse + + # Verify module is configured + for module in model_restored.modules(): + if isinstance(module, SparseAttentionModule): + assert hasattr(module, "_method") + assert module._method == "flash_skip_softmax" + + +class TestSparseAttentionModuleMethods: + """Test SparseAttentionModule methods.""" + + def test_get_stats_with_stats_manager(self): + """Test get_stats() when stats manager exists and is enabled.""" + model = SimpleAttentionModel() + config = { + "sparse_cfg": { + "*attention*": { + "method": "flash_skip_softmax", + "threshold": 0.001, + "br": 64, + "bc": 64, + "collect_stats": True, # Enable stats collection + "enable": True, + } + }, + } + + sparse_model = sparse_attn.sparsify(model, config) + + # Find sparse module + sparse_module = None + for module in sparse_model.modules(): + if isinstance(module, SparseAttentionModule): + sparse_module = module + break + + assert sparse_module is not None + assert sparse_module._stats_manager is not None + + # Get stats (should return summary) + stats = sparse_module.get_stats() + + assert isinstance(stats, dict) + assert "module" in stats + assert "total_calls" in stats + assert "average_sparsity" in stats + + def test_get_stats_without_stats_manager(self): + """Test get_stats() when stats manager is None.""" + model = SimpleAttentionModel() + config = { + "sparse_cfg": { + "*attention*": { + "method": "flash_skip_softmax", + "threshold": 0.001, + "br": 64, + "bc": 64, + "collect_stats": False, # Disable stats collection + "enable": True, + } + }, + } + + sparse_model = sparse_attn.sparsify(model, config) + + # Find sparse module + for module in sparse_model.modules(): + if isinstance(module, SparseAttentionModule): + # Stats manager should be None + assert module._stats_manager is None + + # get_stats should return empty dict + stats = module.get_stats() + assert stats == {} + break + + def test_get_threshold_info(self): + """Test get_threshold_info() method.""" + model = SimpleAttentionModel() + config = { + "sparse_cfg": { + "*attention*": { + "method": "flash_skip_softmax", + "threshold": 0.005, + "br": 64, + "bc": 64, + "enable": True, + } + }, + } + + sparse_model = sparse_attn.sparsify(model, config) + + # Find sparse module and test threshold info + for module in sparse_model.modules(): + if isinstance(module, SparseAttentionModule): + info = module.get_threshold_info() + + assert isinstance(info, dict) + assert "type" in info + assert info["type"] == "static" + assert info["value"] == 0.005 + break diff --git a/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_mode.py b/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_mode.py new file mode 100644 index 000000000..e7e32e153 --- /dev/null +++ b/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_mode.py @@ -0,0 +1,43 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for sparse attention mode registry.""" + +import pytest + +pytest.importorskip("transformers") + +from modelopt.torch.opt.mode import _ModeRegistryCls +from modelopt.torch.sparsity.attention_sparsity.mode import SparseAttentionModeRegistry + + +def test_sparse_attention_mode_exists(): + """Test that sparse_attention mode is registered.""" + assert "sparse_attention" in SparseAttentionModeRegistry + + +def test_sparse_attention_mode_descriptor(): + """Test sparse attention mode descriptor properties.""" + mode_descriptor = _ModeRegistryCls.get_from_any("sparse_attention") + + assert mode_descriptor is not None + assert hasattr(mode_descriptor, "config_class") + assert hasattr(mode_descriptor, "convert") + + +def test_mode_registry_get(): + """Test getting mode from registry.""" + mode = SparseAttentionModeRegistry["sparse_attention"] + assert mode is not None diff --git a/tests/unit/torch/sparsity/attention_sparsity/test_stats_manager.py b/tests/unit/torch/sparsity/attention_sparsity/test_stats_manager.py new file mode 100644 index 000000000..02188e97a --- /dev/null +++ b/tests/unit/torch/sparsity/attention_sparsity/test_stats_manager.py @@ -0,0 +1,334 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for SparseAttentionStatsManager.""" + +import pytest + +pytest.importorskip("transformers") + +from modelopt.torch.sparsity.attention_sparsity.stats_manager import SparseAttentionStatsManager + + +class TestStatsManagerInitialization: + """Test stats manager initialization.""" + + def test_initialization_defaults(self): + """Test default initialization.""" + manager = SparseAttentionStatsManager(module_name="test_module") + + assert manager.module_name == "test_module" + assert manager.enabled is True + assert manager.calibration_mode is False + assert manager.aggregated_stats["total_calls"] == 0 + assert manager.aggregated_stats["total_blocks"] == 0 + assert manager.aggregated_stats["sparse_blocks"] == 0 + assert manager.per_sample_stats == [] + + def test_initialization_disabled(self): + """Test initialization with disabled stats.""" + manager = SparseAttentionStatsManager(module_name="test_module", enabled=False) + + assert manager.enabled is False + assert manager.calibration_mode is False + + def test_initialization_custom_name(self): + """Test initialization with custom module name.""" + manager = SparseAttentionStatsManager(module_name="custom.attention.module") + + assert manager.module_name == "custom.attention.module" + + +class TestStatsCollection: + """Test statistics collection functionality.""" + + def test_collect_stats_enabled(self): + """Test collecting stats when enabled.""" + manager = SparseAttentionStatsManager(module_name="test", enabled=True) + + stats = { + "sparsity": 0.5, + "phase": "prefill", + "total_blocks": 100, + "sparse_blocks": 50, + "sample_length": 1024, + } + + manager.collect(stats) + + assert manager.aggregated_stats["total_calls"] == 1 + assert manager.aggregated_stats["total_blocks"] == 100 + assert manager.aggregated_stats["sparse_blocks"] == 50 + assert manager.aggregated_stats["phase_counts"]["prefill"] == 1 + assert manager.aggregated_stats["phase_counts"]["decode"] == 0 + + def test_collect_stats_disabled(self): + """Test that collect() is no-op when disabled.""" + manager = SparseAttentionStatsManager(module_name="test", enabled=False) + + stats = { + "sparsity": 0.5, + "phase": "prefill", + "total_blocks": 100, + "sparse_blocks": 50, + } + + manager.collect(stats) + + # Should remain at initial values + assert manager.aggregated_stats["total_calls"] == 0 + assert manager.aggregated_stats["total_blocks"] == 0 + assert manager.aggregated_stats["sparse_blocks"] == 0 + + def test_collect_multiple_calls(self): + """Test accumulation over multiple collect calls.""" + manager = SparseAttentionStatsManager(module_name="test", enabled=True) + + # Collect multiple times + for i in range(5): + stats = { + "sparsity": 0.5, + "phase": "prefill", + "total_blocks": 100, + "sparse_blocks": 50, + } + manager.collect(stats) + + assert manager.aggregated_stats["total_calls"] == 5 + assert manager.aggregated_stats["total_blocks"] == 500 + assert manager.aggregated_stats["sparse_blocks"] == 250 + assert manager.aggregated_stats["phase_counts"]["prefill"] == 5 + + def test_collect_different_phases(self): + """Test phase counting.""" + manager = SparseAttentionStatsManager(module_name="test", enabled=True) + + # Collect prefill stats + manager.collect({"phase": "prefill", "total_blocks": 100, "sparse_blocks": 50}) + manager.collect({"phase": "prefill", "total_blocks": 100, "sparse_blocks": 50}) + + # Collect decode stats + manager.collect({"phase": "decode", "total_blocks": 10, "sparse_blocks": 5}) + + assert manager.aggregated_stats["phase_counts"]["prefill"] == 2 + assert manager.aggregated_stats["phase_counts"]["decode"] == 1 + assert manager.aggregated_stats["phase_counts"]["unknown"] == 0 + + +class TestCalibrationMode: + """Test calibration mode functionality.""" + + def test_calibration_mode_per_sample_collection(self): + """Test that calibration mode stores per-sample stats.""" + manager = SparseAttentionStatsManager(module_name="test", enabled=True) + + # Enable calibration mode + manager.set_calibration_mode(enabled=True) + + stats = { + "sparsity": 0.5, + "phase": "prefill", + "total_blocks": 100, + "sparse_blocks": 50, + "sample_length": 1024, + } + + manager.collect(stats) + + # Should store in per_sample_stats + assert len(manager.per_sample_stats) == 1 + assert manager.per_sample_stats[0]["module"] == "test" + assert manager.per_sample_stats[0]["sparsity"] == 0.5 + assert manager.per_sample_stats[0]["sample_length"] == 1024 + assert manager.per_sample_stats[0]["phase"] == "prefill" + + def test_calibration_mode_off(self): + """Test that per-sample stats are not collected when calibration mode is off.""" + manager = SparseAttentionStatsManager(module_name="test", enabled=True) + # Calibration mode is off by default + + stats = {"sparsity": 0.5, "phase": "prefill", "total_blocks": 100, "sparse_blocks": 50} + + manager.collect(stats) + + # Should NOT store in per_sample_stats + assert len(manager.per_sample_stats) == 0 + + # But should still aggregate + assert manager.aggregated_stats["total_calls"] == 1 + + def test_set_calibration_mode_with_reset(self): + """Test set_calibration_mode with reset_history=True.""" + manager = SparseAttentionStatsManager(module_name="test", enabled=True) + + # Collect some stats in calibration mode + manager.set_calibration_mode(enabled=True) + manager.collect( + { + "sparsity": 0.5, + "phase": "prefill", + "total_blocks": 100, + "sparse_blocks": 50, + "sample_length": 1024, + } + ) + assert len(manager.per_sample_stats) == 1 + + # Re-enable with reset + manager.set_calibration_mode(enabled=True, reset_history=True) + assert len(manager.per_sample_stats) == 0 # Should be cleared + + def test_set_calibration_mode_without_reset(self): + """Test set_calibration_mode with reset_history=False.""" + manager = SparseAttentionStatsManager(module_name="test", enabled=True) + + # Collect some stats + manager.set_calibration_mode(enabled=True) + manager.collect( + { + "sparsity": 0.5, + "phase": "prefill", + "total_blocks": 100, + "sparse_blocks": 50, + "sample_length": 1024, + } + ) + assert len(manager.per_sample_stats) == 1 + + # Disable without reset + manager.set_calibration_mode(enabled=False, reset_history=False) + assert len(manager.per_sample_stats) == 1 # Should be preserved + + +class TestGetSummary: + """Test get_summary() functionality.""" + + def test_get_summary_with_data(self): + """Test get_summary returns correct averages.""" + manager = SparseAttentionStatsManager(module_name="test_module", enabled=True) + + # Collect stats + manager.collect({"phase": "prefill", "total_blocks": 100, "sparse_blocks": 30}) + manager.collect({"phase": "prefill", "total_blocks": 100, "sparse_blocks": 50}) + + summary = manager.get_summary() + + assert summary["module"] == "test_module" + assert summary["total_calls"] == 2 + # Average sparsity: (30+50) / (100+100) = 80/200 = 0.4 + assert summary["average_sparsity"] == 0.4 + assert summary["phase_distribution"]["prefill"] == 2 + + def test_get_summary_no_data(self): + """Test get_summary with no collected data.""" + manager = SparseAttentionStatsManager(module_name="test", enabled=True) + + summary = manager.get_summary() + + assert summary["module"] == "test" + assert summary["total_calls"] == 0 + assert summary["average_sparsity"] == 0.0 + assert summary["phase_distribution"]["prefill"] == 0 + + def test_get_summary_zero_blocks(self): + """Test get_summary when total_blocks is zero.""" + manager = SparseAttentionStatsManager(module_name="test", enabled=True) + + # Collect stats with zero blocks + manager.collect({"phase": "prefill", "total_blocks": 0, "sparse_blocks": 0}) + + summary = manager.get_summary() + + assert summary["average_sparsity"] == 0.0 # Should handle division by zero + + +class TestGetCalibrationStats: + """Test get_calibration_stats() functionality.""" + + def test_get_calibration_stats(self): + """Test retrieving per-sample calibration stats.""" + manager = SparseAttentionStatsManager(module_name="test", enabled=True) + manager.set_calibration_mode(enabled=True) + + # Collect multiple samples + for i in range(3): + manager.collect( + { + "sparsity": 0.3 + i * 0.1, + "phase": "prefill", + "total_blocks": 100, + "sparse_blocks": 30, + "sample_length": 1024 + i * 512, + } + ) + + calib_stats = manager.get_calibration_stats() + + assert len(calib_stats) == 3 + assert calib_stats[0]["sparsity"] == 0.3 + assert calib_stats[1]["sparsity"] == 0.4 + assert calib_stats[2]["sparsity"] == 0.5 + + def test_get_calibration_stats_empty(self): + """Test get_calibration_stats when no calibration data.""" + manager = SparseAttentionStatsManager(module_name="test", enabled=True) + + calib_stats = manager.get_calibration_stats() + + assert calib_stats == [] + + +class TestReset: + """Test reset functionality.""" + + def test_reset(self): + """Test reset() clears all statistics.""" + manager = SparseAttentionStatsManager(module_name="test", enabled=True) + manager.set_calibration_mode(enabled=True) + + # Collect some stats + manager.collect( + { + "sparsity": 0.5, + "phase": "prefill", + "total_blocks": 100, + "sparse_blocks": 50, + "sample_length": 1024, + } + ) + manager.collect( + { + "sparsity": 0.3, + "phase": "decode", + "total_blocks": 10, + "sparse_blocks": 3, + "sample_length": 128, + } + ) + + # Verify stats exist + assert manager.aggregated_stats["total_calls"] == 2 + assert len(manager.per_sample_stats) == 2 + + # Reset + manager.reset() + + # All stats should be cleared + assert manager.aggregated_stats["total_calls"] == 0 + assert manager.aggregated_stats["total_blocks"] == 0 + assert manager.aggregated_stats["sparse_blocks"] == 0 + assert manager.per_sample_stats == [] + assert manager.aggregated_stats["phase_counts"]["prefill"] == 0 + assert manager.aggregated_stats["phase_counts"]["decode"] == 0 diff --git a/tests/unit/torch/sparsity/attention_sparsity/test_threshold_info.py b/tests/unit/torch/sparsity/attention_sparsity/test_threshold_info.py new file mode 100644 index 000000000..ac9f46a54 --- /dev/null +++ b/tests/unit/torch/sparsity/attention_sparsity/test_threshold_info.py @@ -0,0 +1,270 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for threshold calibration functionality.""" + +import pytest + +pytest.importorskip("transformers") + +from _test_utils.torch_sparsity.sparse_attention_common import SimpleAttentionModel + +from modelopt.torch.sparsity.attention_sparsity import sparsify +from modelopt.torch.sparsity.attention_sparsity.methods.flash_skip_softmax import FlashSkipSoftmax +from modelopt.torch.sparsity.attention_sparsity.sparse_attention import SparseAttentionModule + + +class TestFlashSkipSoftmaxThresholdInfo: + """Test FlashSkipSoftmax.get_threshold_info() method.""" + + def test_static_threshold(self): + """Test threshold info for static threshold.""" + method = FlashSkipSoftmax( + method_config={ + "threshold": 0.001, + "br": 128, + "bc": 128, + "backend": "pytorch", + "is_causal": True, + } + ) + + info = method.get_threshold_info() + + assert info["type"] == "static" + assert info["value"] == 0.001 + + def test_phased_threshold(self): + """Test threshold info for phase-specific thresholds.""" + method = FlashSkipSoftmax( + method_config={ + "threshold": {"prefill": 0.001, "decode": 0.0001}, + "br": 128, + "bc": 128, + "backend": "pytorch", + "is_causal": True, + } + ) + + info = method.get_threshold_info() + + assert info["type"] == "static_phased" + assert "thresholds" in info + assert info["thresholds"]["prefill"] == 0.001 + assert info["thresholds"]["decode"] == 0.0001 + assert "current" in info + + def test_dynamic_calibrated_threshold(self): + """Test threshold info for calibrated dynamic threshold.""" + method = FlashSkipSoftmax( + method_config={ + "threshold": 0.001, + "br": 128, + "bc": 128, + "backend": "pytorch", + "is_causal": True, + } + ) + + # Simulate calibration setting scale factor + method.threshold_scale_factor = 437.5 + + info = method.get_threshold_info() + + assert info["type"] == "dynamic" + assert info["scale_factor"] == 437.5 + assert info["formula"] == "λ / length" + assert "example_lengths" in info + assert abs(info["example_lengths"][1024] - 437.5 / 1024) < 1e-6 + assert abs(info["example_lengths"][2048] - 437.5 / 2048) < 1e-6 + + def test_threshold_info_structure(self): + """Test that threshold info has expected structure.""" + method = FlashSkipSoftmax( + method_config={ + "threshold": 0.001, + "br": 128, + "bc": 128, + "backend": "pytorch", + "is_causal": True, + } + ) + + info = method.get_threshold_info() + + # Should always have 'type' key + assert "type" in info + assert isinstance(info, dict) + + +class TestSparseAttentionModuleThresholdInfo: + """Test SparseAttentionModule.get_threshold_info() delegation.""" + + def test_module_delegates_to_method(self): + """Test that module correctly delegates to sparse method instance.""" + model = SimpleAttentionModel(hidden_size=64, num_heads=4) + + config = { + "sparse_cfg": { + "*attention*": { + "method": "flash_skip_softmax", + "threshold": 0.005, + "br": 64, + "bc": 64, + "enable": True, + } + }, + } + + sparse_model = sparsify(model, config) + + # Find sparse attention module + sparse_module = None + for module in sparse_model.modules(): + if isinstance(module, SparseAttentionModule): + sparse_module = module + break + + assert sparse_module is not None + + # Test get_threshold_info + info = sparse_module.get_threshold_info() + + assert info["type"] == "static" + assert info["value"] == 0.005 + + def test_module_with_calibrated_threshold(self): + """Test module reports calibrated threshold correctly.""" + model = SimpleAttentionModel(hidden_size=64, num_heads=4) + + config = { + "sparse_cfg": { + "*attention*": { + "method": "flash_skip_softmax", + "threshold": 0.001, + "br": 64, + "bc": 64, + "enable": True, + } + }, + } + + sparse_model = sparsify(model, config) + + # Find module and set calibrated threshold + for module in sparse_model.modules(): + if isinstance(module, SparseAttentionModule): + module._sparse_method_instance.threshold_scale_factor = 500.0 + break + + # Get threshold info + info = module.get_threshold_info() + + assert info["type"] == "dynamic" + assert info["scale_factor"] == 500.0 + + def test_module_without_method_instance(self): + """Test get_threshold_info when sparse method instance doesn't exist.""" + model = SimpleAttentionModel(hidden_size=64, num_heads=4) + + config = { + "sparse_cfg": { + "*attention*": { + "method": "flash_skip_softmax", + "threshold": 0.001, + "br": 64, + "bc": 64, + "enable": True, + } + }, + } + + sparse_model = sparsify(model, config) + + # Find module + for module in sparse_model.modules(): + if isinstance(module, SparseAttentionModule): + # Remove sparse method instance to test fallback + delattr(module, "_sparse_method_instance") + + info = module.get_threshold_info() + + assert info["type"] == "none" + assert info["value"] is None + break + + +class TestPrintSparseAttentionSummaryIntegration: + """Test integration with print_sparse_attention_summary.""" + + def test_summary_displays_static_threshold(self, capsys): + """Test that print function displays static thresholds.""" + from modelopt.torch.sparsity.attention_sparsity.conversion import ( + print_sparse_attention_summary, + ) + + model = SimpleAttentionModel(hidden_size=64, num_heads=4) + + config = { + "sparse_cfg": { + "*attention*": { + "method": "flash_skip_softmax", + "threshold": 0.001, + "br": 64, + "bc": 64, + "enable": True, + } + }, + } + + sparse_model = sparsify(model, config) + print_sparse_attention_summary(sparse_model) + + captured = capsys.readouterr() + assert "Static (1.00e-03)" in captured.out + assert "flash_skip_softmax" in captured.out + + def test_summary_displays_dynamic_threshold(self, capsys): + """Test that print function displays dynamic thresholds.""" + from modelopt.torch.sparsity.attention_sparsity.conversion import ( + print_sparse_attention_summary, + ) + + model = SimpleAttentionModel(hidden_size=64, num_heads=4) + + config = { + "sparse_cfg": { + "*attention*": { + "method": "flash_skip_softmax", + "threshold": 0.001, + "br": 64, + "bc": 64, + "enable": True, + } + }, + } + + sparse_model = sparsify(model, config) + + # Set calibrated threshold + for module in sparse_model.modules(): + if isinstance(module, SparseAttentionModule): + module._sparse_method_instance.threshold_scale_factor = 437.5 + + print_sparse_attention_summary(sparse_model) + + captured = capsys.readouterr() + assert "Dynamic (λ=437.500000)" in captured.out + assert "flash_skip_softmax" in captured.out