From 376b2068f6dd69d5fe5ce50af1e3043bf3311050 Mon Sep 17 00:00:00 2001 From: Kai Xu Date: Tue, 30 Sep 2025 17:59:31 -0700 Subject: [PATCH 1/2] add initial support for sparse attention Signed-off-by: Kai Xu --- examples/llm_sparse_attention/hf_spar_attn.py | 368 +++++++++++++++++ .../sparsity/attention_sparsity/__init__.py | 24 ++ .../calibration/__init__.py | 26 ++ .../sparsity/attention_sparsity/config.py | 358 ++++++++++++++++ .../sparsity/attention_sparsity/conversion.py | 387 ++++++++++++++++++ .../attention_sparsity/methods/__init__.py | 27 ++ .../methods/flash_softmax_skip.py | 289 +++++++++++++ .../attention_sparsity/methods/registry.py | 120 ++++++ .../torch/sparsity/attention_sparsity/mode.py | 85 ++++ .../attention_sparsity/model_sparsify.py | 197 +++++++++ .../attention_sparsity/nn/__init__.py | 20 + .../attention_sparsity/nn/sparse_attention.py | 205 ++++++++++ .../attention_sparsity/plugins/__init__.py | 22 + .../attention_sparsity/plugins/huggingface.py | 122 ++++++ 14 files changed, 2250 insertions(+) create mode 100644 examples/llm_sparse_attention/hf_spar_attn.py create mode 100644 modelopt/torch/sparsity/attention_sparsity/__init__.py create mode 100644 modelopt/torch/sparsity/attention_sparsity/calibration/__init__.py create mode 100644 modelopt/torch/sparsity/attention_sparsity/config.py create mode 100644 modelopt/torch/sparsity/attention_sparsity/conversion.py create mode 100644 modelopt/torch/sparsity/attention_sparsity/methods/__init__.py create mode 100644 modelopt/torch/sparsity/attention_sparsity/methods/flash_softmax_skip.py create mode 100644 modelopt/torch/sparsity/attention_sparsity/methods/registry.py create mode 100644 modelopt/torch/sparsity/attention_sparsity/mode.py create mode 100644 modelopt/torch/sparsity/attention_sparsity/model_sparsify.py create mode 100644 modelopt/torch/sparsity/attention_sparsity/nn/__init__.py create mode 100644 modelopt/torch/sparsity/attention_sparsity/nn/sparse_attention.py create mode 100644 modelopt/torch/sparsity/attention_sparsity/plugins/__init__.py create mode 100644 modelopt/torch/sparsity/attention_sparsity/plugins/huggingface.py diff --git a/examples/llm_sparse_attention/hf_spar_attn.py b/examples/llm_sparse_attention/hf_spar_attn.py new file mode 100644 index 000000000..461af581e --- /dev/null +++ b/examples/llm_sparse_attention/hf_spar_attn.py @@ -0,0 +1,368 @@ +#!/usr/bin/env python3 +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Example script for applying sparse attention to HuggingFace models.""" + +import argparse +import random +from pathlib import Path + +import numpy as np +import torch +import torch.nn as nn +from datasets import load_dataset +from transformers import AutoModelForCausalLM, AutoTokenizer + +import modelopt.torch.sparsity.attention_sparsity as mtsa +from modelopt.torch.export import export_hf_checkpoint +from modelopt.torch.sparsity.attention_sparsity import SparseAttentionConfig +from modelopt.torch.sparsity.attention_sparsity.config import ( + SKIP_SOFTMAX_CALIB, + SKIP_SOFTMAX_DEFAULT, +) +from modelopt.torch.sparsity.attention_sparsity.nn.sparse_attention import SparseAttentionModule +from modelopt.torch.utils.memory_monitor import launch_memory_monitor + +RAND_SEED = 1234 + +# You can define custom configurations or use the default +SPARSE_ATTN_CFG_CHOICES = { + "skip_softmax": SKIP_SOFTMAX_DEFAULT, + "skip_softmax_calib": SKIP_SOFTMAX_CALIB, +} + + +def print_sparsity_stats(model: nn.Module): + """Print sparsity statistics if available.""" + module_stats = [] + for name, module in model.named_modules(): + if hasattr(module, "get_stats"): + stats = module.get_stats() + if stats and "average_sparsity" in stats: + module_stats.append((name, stats["average_sparsity"])) + + if not module_stats: + print("No sparsity statistics available") + return + + # Check if all modules have the same sparsity + sparsities = [s for _, s in module_stats] + if len(set(sparsities)) == 1: + # All identical - show summary + print(f"Average sparsity across all {len(module_stats)} modules: {sparsities[0]:.2%}") + else: + # Different sparsities - show individual values + avg_sparsity = sum(sparsities) / len(sparsities) + print(f"Average sparsity: {avg_sparsity:.2%}") + print("Per-module breakdown:") + for name, sparsity in module_stats: + print(f" {name}: {sparsity:.2%} sparse") + + +def get_narrativeqa_samples(num_samples=3): + """Load samples from NarrativeQA dataset for testing. + + Args: + num_samples: Number of samples to generate + """ + # Load NarrativeQA dataset + dataset = load_dataset("narrativeqa", split="test", streaming=True) + + samples = [] + for i, item in enumerate(dataset): + if i >= num_samples: + break + + # Combine document context and question + context = item.get("document", {}).get("text", "") + question = item.get("question", {}).get("text", "") + + if context and question: + # Use the full context as-is + prompt = f"Context: {context}\n\nQuestion: {question}\n\nAnswer:" + samples.append(prompt) + + if not samples: + raise ValueError("Could not load NarrativeQA samples") + + print(f"Loaded {len(samples)} NarrativeQA samples") + return samples + + +def truncate_text(text: str, tokenizer, max_length: int): + """Truncate text from the middle to preserve beginning and end. + + Args: + text: Input text to truncate + tokenizer: Tokenizer to use for encoding + max_length: Maximum number of tokens + + Returns: + Truncated text that fits within max_length tokens + """ + # First tokenize to see if truncation is needed + tokens = tokenizer.encode(text, add_special_tokens=True) + + if len(tokens) <= max_length: + return text + + # Need to truncate - preserve beginning and end + # Reserve some tokens for special tokens + available_tokens = max_length - 2 # Account for special tokens + + # Split tokens roughly in half for beginning and end + begin_tokens = available_tokens // 2 + end_tokens = available_tokens - begin_tokens + + # Decode beginning and end parts + begin_text = tokenizer.decode(tokens[:begin_tokens], skip_special_tokens=True) + end_text = tokenizer.decode(tokens[-end_tokens:], skip_special_tokens=True) + + # Combine with ellipsis marker + return begin_text + " [...] " + end_text + + +def verify_outputs(model, tokenizer, args): + """Compare outputs between baseline and sparse attention models.""" + # Update seq_len to match calibration max_seqlen if calibration was used + base_config = SPARSE_ATTN_CFG_CHOICES.get(args.sparse_attn, {}) + if "calibration" in base_config and "max_seqlen" in base_config["calibration"]: + calib_max_seqlen = base_config["calibration"]["max_seqlen"] + if args.seq_len != calib_max_seqlen: + print( + f"\nNote: Updating test seq_len from {args.seq_len} to {calib_max_seqlen} " + f"to match calibration config" + ) + args.seq_len = calib_max_seqlen + + # Load and prepare a single test prompt + print(f"\nLoading test sample (will be tokenized up to {args.seq_len} tokens)") + prompts = get_narrativeqa_samples(num_samples=1) + prompt = prompts[0] + + # Prepare inputs + truncated_prompt = truncate_text(prompt, tokenizer, args.seq_len) + display_prompt = ( + truncated_prompt[:150] + "..." if len(truncated_prompt) > 150 else truncated_prompt + ) + + inputs = tokenizer( + truncated_prompt, + return_tensors="pt", + max_length=args.seq_len, + truncation=True, + padding=False, + ) + if torch.cuda.is_available(): + inputs = {k: v.cuda() for k, v in inputs.items()} + + print("\n" + "=" * 60) + print("BASELINE vs SPARSE ATTENTION COMPARISON") + print("=" * 60) + print(f"\nTest prompt: {display_prompt}") + print(f"Input tokens: {inputs['input_ids'].shape[1]} (max: {args.seq_len})") + if "[...]" in truncated_prompt: + print("Note: Text was middle-truncated to fit token limit") + + # Helper function to generate text + def generate_text(model, inputs, args, tokenizer): + with torch.no_grad(): + outputs = model.generate( + **inputs, + max_new_tokens=args.max_new_tokens, + do_sample=args.do_sample, + temperature=args.temperature if args.do_sample else 1.0, + pad_token_id=tokenizer.pad_token_id, + ) + input_length = inputs["input_ids"].shape[1] + generated_ids = outputs[0][input_length:] + return tokenizer.decode(generated_ids, skip_special_tokens=True) + + # Find all sparse attention modules + sparse_modules = [m for m in model.modules() if isinstance(m, SparseAttentionModule)] + + # Generate baseline by temporarily disabling sparse attention + print("\n" + "-" * 60) + print("Generating baseline (sparse attention disabled)...") + for module in sparse_modules: + module.disable() + baseline_text = generate_text(model, inputs, args, tokenizer) + + # Generate with sparse attention enabled + print("\nGenerating with sparse attention (calibrated thresholds)...") + for module in sparse_modules: + module.enable() + sparse_text = generate_text(model, inputs, args, tokenizer) + + # Display comparison + print("\n" + "-" * 60) + print("RESULTS:") + baseline_display = baseline_text[:300] + "..." if len(baseline_text) > 300 else baseline_text + sparse_display = sparse_text[:300] + "..." if len(sparse_text) > 300 else sparse_text + + print(f"\nBaseline: {baseline_display}") + print(f"With Sparse: {sparse_display}") + + if baseline_text == sparse_text: + print("\nOutputs are identical") + else: + print("\nOutputs differ") + + +def sparsify_model(model, args): + """Apply sparse attention to the model with optional calibration.""" + print(f"\nApplying sparse attention: {args.sparse_attn} with backend: {args.backend}") + base_config = SPARSE_ATTN_CFG_CHOICES[args.sparse_attn] + + # Create modified config with selected backend + modified_sparse_cfg = {} + for pattern, cfg in base_config["sparse_cfg"].items(): + modified_cfg = cfg.copy() + modified_cfg["backend"] = args.backend + modified_sparse_cfg[pattern] = modified_cfg + + # Create new config with modified settings + sparse_config = SparseAttentionConfig( + method=base_config["method"], + sparse_cfg=modified_sparse_cfg, + collect_stats=True, # Enable stats collection for monitoring + ) + + # Sparsify with optional calibration - framework handles calibration automatically + model = mtsa.sparsify(model, config=sparse_config) + + print("Sparse attention applied successfully!") + + # Show sparsity statistics + print("\n" + "=" * 60) + print("Sparsity Statistics") + print("=" * 60) + print_sparsity_stats(model) + + return model + + +def main(args): + """Main function to run the selected mode.""" + if not torch.cuda.is_available(): + raise OSError("GPU is required for inference.") + + random.seed(RAND_SEED) + np.random.seed(RAND_SEED) + launch_memory_monitor() + + print(f"Loading model: {args.pyt_ckpt_path}") + + # Load model and tokenizer + # Note: attn_implementation="eager" is required for calibration to work properly + # (flash_attention_2 or sdpa would bypass the softmax patching needed for stats collection) + model = AutoModelForCausalLM.from_pretrained( + args.pyt_ckpt_path, + attn_implementation="eager", + torch_dtype=torch.bfloat16, + ) + tokenizer = AutoTokenizer.from_pretrained(args.pyt_ckpt_path) + + # Set pad token if not set + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + + # Move model to GPU if available + if torch.cuda.is_available(): + model = model.cuda() + print("Model moved to CUDA") + + # Apply sparse attention to the model (with calibration if configured) + model = sparsify_model(model, args) + + # Verify outputs if requested (compares baseline vs calibrated sparse model) + if args.verify_output: + verify_outputs(model, tokenizer, args) + + # Export if requested + if args.export_dir: + print(f"\nExporting model to: {args.export_dir}") + export_dir = Path(args.export_dir) + export_dir.mkdir(parents=True, exist_ok=True) + + with torch.inference_mode(): + export_hf_checkpoint(model, export_dir=export_dir) + + tokenizer.save_pretrained(export_dir) + print(f"Model exported successfully to: {export_dir}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description=__doc__) + + # Model arguments + parser.add_argument( + "--pyt_ckpt_path", + type=str, + required=True, + help="Specify where the PyTorch checkpoint path is", + ) + parser.add_argument( + "--sparse_attn", + type=str, + default="skip_softmax", + choices=list(SPARSE_ATTN_CFG_CHOICES.keys()), + help="Sparse attention configuration to apply.", + ) + parser.add_argument( + "--backend", + type=str, + default="pytorch", + choices=["pytorch", "triton"], + help="Backend to use for sparse attention computation (default: pytorch)", + ) + + # Sequence length arguments + parser.add_argument( + "--seq_len", + type=int, + default=2048, + help="Maximum sequence length for input prompts (will be truncated if longer)", + ) + parser.add_argument( + "--num_samples", + type=int, + default=3, + help="Number of samples to use from NarrativeQA dataset", + ) + + # Generation arguments + parser.add_argument( + "--max_new_tokens", type=int, default=50, help="Maximum new tokens to generate" + ) + parser.add_argument("--do_sample", action="store_true", help="Use sampling for generation") + parser.add_argument("--temperature", type=float, default=0.7, help="Temperature for sampling") + + # Operation arguments + parser.add_argument( + "--verify_output", + action="store_true", + help="Verify that sparse attention outputs match baseline", + ) + parser.add_argument( + "--export_dir", + type=str, + default=None, + help="Directory to export the model with sparse attention applied", + ) + + args = parser.parse_args() + main(args) diff --git a/modelopt/torch/sparsity/attention_sparsity/__init__.py b/modelopt/torch/sparsity/attention_sparsity/__init__.py new file mode 100644 index 000000000..150f93a3a --- /dev/null +++ b/modelopt/torch/sparsity/attention_sparsity/__init__.py @@ -0,0 +1,24 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Extensible sparse attention optimization for transformer models.""" + +# Initialize mode +from . import mode + +# Add methods to namespace +from .config import * +from .conversion import * +from .model_sparsify import * diff --git a/modelopt/torch/sparsity/attention_sparsity/calibration/__init__.py b/modelopt/torch/sparsity/attention_sparsity/calibration/__init__.py new file mode 100644 index 000000000..3b616e8e3 --- /dev/null +++ b/modelopt/torch/sparsity/attention_sparsity/calibration/__init__.py @@ -0,0 +1,26 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Calibration framework for sparse attention methods.""" + +from .calibrate import calibrate_sparse_attention +from .calibrator import DynamicThresholdCalibrator +from .dataset import RulerDatasetBuilder + +__all__ = [ + "DynamicThresholdCalibrator", + "RulerDatasetBuilder", + "calibrate_sparse_attention", +] diff --git a/modelopt/torch/sparsity/attention_sparsity/config.py b/modelopt/torch/sparsity/attention_sparsity/config.py new file mode 100644 index 000000000..5fdab0032 --- /dev/null +++ b/modelopt/torch/sparsity/attention_sparsity/config.py @@ -0,0 +1,358 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Configuration classes for sparse attention optimization.""" + +from collections.abc import Callable +from typing import Any + +from pydantic import Field, field_validator + +from modelopt.torch.opt.config import ModeloptBaseConfig, ModeloptField + +# Type definitions for sparse configuration +SparseAttributeConfig = dict[str, Any] # Configuration for a specific pattern + +SparseAttentionCfgType = dict[ + str | Callable, # Pattern or callable for matching modules + SparseAttributeConfig, # Configuration dict with threshold, enable, etc. +] + + +class SparseAttentionAttributeConfig(ModeloptBaseConfig): + """Sparse attention attribute configuration for pattern-based module config.""" + + enable: bool = ModeloptField( + default=True, + title="Enable sparse attention.", + description="If True, enables sparse attention. If False, bypasses sparsity.", + ) + + method: str = ModeloptField( + default="flash_softmax_skip", + title="Sparse attention method.", + description="The sparse attention method to use (e.g., 'flash_softmax_skip').", + ) + + threshold: float | dict[str, float] = ModeloptField( + default=1e-3, + title="Sparsity threshold.", + description=( + "Threshold for determining which attention values to skip. " + "Can be a float or dict with phase-specific values." + ), + ) + + br: int = ModeloptField( + default=128, + title="Block row size.", + description="Block row size for block-wise sparsity in Flash Attention.", + ) + + bc: int = ModeloptField( + default=128, + title="Block column size.", + description="Block column size for block-wise sparsity in Flash Attention.", + ) + + collect_stats: bool = ModeloptField( + default=False, + title="Collect statistics.", + description="Whether to collect sparsity statistics during forward pass.", + ) + + backend: str = ModeloptField( + default="pytorch", + title="Backend implementation.", + description=( + "Backend to use for sparse attention computation. " + "Only 'pytorch' is supported, which uses softmax patching with F.softmax. " + "Requires model to be loaded with attn_implementation='eager'." + ), + ) + + is_causal: bool = ModeloptField( + default=True, + title="Causal attention flag.", + description=( + "Whether the model uses causal (autoregressive) attention. " + "If True, sparsity statistics are calculated over the lower triangle only. " + "Defaults to True for decoder-only models like GPT, LLaMA, etc." + ), + ) + + calibration: dict | None = ModeloptField( + default=None, + title="Calibration configuration", + description=( + "Calibration settings for this pattern. " + "If provided, enables automatic threshold calibration. " + "Only one pattern should have calibration enabled." + ), + ) + + @field_validator("method") + @classmethod + def validate_method(cls, v): + """Validate method is a string.""" + if not isinstance(v, str): + raise ValueError("method must be a string") + return v + + @field_validator("backend") + @classmethod + def validate_backend(cls, v): + """Validate backend is pytorch.""" + if v != "pytorch": + raise ValueError( + f"Invalid backend: {v}. Only 'pytorch' backend is supported. " + f"Model must be loaded with attn_implementation='eager'." + ) + return v + + @field_validator("br", "bc") + @classmethod + def validate_block_size(cls, v): + """Validate block sizes are positive integers.""" + if v <= 0: + raise ValueError(f"Block size must be positive, got {v}") + return v + + @field_validator("threshold") + @classmethod + def validate_threshold(cls, v): + """Validate threshold is in valid range (0, 1) or dict with valid phases.""" + if isinstance(v, dict): + # Validate phase keys + valid_phases = {"prefill", "decode", "default"} + invalid_keys = set(v.keys()) - valid_phases + if invalid_keys: + raise ValueError( + f"Invalid threshold phases: {invalid_keys}. Valid phases: {valid_phases}" + ) + # Validate all values are in range (0, 1) + for phase, threshold in v.items(): + if not isinstance(threshold, (int, float)) or threshold <= 0 or threshold >= 1: + raise ValueError( + f"Threshold for phase '{phase}' must be in range (0, 1), got {threshold}" + ) + elif isinstance(v, (int, float)): + if v <= 0 or v >= 1: + raise ValueError(f"Threshold must be in range (0, 1), got {v}") + else: + raise ValueError(f"Threshold must be a number in range (0, 1) or dict, got {type(v)}") + return v + + +class CalibrationConfig(ModeloptBaseConfig): + """Configuration for automatic threshold calibration using RULER dataset. + + Calibration learns a dynamic threshold λ = scale_factor / sequence_length that + achieves target sparsity. Only supports prefill phase (seq_len > 1). + """ + + target_sparse_ratio: float = ModeloptField( + default=0.5, + title="Target sparsity ratio", + description="Target ratio of sparse attention blocks (0.0 to 1.0).", + ) + + samples: int = ModeloptField( + default=24, + title="Calibration samples", + description="Total number of RULER samples for calibration (distributed across length bins).", + ) + + max_seqlen: int = ModeloptField( + default=32768, + title="Maximum sequence length", + description="Maximum sequence length for calibration (length bins auto-generated as powers of 2).", + ) + + num_length_bins: int = ModeloptField( + default=4, + title="Number of length bins", + description="Number of length bins to generate (hidden parameter, default: 4).", + ) + + threshold_trials: list[float] | None = ModeloptField( + default=None, + title="Threshold trials", + description=( + "List of threshold values to test during calibration. " + "If None, uses default: [1e-6, 5e-6, 1e-5, 5e-5, 1e-4, 5e-4, 1e-3, 5e-3, 1e-2, 5e-2, 1e-1, 5e-1]" + ), + ) + + @field_validator("threshold_trials") + @classmethod + def validate_threshold_trials(cls, v): + """Validate threshold_trials are in valid range.""" + if v is not None: + if not isinstance(v, list): + raise ValueError(f"threshold_trials must be a list, got {type(v)}") + if len(v) == 0: + raise ValueError("threshold_trials must not be empty") + for threshold in v: + if not isinstance(threshold, (int, float)): + raise ValueError(f"All threshold_trials must be numbers, got {type(threshold)}") + if threshold <= 0 or threshold >= 1: + raise ValueError( + f"All threshold_trials must be in range (0, 1), got {threshold}" + ) + return v + + @field_validator("target_sparse_ratio") + @classmethod + def validate_target_sparse_ratio(cls, v): + """Validate target sparsity ratio is between 0 and 1.""" + if not 0.0 <= v <= 1.0: + raise ValueError(f"target_sparse_ratio must be between 0.0 and 1.0, got {v}") + return v + + @field_validator("samples") + @classmethod + def validate_samples(cls, v): + """Validate samples is positive.""" + if v <= 0: + raise ValueError(f"samples must be positive, got {v}") + return v + + @field_validator("max_seqlen") + @classmethod + def validate_max_seqlen(cls, v): + """Validate max_seqlen is at least 1024.""" + if v < 1024: + raise ValueError(f"max_seqlen must be >= 1024, got {v}") + return v + + @field_validator("num_length_bins") + @classmethod + def validate_num_length_bins(cls, v): + """Validate num_length_bins is positive.""" + if v <= 0: + raise ValueError(f"num_length_bins must be positive, got {v}") + return v + + +# Pre-defined Sparse Attention Configuration +# Default configuration with block-wise sparsity optimized for Flash Attention +SKIP_SOFTMAX_DEFAULT = { + "method": "flash_softmax_skip", + "sparse_cfg": { + "*attn*": { + "threshold": { + "prefill": 1e-3, # More aggressive during prefill + "decode": 1e-4, # Conservative during decode + }, + "br": 128, # Flash Attention block rows + "bc": 128, # Flash Attention block columns + "backend": "pytorch", # Only pytorch backend supported + "enable": True, + }, + "default": {"enable": False}, + }, +} + + +# Configuration with RULER calibration +# Note: threshold field is omitted - calibration determines dynamic threshold λ = a / length +# The calibrated threshold adapts to sequence length for optimal sparsity +SKIP_SOFTMAX_CALIB = { + "method": "flash_softmax_skip", + "sparse_cfg": { + "*attn*": { + "br": 128, + "bc": 128, + "backend": "pytorch", # Only pytorch backend supported + "enable": True, + "calibration": { + "target_sparse_ratio": 0.5, + "samples": 120, + "max_seqlen": 8192, + }, + }, + "default": {"enable": False}, + }, +} + + +class SparseAttentionConfig(ModeloptBaseConfig): + """Base configuration for sparse attention optimization. + + This base configuration provides the common structure for all sparse + attention methods and supports pattern-based layer configuration. + """ + + # Method selection + method: str = Field("flash_softmax_skip", description="Sparse attention method to use") + + # Statistics collection + collect_stats: bool = Field( + False, description="Whether to collect sparsity statistics during forward pass" + ) + + # Pattern-based sparse configuration (similar to quant_cfg in quantization) + sparse_cfg: SparseAttentionCfgType = ModeloptField( + default={"*attention*": {"enable": True}, "default": {"enable": False}}, + title="Sparse attention configuration", + description="Pattern-based configuration for sparse attention. Keys are patterns to match module names, " + "values are configuration dicts with parameters like 'threshold', 'enable', and 'calibration'.", + validate_default=True, + ) + + # Export configuration + export_format: str | None = Field( + None, description="Export format for sparse attention (e.g., 'onnx', 'tensorrt')" + ) + + +class FlashSoftmaxSkipConfig(SparseAttentionConfig): + """Configuration for Flash Attention-aware softmax skip sparse attention.""" + + # Override method to default to flash_softmax_skip + method: str = Field( + "flash_softmax_skip", description="Sparse attention method (fixed to flash_softmax_skip)" + ) + + # Override sparse_cfg with flash_softmax_skip specific defaults + sparse_cfg: SparseAttentionCfgType = ModeloptField( + default={ + "*attention*": { + "threshold": {"prefill": 1e-3, "decode": 1e-5}, + "br": 128, # Flash Attention block rows + "bc": 128, # Flash Attention block columns + "backend": "pytorch", # Only pytorch backend supported + "enable": True, + }, + "default": {"enable": False}, + }, + title="Flash softmax skip sparse configuration", + description="Pattern-based configuration with flash_softmax_skip specific defaults. " + "Includes FA block sizes (br, bc) and correction factor settings.", + validate_default=True, + ) + + +__all__ = [ + "SKIP_SOFTMAX_CALIB", + "SKIP_SOFTMAX_DEFAULT", + "CalibrationConfig", + "FlashSoftmaxSkipConfig", + "SparseAttentionAttributeConfig", + "SparseAttentionCfgType", + "SparseAttentionConfig", + "SparseAttributeConfig", +] diff --git a/modelopt/torch/sparsity/attention_sparsity/conversion.py b/modelopt/torch/sparsity/attention_sparsity/conversion.py new file mode 100644 index 000000000..028e2bb67 --- /dev/null +++ b/modelopt/torch/sparsity/attention_sparsity/conversion.py @@ -0,0 +1,387 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Conversion and restoration utilities for sparse attention.""" + +import fnmatch +from collections.abc import Callable +from typing import Any + +import torch.nn as nn + +from modelopt.torch.opt.conversion import ModelLikeModule, ModeloptStateManager +from modelopt.torch.opt.mode import ConvertReturnType, MetadataDict +from modelopt.torch.utils import get_unwrapped_name + +from .config import SparseAttentionConfig +from .nn.sparse_attention import SparseAttentionModule, SparseAttentionRegistry +from .plugins.huggingface import register_sparse_attention_on_the_fly + + +def is_attn_sparsified(model: nn.Module) -> bool: + """Check if a model has sparse attention applied. + + Similar to quantization's is_quantized for API consistency. + + Args: + model: Model to check + + Returns: + True if model contains any SparseAttentionModule instances + """ + return any(isinstance(module, SparseAttentionModule) for module in model.modules()) + + +def convert_to_sparse_attention_model( + model: ModelLikeModule, config: SparseAttentionConfig +) -> ConvertReturnType: + """Convert model to use sparse attention. + + Args: + model: Model to convert + config: Sparse attention configuration + + Returns: + Tuple of (converted_model, metadata) + """ + # Initialize the true module if necessary + model = model.init_modellike() if isinstance(model, ModelLikeModule) else model + + # Register sparse attention modules dynamically + register_sparse_attention_on_the_fly(model) + + # Replace attention modules with sparse versions + replace_sparse_attention_modules(model, version=ModeloptStateManager(model).state_version) + + # Apply configuration to sparse attention modules + sparse_cfg = config.sparse_cfg if hasattr(config, "sparse_cfg") else {} + set_sparse_attention_by_cfg(model, sparse_cfg, config) + + # Create metadata + metadata = {} + update_sparse_attention_metadata(model, config, metadata) + + return model, metadata + + +def replace_sparse_attention_modules(model: nn.Module, version=None): + """Replace regular attention modules with sparse attention modules. + + Recursively replace all attention modules in the model with their sparse attention counterparts. + + Args: + model: Model to process + version: State version for tracking (optional) + """ + # Recursively replace modules + _replace_sparse_attention_modules(model, version=version) + + # Count and report replaced modules + replaced_count = sum(isinstance(m, SparseAttentionModule) for _, m in model.named_modules()) + if replaced_count > 0: + print(f"Inserted {replaced_count} sparse attention modules") + + +def _replace_sparse_attention_modules(model: nn.Module, version=None): + """Helper function for replace_sparse_attention_modules.""" + for name, child in model.named_children(): + if type(child) in SparseAttentionRegistry: + # REPLACE on the parent (model), not on child + sparse_module = SparseAttentionRegistry.convert(child) + setattr(model, name, sparse_module) + + # Now recurse into whichever module is now at `model.name` + _replace_sparse_attention_modules(getattr(model, name), version=version) + + +def set_sparse_attention_by_cfg(model: nn.Module, sparse_cfg: dict, config: SparseAttentionConfig): + """Apply sparse attention configuration to model. + + Similar to quantization's set_quantizer_by_cfg. + + Args: + model: Model with sparse attention modules + sparse_cfg: Sparse configuration dictionary + config: Global sparse attention configuration + """ + sparse_cfg = sparse_cfg.copy() + + # Apply default first if exists + if "default" in sparse_cfg: + set_sparse_attention_attribute(model, "*", sparse_cfg["default"], config) + sparse_cfg.pop("default") + + # Apply pattern-specific configs + for pattern, cfg in sparse_cfg.items(): + set_sparse_attention_attribute(model, pattern, cfg, config) + + +def set_sparse_attention_attribute( + model: nn.Module, + wildcard_or_filter: str | Callable, + attribute_cfg: dict[str, Any], + global_config: SparseAttentionConfig, +): + """Set sparse attention attributes for modules matching pattern. + + Similar to quantization's set_quantizer_attribute. + + Args: + model: Model to configure + wildcard_or_filter: Pattern to match module names + attribute_cfg: Attributes to apply + global_config: Global sparse attention configuration + """ + # Merge global config fields with pattern config + # Filter out model-level configs that shouldn't be passed to modules + module_cfg = {k: v for k, v in attribute_cfg.items() if k != "calibration"} + + full_cfg = { + "method": global_config.method, + "collect_stats": global_config.collect_stats, + **module_cfg, + } + + for name, module in model.named_modules(): + if not isinstance(module, SparseAttentionModule): + continue + + # Check pattern match + matched = False + if isinstance(wildcard_or_filter, str): + matched = fnmatch.fnmatch(name, wildcard_or_filter) + elif callable(wildcard_or_filter): + matched = wildcard_or_filter(name) + else: + continue + + if matched: + # Apply config using the same method as TensorQuantizer + module.set_from_attribute_config(full_cfg) + + +def restore_sparse_attention_model( + model: ModelLikeModule, config: SparseAttentionConfig, metadata: MetadataDict +) -> nn.Module: + """Restore sparse attention model from saved state. + + Args: + model: Model to restore + config: Sparse attention configuration + metadata: Saved metadata + + Returns: + Restored model + """ + # Convert to sparse attention model + model, _ = convert_to_sparse_attention_model(model, config) + + # Restore sparse attention state from metadata + if "sparse_attention_state" in metadata: + restore_sparse_attention_state(model, metadata["sparse_attention_state"]) + + return model + + +def restore_sparse_attention_state(model: nn.Module, state_dict: dict[str, Any]): + """Restore sparse attention state from state dict. + + Args: + model: Model with sparse attention modules + state_dict: Saved state dictionary + """ + for name, module in model.named_modules(): + if isinstance(module, SparseAttentionModule): + module_name = get_unwrapped_name(name, model) + if module_name in state_dict: + module_state = state_dict[module_name] + + # Restore method and config + if "method" in module_state: + module._method = module_state["method"] + if "method_config" in module_state: + # Restore config attributes + for key, val in module_state["method_config"].items(): + setattr(module, f"_{key}", val) + + # Re-setup with restored config + module._setup() + + +def update_sparse_attention_metadata( + model: nn.Module, config: SparseAttentionConfig, metadata: MetadataDict +) -> None: + """Update metadata with sparse attention state. + + Args: + model: Model with sparse attention + config: Configuration used + metadata: Metadata dict to update + """ + sparse_state = {} + + for name, module in model.named_modules(): + if isinstance(module, SparseAttentionModule): + module_name = get_unwrapped_name(name, model) + + # Collect method config from module attributes + method_config = { + k[1:]: v + for k, v in module.__dict__.items() + if k.startswith("_") and k not in ("_method", "_enabled", "_sparse_method_instance") + } + + module_state = { + "method": module._sparse_method_instance.name, + "method_config": method_config, + } + + sparse_state[module_name] = module_state + + metadata["sparse_attention_state"] = sparse_state + metadata["sparse_attention_config"] = ( + config.model_dump() if hasattr(config, "model_dump") else vars(config) + ) + + +def disable_sparse_attention(model: nn.Module, wildcard_or_filter_func: str | Callable): + """Disable sparse attention for matching modules. + + Similar to mtq.disable_quantizer for API consistency. + + Args: + model: Model with sparse attention applied + wildcard_or_filter_func: Wildcard string or filter function to match module names. + For example: "*lm_head*", "*layer_0*", etc. + + Example: + >>> import modelopt.torch.sparsity.attention_sparsity as sparse_attn + >>> model = sparse_attn.sparsify(model, config) + >>> # Disable sparse attention for lm_head + >>> sparse_attn.disable_sparse_attention(model, "*lm_head*") + """ + for name, module in model.named_modules(): + if not isinstance(module, SparseAttentionModule): + continue + + matched = False + if isinstance(wildcard_or_filter_func, str): + matched = fnmatch.fnmatch(name, wildcard_or_filter_func) + elif callable(wildcard_or_filter_func): + matched = wildcard_or_filter_func(name) + + if matched: + module.disable() + + +def enable_sparse_attention(model: nn.Module, wildcard_or_filter_func: str | Callable): + """Enable sparse attention for matching modules. + + Similar to mtq.enable_quantizer for API consistency. + + Args: + model: Model with sparse attention applied + wildcard_or_filter_func: Wildcard string or filter function to match module names. + For example: "*attention*", "*attn*", etc. + + Example: + >>> import modelopt.torch.sparsity.attention_sparsity as sparse_attn + >>> model = sparse_attn.sparsify(model, config) + >>> # Re-enable sparse attention for all attention modules + >>> sparse_attn.enable_sparse_attention(model, "*attention*") + """ + for name, module in model.named_modules(): + if not isinstance(module, SparseAttentionModule): + continue + + matched = False + if isinstance(wildcard_or_filter_func, str): + matched = fnmatch.fnmatch(name, wildcard_or_filter_func) + elif callable(wildcard_or_filter_func): + matched = wildcard_or_filter_func(name) + + if matched: + module.enable() + + +def print_sparse_attention_summary(model: nn.Module): + """Print summary of sparse attention modules in the model. + + Similar to mtq.print_quant_summary for API consistency. + + Args: + model: Model with sparse attention applied + + Prints: + - Total sparse attention modules + - Enabled vs disabled count + - Method distribution + - Configuration summary by module + + Example: + >>> import modelopt.torch.sparsity.attention_sparsity as sparse_attn + >>> model = sparse_attn.sparsify(model, config) + >>> sparse_attn.print_sparse_attention_summary(model) + """ + sparse_modules = [] + for name, module in model.named_modules(): + if isinstance(module, SparseAttentionModule): + sparse_modules.append((name, module)) + + if not sparse_modules: + print("No sparse attention modules found in model") + return + + enabled_count = sum(1 for _, m in sparse_modules if m.is_enabled) + disabled_count = len(sparse_modules) - enabled_count + + # Count methods + method_counts = {} + for _, module in sparse_modules: + method = getattr(module, "_method", "unknown") + method_counts[method] = method_counts.get(method, 0) + 1 + + print(f"\n{'=' * 70}") + print(f"{'Sparse Attention Summary':^70}") + print(f"{'=' * 70}") + print(f"Total sparse attention modules: {len(sparse_modules)}") + print(f" Enabled: {enabled_count}") + print(f" Disabled: {disabled_count}") + + if method_counts: + print("\nMethods:") + for method, count in sorted(method_counts.items()): + print(f" {method}: {count}") + + print(f"\n{'Module Details':^70}") + print(f"{'-' * 70}") + + for name, module in sparse_modules: + status = "✓" if module.is_enabled else "✗" + method = getattr(module, "_method", "unknown") + threshold = getattr(module, "_threshold", "N/A") + + # Format threshold nicely + if isinstance(threshold, dict): + threshold_str = str(threshold) + elif isinstance(threshold, float): + threshold_str = f"{threshold:.2e}" + else: + threshold_str = str(threshold) + + print(f"{status} {name}") + print(f" Method: {method}, Threshold: {threshold_str}") + + print(f"{'=' * 70}\n") diff --git a/modelopt/torch/sparsity/attention_sparsity/methods/__init__.py b/modelopt/torch/sparsity/attention_sparsity/methods/__init__.py new file mode 100644 index 000000000..5120bd755 --- /dev/null +++ b/modelopt/torch/sparsity/attention_sparsity/methods/__init__.py @@ -0,0 +1,27 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Sparse attention methods package.""" + +from .registry import SparseAttentionMethod, get_sparse_method, register_sparse_method + +__all__ = [ + "SparseAttentionMethod", + "get_sparse_method", + "register_sparse_method", +] + +# Import method implementations to trigger registration +from . import flash_softmax_skip diff --git a/modelopt/torch/sparsity/attention_sparsity/methods/flash_softmax_skip.py b/modelopt/torch/sparsity/attention_sparsity/methods/flash_softmax_skip.py new file mode 100644 index 000000000..04b696d11 --- /dev/null +++ b/modelopt/torch/sparsity/attention_sparsity/methods/flash_softmax_skip.py @@ -0,0 +1,289 @@ +"""Flash Attention-aware softmax skip method for sparse attention. + +This module implements block-wise sparsity that aligns with Flash Attention's +processing pattern for optimal performance. +""" + +import math + +import numpy as np +import torch + +from . import SparseAttentionMethod, register_sparse_method + + +@register_sparse_method("flash_softmax_skip") +class FlashSoftmaxSkipMethod(SparseAttentionMethod): + """Flash Attention-aware softmax skip sparse attention method. + + Implements row-level block-wise sparsity aligned with Flash Attention's + processing pattern for optimal performance and accuracy. + """ + + def __init__(self, method_config: dict | None = None): + """Initialize Flash softmax skip method. + + Args: + method_config: Configuration dict with threshold, br, bc, is_causal, etc. + """ + config = method_config or {} + + # Extract configuration + self.threshold_config = config.get("threshold", 1e-4) + self.br = config.get("br", 128) + self.bc = config.get("bc", 128) + self.enable_correction_factor = config.get("enable_correction_factor", True) + self.collect_stats = config.get("collect_stats", True) + self.phase = config.get("phase", None) + self.backend = config.get("backend", "pytorch") + self.is_causal = config.get("is_causal", True) + # Calibration mode: when True, prevent threshold updates to preserve calibrator's test threshold + self._calibration_mode = False + + # Initialize threshold + if isinstance(self.threshold_config, dict): + self.threshold = self.threshold_config.get( + "default", self.threshold_config.get("prefill", 1e-4) + ) + else: + self.threshold = self.threshold_config + + def _update_threshold(self, phase: str): + """Update threshold based on phase.""" + if isinstance(self.threshold_config, dict): + self.threshold = self.threshold_config.get( + phase, self.threshold_config.get("default", self.threshold) + ) + + def set_calibration_mode(self, enabled: bool): + """Set calibration mode to prevent _update_threshold from modifying the threshold.""" + self._calibration_mode = enabled + + def _infer_phase(self, attention_scores: torch.Tensor) -> str: + """Infer phase from attention scores shape.""" + return "decode" if attention_scores.shape[2] == 1 else "prefill" + + def _reshape_to_blocks( + self, tensor: torch.Tensor, br: int, bc: int + ) -> tuple[torch.Tensor, ...]: + """Reshape tensor into blocks for Flash Attention processing. + + Args: + tensor: Input tensor of shape [batch, heads, seq_q, seq_k] + br: Block row size + bc: Block column size + + Returns: + Tuple of (blocked_tensor, num_block_rows, num_block_cols, padded_seq_q, padded_seq_k) + """ + batch_size, num_heads, seq_q, seq_k = tensor.shape + + # Calculate padding needed + padded_seq_q = math.ceil(seq_q / br) * br + padded_seq_k = math.ceil(seq_k / bc) * bc + + # Pad tensor if necessary + if padded_seq_q != seq_q or padded_seq_k != seq_k: + pad_q = padded_seq_q - seq_q + pad_k = padded_seq_k - seq_k + # Use dtype min instead of -inf for numerical stability + pad_value = torch.finfo(tensor.dtype).min + tensor = torch.nn.functional.pad(tensor, (0, pad_k, 0, pad_q), value=pad_value) + + # Reshape to blocks + num_block_rows = padded_seq_q // br + num_block_cols = padded_seq_k // bc + + # Keep natural order for row-level processing: [batch, heads, block_rows, br, block_cols, bc] + blocked = tensor.view(batch_size, num_heads, num_block_rows, br, num_block_cols, bc) + + return blocked, num_block_rows, num_block_cols, padded_seq_q, padded_seq_k + + def calc_correction_factor_and_p( + self, attn_weights: torch.Tensor, phase: str + ) -> tuple[torch.Tensor, dict]: + """Calculate sparse mask and statistics for Flash Attention. + + Implements block-wise sparsity compatible with Flash Attention's online softmax: + 1. Reshape attention scores into 128x128 blocks + 2. Track block-wise maximum values (simulating Flash Attention's row processing) + 3. Compute cumulative maximum across blocks (for online normalization) + 4. Apply threshold: mask blocks where p = score - cummax < log(threshold) + 5. Calculate correction factor and sparsity statistics + + Args: + attn_weights: Pre-softmax attention scores [batch, heads, seq_q, seq_k] + phase: "prefill" (seq_q > 1) or "decode" (seq_q = 1) + + Returns: + element_mask: Boolean mask [batch, heads, seq_q, seq_k] + stats: Dict with sparsity, correction_factor, total_blocks, etc. + """ + batch_size, num_heads, seq_q, seq_k = attn_weights.shape + + # Calculate threshold + threshold_scale_factor = getattr(self, "threshold_scale_factor", None) + if threshold_scale_factor: + # Use calibrated dynamic threshold: λ = scale_factor / length + log_threshold = np.log(threshold_scale_factor / seq_k) + else: + # Use static threshold from config + log_threshold = np.log(self.threshold) + + if phase == "prefill": + blocked_attn, num_block_rows, num_block_cols, padded_seq_q, padded_seq_k = ( + self._reshape_to_blocks(attn_weights, self.br, self.bc) + ) + + # Step 1: Compute maximum value in each block + # For each 128x128 block, find max across the 128 columns + # blocked_attn: [batch, heads, block_rows, br=128, block_cols, bc=128] + # block_max: [batch, heads, block_rows, br=128, block_cols] + block_max = blocked_attn.max(dim=-1)[0] + + # Step 2: Track cumulative maximum across blocks (left to right) + # This simulates Flash Attention's online softmax normalization + # block_max_cummax: [batch, heads, block_rows, br=128, block_cols] + block_max_cummax = block_max.cummax(dim=-1)[0] + + # Step 3: Calculate correction factor (how often max changes) + # Used by Flash Attention to adjust running sum when max increases + block_max_larger = torch.ones_like(block_max) + block_max_larger[..., 1:] = block_max[..., 1:] > block_max_cummax[..., :-1] + correction_factor = float(torch.sum(block_max_larger) / torch.numel(block_max_larger)) + + # Step 4: Normalize attention scores by cumulative max + # p represents log-space difference: log(score) - log(cummax) + p = blocked_attn - block_max_cummax[..., None] + + # Step 5: Apply threshold and create block-level mask + # Keep blocks where at least one element exceeds log(threshold) + p_larger_than_thresh = p > log_threshold + # Reduce over bc (128 cols), then br (128 rows) to get block-level decision + # Result: [batch, heads, block_rows, block_cols] + block_mask = p_larger_than_thresh.any(dim=-1).any(dim=-2) + + # Step 6: Expand block mask back to element level + # All 128x128 elements in a block share the same mask value + # [batch, heads, block_rows, block_cols] -> [batch, heads, block_rows, br=128, block_cols, bc=128] + element_mask = block_mask.unsqueeze(-2).unsqueeze(-1).expand_as(blocked_attn) + + # Step 7: Reshape to original attention shape and remove padding + element_mask = element_mask.reshape(batch_size, num_heads, padded_seq_q, padded_seq_k) + element_mask = element_mask[:, :, :seq_q, :seq_k] + + # Step 8: Calculate sparsity statistics + # Count kept blocks (averaged across batch and heads) + kept_blocks = block_mask.sum().item() / (batch_size * num_heads) + + # Total valid blocks (lower triangle only for causal attention) + # Note: Causal mask pre-applied by attention module, so block_mask naturally + # has zeros in upper triangle. We only count lower triangle for denominator. + total_blocks = ( + num_block_rows * (num_block_rows + 1) // 2 # Causal: N(N+1)/2 + if self.is_causal + else num_block_rows * num_block_cols # Non-causal: N*N + ) + sparsity = 1 - (kept_blocks / total_blocks) + else: # decode + blocked_attn, _, num_block_cols, _, padded_seq_k = self._reshape_to_blocks( + attn_weights, 1, self.bc + ) + + # Decode: Single query row attends to all past key blocks + # blocked_attn: [batch, heads, 1, 1, num_block_cols, bc=128] + + # Step 1: Find maximum in each key block + # block_max: [batch, heads, 1, 1, num_block_cols] + block_max = blocked_attn.max(dim=-1)[0] + + # Step 2: Track cumulative maximum across key blocks (left to right) + # Simulates Flash Attention's online softmax normalization + block_max_cummax = block_max.cummax(dim=-1)[0] + + # Step 3: Calculate correction factor + # Tracks how often the maximum increases (needed for Flash Attention rescaling) + block_max_larger = torch.ones_like(block_max) + block_max_larger[..., 1:] = block_max[..., 1:] > block_max_cummax[..., :-1] + correction_factor = float(torch.sum(block_max_larger) / torch.numel(block_max_larger)) + + # Step 4: Normalize scores by cumulative max + # p = log(score) - log(cummax) in log-space + p = blocked_attn - block_max_cummax[..., None] + + # Step 5: Apply threshold and create block mask + # Keep blocks where at least one element exceeds threshold + p_larger_than_thresh = p > log_threshold + block_mask = p_larger_than_thresh.any(dim=-1, keepdim=False) + + # Step 6: Expand to element level and remove padding + element_mask = block_mask[..., None].expand_as(blocked_attn) + element_mask = element_mask.reshape(batch_size, num_heads, 1, padded_seq_k) + element_mask = element_mask[:, :, :seq_q, :seq_k] + + # Step 7: Calculate statistics + kept_blocks = block_mask.sum().item() / (batch_size * num_heads) + total_blocks = num_block_cols + sparsity = 1 - (kept_blocks / total_blocks) + + # Create stats dictionary + stats = { + "correction_factor": correction_factor if self.enable_correction_factor else 1.0, + "sparsity": sparsity, + "phase": phase, + "total_blocks": total_blocks, + "sparse_blocks": int(sparsity * total_blocks), + "sample_length": seq_k, + } + + return element_mask, stats + + def apply_sparsity( + self, + query: torch.Tensor | None = None, + key: torch.Tensor | None = None, + value: torch.Tensor | None = None, + attention_scores: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor | None]: + """Apply Flash Attention-aware block-wise sparsity. + + Args: + query: Query tensor (unused, for API compatibility) + key: Key tensor (unused, for API compatibility) + value: Value tensor (unused, for API compatibility) + attention_scores: Attention scores tensor with shape [batch, heads, seq_q, seq_k] + + Returns: + Tuple with potentially modified attention_scores + """ + # Attention scores must be provided for sparse attention + assert attention_scores is not None, "attention_scores must be provided for apply_sparsity" + + # Attention scores are always 4D: [batch, heads, seq_q, seq_k] + assert len(attention_scores.shape) == 4, ( + f"Expected 4D attention scores, got shape {attention_scores.shape}" + ) + + # Infer phase from tensor shape + phase = self._infer_phase(attention_scores) + + # Update threshold for the detected phase (skip during calibration) + if not self._calibration_mode: + self._update_threshold(phase) + + # Apply block-wise sparsity + sparse_mask, stats = self.calc_correction_factor_and_p(attention_scores, phase) + + # Store stats for module to collect (doesn't persist across calls) + self._last_stats = stats + + # Apply mask to create sparse scores + mask_value = torch.finfo(attention_scores.dtype).min + sparse_scores = attention_scores.masked_fill(~sparse_mask, mask_value) + + return query, key, value, sparse_scores + + @property + def name(self) -> str: + """Method identifier.""" + return "flash_softmax_skip" diff --git a/modelopt/torch/sparsity/attention_sparsity/methods/registry.py b/modelopt/torch/sparsity/attention_sparsity/methods/registry.py new file mode 100644 index 000000000..081ad9e27 --- /dev/null +++ b/modelopt/torch/sparsity/attention_sparsity/methods/registry.py @@ -0,0 +1,120 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Registry and base class for sparse attention methods.""" + +from abc import ABC, abstractmethod + +import torch + + +class SparseAttentionMethod(ABC): + """Base class for sparse attention methods.""" + + @abstractmethod + def apply_sparsity( + self, + query: torch.Tensor | None = None, + key: torch.Tensor | None = None, + value: torch.Tensor | None = None, + attention_scores: torch.Tensor | None = None, + ) -> tuple[torch.Tensor | None, torch.Tensor | None, torch.Tensor | None, torch.Tensor | None]: + """Apply sparsity to attention computation. + + Args: + query: Query tensor + key: Key tensor + value: Value tensor + attention_scores: Pre-computed attention scores + + Returns: + Tuple of (query, key, value, attention_scores) with sparsity applied + """ + + @property + @abstractmethod + def name(self) -> str: + """Method name identifier.""" + + +# Method Registry with versioning support +_SPARSE_ATTENTION_METHODS: dict[str, dict[str, type[SparseAttentionMethod]]] = {} + + +def register_sparse_method(name: str, version: str = "v1"): + """Decorator to register sparse attention methods with version support. + + Args: + name: Method name to register + version: Version string (default: "v1") + + Example: + @register_sparse_method("my_method", version="v3") + class MyMethodV3(SparseAttentionMethod): + ... + """ + + def decorator(cls: type[SparseAttentionMethod]): + if name not in _SPARSE_ATTENTION_METHODS: + _SPARSE_ATTENTION_METHODS[name] = {} + + if version in _SPARSE_ATTENTION_METHODS[name]: + import warnings + + warnings.warn( + f"Overriding existing sparse attention method: {name}@{version}", + RuntimeWarning, + stacklevel=2, + ) + + _SPARSE_ATTENTION_METHODS[name][version] = cls + return cls + + return decorator + + +def get_sparse_method(name: str, version: str | None = None) -> type[SparseAttentionMethod]: + """Get sparse attention method by name and optional version. + + Args: + name: Method name to retrieve + version: Optional version string. If None, uses latest version. + + Returns: + Method class + + Raises: + ValueError: If method name or version is not registered + + Example: + >>> get_sparse_method("flash_softmax_skip") # Latest version + >>> get_sparse_method("flash_softmax_skip", "v1") # Specific version + """ + if name not in _SPARSE_ATTENTION_METHODS: + available = list(_SPARSE_ATTENTION_METHODS.keys()) + raise ValueError(f"Unknown sparse attention method: {name}. Available: {available}") + + method_versions = _SPARSE_ATTENTION_METHODS[name] + + if not version: + version = sorted(method_versions.keys())[-1] + + if version not in method_versions: + available_versions = list(method_versions.keys()) + raise ValueError( + f"Unknown version {version} for method {name}. Available: {available_versions}" + ) + + return method_versions[version] diff --git a/modelopt/torch/sparsity/attention_sparsity/mode.py b/modelopt/torch/sparsity/attention_sparsity/mode.py new file mode 100644 index 000000000..f389509a5 --- /dev/null +++ b/modelopt/torch/sparsity/attention_sparsity/mode.py @@ -0,0 +1,85 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Sparse attention mode descriptor for ModelOpt.""" + +from modelopt.torch.opt.config import ModeloptBaseConfig +from modelopt.torch.opt.mode import ( + ConvertEntrypoint, + ModeDescriptor, + RestoreEntrypoint, + UpdateEntrypoint, + _ModeRegistryCls, +) + +from .config import SparseAttentionConfig +from .conversion import ( + convert_to_sparse_attention_model, + restore_sparse_attention_model, + update_sparse_attention_metadata, +) + +# Create registry for sparse attention modes +SparseAttentionModeRegistry = _ModeRegistryCls("sparse_attention") + + +@SparseAttentionModeRegistry.register_mode +class SparseAttentionModeDescriptor(ModeDescriptor): + """Mode descriptor for sparse attention optimization. + + This mode enables various sparse attention methods to reduce + computational complexity and memory usage in transformer models. + """ + + @property + def name(self) -> str: + """Returns the value (str representation) of the mode.""" + return "sparse_attention" + + @property + def config_class(self) -> type[ModeloptBaseConfig]: + """Specifies the config class for the mode.""" + return SparseAttentionConfig + + @property + def next_prohibited_modes(self) -> set[str] | None: + """Modes that should not be applied after this mode.""" + # Can work with quantization but not with weight sparsity + return {"sparsity"} + + @property + def export_mode(self) -> str | None: + """The mode that corresponds to the export mode of this mode.""" + return "export_sparse_attention" + + @property + def convert(self) -> ConvertEntrypoint: + """The mode's entrypoint for converting a model.""" + return convert_to_sparse_attention_model + + @property + def restore(self) -> RestoreEntrypoint: + """The mode's entrypoint for restoring a model.""" + return restore_sparse_attention_model + + @property + def update_for_save(self) -> UpdateEntrypoint: + """The mode's entrypoint for updating the model's state before saving.""" + return update_sparse_attention_metadata + + @property + def update_for_new_mode(self) -> UpdateEntrypoint: + """The mode's entrypoint for updating the model's state before new mode.""" + return update_sparse_attention_metadata diff --git a/modelopt/torch/sparsity/attention_sparsity/model_sparsify.py b/modelopt/torch/sparsity/attention_sparsity/model_sparsify.py new file mode 100644 index 000000000..908f3ad89 --- /dev/null +++ b/modelopt/torch/sparsity/attention_sparsity/model_sparsify.py @@ -0,0 +1,197 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Main API functions for sparse attention optimization.""" + +from typing import Any + +import torch + +from modelopt.torch.opt.conversion import ModeloptStateManager, apply_mode +from modelopt.torch.opt.searcher import ForwardLoop + +from .calibration import calibrate_sparse_attention +from .config import SparseAttentionConfig +from .mode import SparseAttentionModeRegistry + +__all__ = [ + "calibrate", + "sparsify", +] + + +def sparsify( + model: torch.nn.Module, + config: dict[str, Any] | SparseAttentionConfig, + forward_loop: ForwardLoop | None = None, +) -> torch.nn.Module: + """Applies sparse attention optimization to the model in-place. + + This method performs replacement of attention modules with their sparse counterparts and + optionally performs calibration as specified by ``config``. + ``forward_loop`` is used to forward data through the model and gather statistics for calibration. + + Args: + model: A pytorch model + config: A dictionary or an instance of + :class:`SparseAttentionConfig ` + specifying the values for keys ``"sparse_cfg"``, ``"method"``, and optionally ``"calibration"``. + + The ``"sparse_cfg"`` key specifies the sparse attention configurations. + The ``"method"`` key specifies the sparse attention method (e.g., "softmax_skip"). + The ``"calibration"`` key specifies calibration settings if automatic threshold tuning is desired. + + Sparse attention configurations is a dictionary mapping wildcards or filter functions + to its sparse attention attributes. The wildcards or filter functions are matched + against the module names. The sparse attention attributes include ``"threshold"``, + ``"enable"``, and method-specific parameters. + + An example ``config`` dictionary is given below: + + .. code-block::python + + config = { + "method": "softmax_skip", + "sparse_cfg": { + # Phase-aware thresholds with backend selection and calibration + "*attention*": { + "threshold": {"prefill": 1e-3, "decode": 1e-5}, + "backend": "pytorch", # Only pytorch backend supported + "enable": True, + "calibration": { # Optional: enables automatic threshold calibration + "target_sparse_ratio": 0.5, + "samples": 48, + "max_seqlen": 8192, + }, + }, + # Disable for specific layers + "*layer_0*": {"enable": False}, + # Default settings + "default": {"enable": False}, + }, + } + + The ``"backend"`` parameter must be set to ``"pytorch"``: + + - ``"pytorch"``: Softmax patching approach (only supported backend) + + This requires the model to be loaded with ``attn_implementation="eager"``. + + forward_loop: A callable that forwards all calibration data through the model. This is used + to gather statistics for calibration. It should take model as the argument. It does not need + to return anything. + + This argument is only required when calibration is enabled in the config. + + Here are a few examples for correct ``forward_loop`` definitions: + + Example 1: + + .. code-block:: + + def forward_loop(model) -> None: + # iterate over the data loader and forward data through the model + for batch in data_loader: + model(batch) + + Example 2: + + .. code-block:: + + def forward_loop(model) -> float: + # evaluate the model on the task + return evaluate(model, task, ....) + + .. note:: + + Calibration does not require forwarding the entire dataset through the model. + Please subsample the dataset or reduce the number of batches if needed. + + .. important:: + + The model must always be loaded with ``attn_implementation="eager"`` + for sparse attention to work correctly: + + .. code-block:: python + + from transformers import AutoModelForCausalLM + + model = AutoModelForCausalLM.from_pretrained( + model_path, + attn_implementation="eager", # Required for sparse attention + torch_dtype=torch.bfloat16, + ) + + This is because sparse attention works by patching torch.nn.functional.softmax, + which is only called in the eager attention implementation. + + Returns: + A pytorch model which has sparse attention applied and optionally calibrated. + """ + model = apply_mode( + model, mode=[("sparse_attention", config)], registry=SparseAttentionModeRegistry + ) + + # Calibrate the sparsity ratio of the attention modules + return calibrate(model, forward_loop=forward_loop) + + +def calibrate( + model: torch.nn.Module, + forward_loop: ForwardLoop | None = None, +) -> torch.nn.Module: + """Calibrates sparse attention thresholds based on target sparsity. + + This function performs calibration to find optimal thresholds that achieve + the target sparsity ratio specified in the sparse attention configuration. + + Args: + model: A pytorch model with sparse attention already applied + forward_loop: Optional callable that forwards calibration data through the model. + It should take model as the argument and can optionally return metrics. + If None, will auto-generate RULER dataset for calibration. + + Returns: + The calibrated model with optimized sparse attention thresholds. + If no calibration is configured, returns the model unchanged. + """ + # Get the sparse attention config from the model's state + if not ModeloptStateManager.is_converted(model): + return model + + manager = ModeloptStateManager(model) + + sparse_attn_config = next( + (state["config"] for name, state in manager._state if name == "sparse_attention"), None + ) + + if sparse_attn_config is None: + return model + + # Check if calibration is configured in any sparse_cfg pattern + # Note: sparse_attn_config is always a dict (stored via config.model_dump()) + sparse_cfg = sparse_attn_config.get("sparse_cfg", {}) + + has_calibration = any( + isinstance(cfg, dict) and "calibration" in cfg for cfg in sparse_cfg.values() + ) + + if not has_calibration: + return model + + # Run calibration (handles stats collection internally) + calibrate_sparse_attention(model, sparse_attn_config, forward_loop=forward_loop) + + return model diff --git a/modelopt/torch/sparsity/attention_sparsity/nn/__init__.py b/modelopt/torch/sparsity/attention_sparsity/nn/__init__.py new file mode 100644 index 000000000..00ff275bc --- /dev/null +++ b/modelopt/torch/sparsity/attention_sparsity/nn/__init__.py @@ -0,0 +1,20 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Neural network modules for sparse attention.""" + +from .sparse_attention import SparseAttentionModule, SparseAttentionRegistry + +__all__ = ["SparseAttentionModule", "SparseAttentionRegistry"] diff --git a/modelopt/torch/sparsity/attention_sparsity/nn/sparse_attention.py b/modelopt/torch/sparsity/attention_sparsity/nn/sparse_attention.py new file mode 100644 index 000000000..a45931224 --- /dev/null +++ b/modelopt/torch/sparsity/attention_sparsity/nn/sparse_attention.py @@ -0,0 +1,205 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Extensible sparse attention module.""" + +import torch +import torch.nn.functional as F + +from modelopt.torch.opt.dynamic import DynamicModule, _DMRegistryCls +from modelopt.torch.quantization.utils import replace_function + +from ..config import SparseAttentionAttributeConfig +from ..methods import get_sparse_method +from .stats_manager import SparseAttentionStatsManager + + +class SparseAttentionModule(DynamicModule): + """Generic sparse attention module wrapper for applying sparsity to attention layers. + + This module wraps existing attention implementations to add sparse attention + capabilities by patching torch.nn.functional.softmax. + + Forward Flow: + ------------- + 1. Check if sparse attention is enabled (pass-through if disabled) + 2. Create softmax patch context with sparse_softmax function + 3. Apply sparse attention by patching F.softmax: + - Patches torch.nn.functional.softmax with sparse_softmax + - sparse_softmax applies method's sparsity logic before softmax + 4. Forward through original attention with sparsity applied + + Requirements: + ------------- + - Model must be loaded with attn_implementation="eager" for proper softmax interception + - Only PyTorch backend is supported (patches F.softmax) + + Attributes: + ----------- + _enabled: bool + Whether sparse attention is enabled + _method: str + The sparse attention method to use (e.g., "flash_softmax_skip") + _method_config: dict + Configuration dictionary for the sparse method (threshold, br, bc, etc.) + _sparse_method_instance: SparseAttentionMethod + Instance of the configured sparse attention method + """ + + def set_from_attribute_config( + self, attribute_cfg: SparseAttentionAttributeConfig | dict | None = None + ): + """Set sparse attention attributes from configuration. + + Similar to TensorQuantizer.set_from_attribute_config. + + Args: + attribute_cfg: Sparse attention attribute configuration. + If None, uses default SparseAttentionAttributeConfig. + """ + # Use default config if not provided + attribute_cfg = ( + attribute_cfg if attribute_cfg is not None else SparseAttentionAttributeConfig() + ) + + # Store raw config for method initialization + self._method_config = {} + + # Define which attributes are method-specific vs module-specific + # Module-specific attributes control the SparseAttentionModule behavior + _module_attributes = {"enable", "method"} + + # Custom setters for special module attributes + _custom_setters = { + "enable": ("_enabled", lambda val: bool(val)), + "method": ("_method", lambda val: str(val)), + } + + # Process each attribute from config + for attribute, val in attribute_cfg.items(): + # Validate attribute if using config class + if hasattr(SparseAttentionAttributeConfig, "model_fields"): + assert attribute in SparseAttentionAttributeConfig.model_fields, ( + f"{attribute} is not a valid SparseAttentionModule attribute" + ) + + if attribute in _module_attributes: + # Module-level attribute: store with underscore prefix + attr_name, setter = _custom_setters.get(attribute, (f"_{attribute}", lambda v: v)) + setattr(self, attr_name, setter(val)) + else: + # Method-specific attribute: store in config dict + self._method_config[attribute] = val + + # Initialize sparse method instance + self._init_sparse_method() + + def _init_sparse_method(self): + """Initialize the sparse method instance.""" + method_class = get_sparse_method(self._method) + + # Initialize the sparse method instance + # _method_config is always initialized in set_from_attribute_config + self._sparse_method_instance = method_class(method_config=self._method_config) # type: ignore[call-arg] + + def enable(self): + """Enable sparse attention for this module.""" + self._enabled = True + + def disable(self): + """Disable sparse attention for this module.""" + self._enabled = False + + @property + def is_enabled(self) -> bool: + """Check if sparse attention is enabled.""" + return getattr(self, "_enabled", True) + + def get_stats(self) -> dict: + """Get sparsity statistics from the stats manager. + + Returns: + Dictionary with sparsity statistics including 'average_sparsity' if available. + Returns empty dict if stats manager is not enabled. + """ + if self._stats_manager is not None and self._stats_manager.enabled: + return self._stats_manager.get_summary() + return {} + + def _setup(self): + """Setup called by DynamicModule.""" + # Apply default configuration if not yet configured + if not hasattr(self, "_method"): + self.set_from_attribute_config(None) + + # Create stats manager if stats collection is enabled + if self._method_config.get("collect_stats", False): + self._stats_manager = SparseAttentionStatsManager( + module_name="sparse_attention", enabled=True + ) + else: + self._stats_manager = None + + def forward(self, *args, **kwargs): + """Forward with selected sparse attention method. + + This method dispatches to the appropriate sparse attention implementation + based on the configured method and backend. + """ + # Pass through if sparse attention is disabled + if not self.is_enabled: + return super().forward(*args, **kwargs) + + # Get the appropriate context manager for this configuration + context = self._get_sparse_context() + + # Apply sparse attention through the context + with context: + result = super().forward(*args, **kwargs) + + # Collect stats if manager is available + if self._stats_manager is not None and hasattr(self._sparse_method_instance, "_last_stats"): + self._stats_manager.collect(self._sparse_method_instance._last_stats) + + return result + + def _get_sparse_context(self): + """Get the softmax patch context for applying sparse attention.""" + return self._create_softmax_patch_context() + + def _create_softmax_patch_context(self): + """Create context manager for patching softmax function.""" + return replace_function(torch.nn.functional, "softmax", self._create_sparse_softmax()) + + def _create_sparse_softmax(self): + """Create sparse softmax function for current method.""" + original_softmax = F.softmax + + def sparse_softmax(input, dim=-1, *args, **kwargs): + # Let the method handle the sparsification + _, _, _, sparse_input = self._sparse_method_instance.apply_sparsity( + None, None, None, input + ) + + # Use sparse input if modified, otherwise use original + if sparse_input is not None: + return original_softmax(sparse_input, dim, *args, **kwargs) + return original_softmax(input, dim, *args, **kwargs) + + return sparse_softmax + + +# Create registry for sparse attention modules +SparseAttentionRegistry = _DMRegistryCls("SparseAttention", SparseAttentionModule) diff --git a/modelopt/torch/sparsity/attention_sparsity/plugins/__init__.py b/modelopt/torch/sparsity/attention_sparsity/plugins/__init__.py new file mode 100644 index 000000000..ba8c8b821 --- /dev/null +++ b/modelopt/torch/sparsity/attention_sparsity/plugins/__init__.py @@ -0,0 +1,22 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Plugins for sparse attention integration with various frameworks.""" + +from .huggingface import register_sparse_attention_on_the_fly + +__all__ = [ + "register_sparse_attention_on_the_fly", +] diff --git a/modelopt/torch/sparsity/attention_sparsity/plugins/huggingface.py b/modelopt/torch/sparsity/attention_sparsity/plugins/huggingface.py new file mode 100644 index 000000000..0012257b6 --- /dev/null +++ b/modelopt/torch/sparsity/attention_sparsity/plugins/huggingface.py @@ -0,0 +1,122 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Dynamic sparse attention registration for HuggingFace models.""" + +import torch.nn as nn +import transformers + +from modelopt.torch.opt.dynamic import DynamicModule + +from ..nn.sparse_attention import SparseAttentionModule, SparseAttentionRegistry + + +class _GenericSparseAttention(SparseAttentionModule): + """Generic sparse attention that works with any HF attention module. + + This class provides a universal sparse attention wrapper that can + work with various transformer attention implementations. + """ + + def _setup(self): + """Setup sparse attention for any attention type. + + The base SparseAttentionModule handles detection and initialization. + """ + super()._setup() + + def get_attn_type(self, attn_module) -> type: + """Get the original attention type. + + Args: + attn_module: Attention module (possibly wrapped) + + Returns: + Original class type + """ + # If this is a DynamicModule, get the original class + if isinstance(attn_module, DynamicModule): + return attn_module.get_original_cls_by_level(level=0) + return type(attn_module) + + +def register_sparse_attention_on_the_fly(model: nn.Module) -> bool: + """Dynamically register sparse attention for any model. + + This function automatically detects attention modules in the model + and registers them for sparse attention optimization. + + Args: + model: Model to process + + Returns: + True if any modules were registered + """ + if not _is_supported_model(model): + return False + + registered_count = 0 + attention_types = set() + + for name, module in model.named_modules(): + # Skip if already a sparse attention module + if isinstance(module, SparseAttentionModule): + continue + + # Check if this is an attention module by name + module_type = type(module) + type_name = module_type.__name__ + + # Common attention module patterns + is_attention = ( + "attention" in type_name.lower() + or type_name.endswith("Attention") + or type_name.endswith("SelfAttention") + ) + + if is_attention and module_type not in SparseAttentionRegistry: + # Register attention type + if module_type not in attention_types: + SparseAttentionRegistry.register({module_type: type_name})(_GenericSparseAttention) + attention_types.add(module_type) + registered_count += 1 + print(f"Registered {type_name} for sparse attention optimization") + + if registered_count > 0: + print(f"Dynamically registered {registered_count} attention module types for sparsity") + + return registered_count > 0 + + +def _is_supported_model(model: nn.Module) -> bool: + """Check if model is supported for sparse attention. + + Supports HuggingFace PreTrainedModel and any PyTorch model with attention modules. + + Args: + model: Model to check + + Returns: + True if model is supported + """ + # Check for HuggingFace PreTrainedModel + try: + if isinstance(model, transformers.PreTrainedModel): + return True + except ImportError: + pass + + # Support any PyTorch model with attention modules + return isinstance(model, nn.Module) From 5d027e0a3e7cb01097f1422030d63080c35a2fda Mon Sep 17 00:00:00 2001 From: Kai Xu Date: Wed, 5 Nov 2025 17:21:28 -0800 Subject: [PATCH 2/2] Add unit and GPU tests for core sparse attention functionality Signed-off-by: Kai Xu --- .../attention_sparsity/hf_sa.py} | 75 ++--- .../{ => weight_sparsity}/.gitignore | 0 .../{ => weight_sparsity}/README.md | 0 .../{ => weight_sparsity}/data_prep.py | 0 .../{ => weight_sparsity}/eval.py | 0 .../export_trtllm_ckpt.py | 0 .../{ => weight_sparsity}/finetune.py | 15 + .../{ => weight_sparsity}/hf_pts.py | 0 .../{ => weight_sparsity}/launch_finetune.sh | 0 .../{ => weight_sparsity}/requirements.txt | 0 .../{ => weight_sparsity}/utils.py | 0 .../calibration/__init__.py | 26 -- .../sparsity/attention_sparsity/config.py | 164 ++-------- .../sparsity/attention_sparsity/conversion.py | 57 +--- .../attention_sparsity/methods/__init__.py | 2 +- ..._softmax_skip.py => flash_skip_softmax.py} | 48 +-- .../attention_sparsity/methods/registry.py | 37 ++- .../attention_sparsity/model_sparsify.py | 82 +---- .../attention_sparsity/nn/__init__.py | 20 -- .../attention_sparsity/plugins/huggingface.py | 18 +- .../{nn => }/sparse_attention.py | 36 +-- .../torch_sparsity/sparse_attention_common.py | 195 ++++++++++++ .../test_attention_sparsity.py | 52 ++++ .../test_llama_sparsify.py | 8 +- .../test_attention_sparsity_gpu.py | 144 +++++++++ .../test_integration_gpu.py | 190 ++++++++++++ .../test_flash_skip_softmax.py | 282 ++++++++++++++++++ .../test_sparse_attention_config.py | 129 ++++++++ .../test_sparse_attention_conversion.py | 208 +++++++++++++ .../test_sparse_attention_mode.py | 43 +++ 30 files changed, 1405 insertions(+), 426 deletions(-) rename examples/{llm_sparse_attention/hf_spar_attn.py => llm_sparsity/attention_sparsity/hf_sa.py} (83%) rename examples/llm_sparsity/{ => weight_sparsity}/.gitignore (100%) rename examples/llm_sparsity/{ => weight_sparsity}/README.md (100%) rename examples/llm_sparsity/{ => weight_sparsity}/data_prep.py (100%) rename examples/llm_sparsity/{ => weight_sparsity}/eval.py (100%) rename examples/llm_sparsity/{ => weight_sparsity}/export_trtllm_ckpt.py (100%) rename examples/llm_sparsity/{ => weight_sparsity}/finetune.py (95%) rename examples/llm_sparsity/{ => weight_sparsity}/hf_pts.py (100%) rename examples/llm_sparsity/{ => weight_sparsity}/launch_finetune.sh (100%) rename examples/llm_sparsity/{ => weight_sparsity}/requirements.txt (100%) rename examples/llm_sparsity/{ => weight_sparsity}/utils.py (100%) delete mode 100644 modelopt/torch/sparsity/attention_sparsity/calibration/__init__.py rename modelopt/torch/sparsity/attention_sparsity/methods/{flash_softmax_skip.py => flash_skip_softmax.py} (90%) delete mode 100644 modelopt/torch/sparsity/attention_sparsity/nn/__init__.py rename modelopt/torch/sparsity/attention_sparsity/{nn => }/sparse_attention.py (83%) create mode 100644 tests/_test_utils/torch_sparsity/sparse_attention_common.py create mode 100644 tests/examples/llm_sparsity/attention_sparsity/test_attention_sparsity.py rename tests/examples/llm_sparsity/{ => weight_sparsity}/test_llama_sparsify.py (93%) create mode 100644 tests/gpu/torch/sparsity/attention_sparsity/test_attention_sparsity_gpu.py create mode 100644 tests/gpu/torch/sparsity/attention_sparsity/test_integration_gpu.py create mode 100644 tests/unit/torch/sparsity/attention_sparsity/test_flash_skip_softmax.py create mode 100644 tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_config.py create mode 100644 tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_conversion.py create mode 100644 tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_mode.py diff --git a/examples/llm_sparse_attention/hf_spar_attn.py b/examples/llm_sparsity/attention_sparsity/hf_sa.py similarity index 83% rename from examples/llm_sparse_attention/hf_spar_attn.py rename to examples/llm_sparsity/attention_sparsity/hf_sa.py index 461af581e..2f68cfa68 100644 --- a/examples/llm_sparse_attention/hf_spar_attn.py +++ b/examples/llm_sparsity/attention_sparsity/hf_sa.py @@ -22,64 +22,43 @@ import numpy as np import torch -import torch.nn as nn from datasets import load_dataset from transformers import AutoModelForCausalLM, AutoTokenizer +import modelopt.torch.opt as mto import modelopt.torch.sparsity.attention_sparsity as mtsa from modelopt.torch.export import export_hf_checkpoint from modelopt.torch.sparsity.attention_sparsity import SparseAttentionConfig -from modelopt.torch.sparsity.attention_sparsity.config import ( - SKIP_SOFTMAX_CALIB, - SKIP_SOFTMAX_DEFAULT, -) -from modelopt.torch.sparsity.attention_sparsity.nn.sparse_attention import SparseAttentionModule +from modelopt.torch.sparsity.attention_sparsity.config import SKIP_SOFTMAX_DEFAULT +from modelopt.torch.sparsity.attention_sparsity.sparse_attention import SparseAttentionModule from modelopt.torch.utils.memory_monitor import launch_memory_monitor RAND_SEED = 1234 +# Enable HuggingFace checkpointing support +mto.enable_huggingface_checkpointing() + # You can define custom configurations or use the default SPARSE_ATTN_CFG_CHOICES = { "skip_softmax": SKIP_SOFTMAX_DEFAULT, - "skip_softmax_calib": SKIP_SOFTMAX_CALIB, } -def print_sparsity_stats(model: nn.Module): - """Print sparsity statistics if available.""" - module_stats = [] - for name, module in model.named_modules(): - if hasattr(module, "get_stats"): - stats = module.get_stats() - if stats and "average_sparsity" in stats: - module_stats.append((name, stats["average_sparsity"])) - - if not module_stats: - print("No sparsity statistics available") - return - - # Check if all modules have the same sparsity - sparsities = [s for _, s in module_stats] - if len(set(sparsities)) == 1: - # All identical - show summary - print(f"Average sparsity across all {len(module_stats)} modules: {sparsities[0]:.2%}") - else: - # Different sparsities - show individual values - avg_sparsity = sum(sparsities) / len(sparsities) - print(f"Average sparsity: {avg_sparsity:.2%}") - print("Per-module breakdown:") - for name, sparsity in module_stats: - print(f" {name}: {sparsity:.2%} sparse") - - def get_narrativeqa_samples(num_samples=3): """Load samples from NarrativeQA dataset for testing. Args: num_samples: Number of samples to generate + + Raises: + RuntimeError: If dataset loading fails + ValueError: If no valid samples could be loaded """ - # Load NarrativeQA dataset - dataset = load_dataset("narrativeqa", split="test", streaming=True) + # Load NarrativeQA dataset with retry logic + try: + dataset = load_dataset("narrativeqa", split="test", streaming=True) + except Exception as e: + raise RuntimeError(f"Failed to load NarrativeQA dataset: {e}") samples = [] for i, item in enumerate(dataset): @@ -120,8 +99,10 @@ def truncate_text(text: str, tokenizer, max_length: int): return text # Need to truncate - preserve beginning and end - # Reserve some tokens for special tokens - available_tokens = max_length - 2 # Account for special tokens + # Calculate actual special tokens used + dummy_tokens = tokenizer.encode("", add_special_tokens=True) + special_token_count = len(dummy_tokens) + available_tokens = max_length - special_token_count # Split tokens roughly in half for beginning and end begin_tokens = available_tokens // 2 @@ -173,9 +154,7 @@ def verify_outputs(model, tokenizer, args): print("BASELINE vs SPARSE ATTENTION COMPARISON") print("=" * 60) print(f"\nTest prompt: {display_prompt}") - print(f"Input tokens: {inputs['input_ids'].shape[1]} (max: {args.seq_len})") - if "[...]" in truncated_prompt: - print("Note: Text was middle-truncated to fit token limit") + print(f"Input tokens: {inputs['input_ids'].shape[1]}") # Helper function to generate text def generate_text(model, inputs, args, tokenizer): @@ -235,23 +214,13 @@ def sparsify_model(model, args): modified_sparse_cfg[pattern] = modified_cfg # Create new config with modified settings - sparse_config = SparseAttentionConfig( - method=base_config["method"], - sparse_cfg=modified_sparse_cfg, - collect_stats=True, # Enable stats collection for monitoring - ) + sparse_config = SparseAttentionConfig(sparse_cfg=modified_sparse_cfg) - # Sparsify with optional calibration - framework handles calibration automatically + # Sparsify the model model = mtsa.sparsify(model, config=sparse_config) print("Sparse attention applied successfully!") - # Show sparsity statistics - print("\n" + "=" * 60) - print("Sparsity Statistics") - print("=" * 60) - print_sparsity_stats(model) - return model diff --git a/examples/llm_sparsity/.gitignore b/examples/llm_sparsity/weight_sparsity/.gitignore similarity index 100% rename from examples/llm_sparsity/.gitignore rename to examples/llm_sparsity/weight_sparsity/.gitignore diff --git a/examples/llm_sparsity/README.md b/examples/llm_sparsity/weight_sparsity/README.md similarity index 100% rename from examples/llm_sparsity/README.md rename to examples/llm_sparsity/weight_sparsity/README.md diff --git a/examples/llm_sparsity/data_prep.py b/examples/llm_sparsity/weight_sparsity/data_prep.py similarity index 100% rename from examples/llm_sparsity/data_prep.py rename to examples/llm_sparsity/weight_sparsity/data_prep.py diff --git a/examples/llm_sparsity/eval.py b/examples/llm_sparsity/weight_sparsity/eval.py similarity index 100% rename from examples/llm_sparsity/eval.py rename to examples/llm_sparsity/weight_sparsity/eval.py diff --git a/examples/llm_sparsity/export_trtllm_ckpt.py b/examples/llm_sparsity/weight_sparsity/export_trtllm_ckpt.py similarity index 100% rename from examples/llm_sparsity/export_trtllm_ckpt.py rename to examples/llm_sparsity/weight_sparsity/export_trtllm_ckpt.py diff --git a/examples/llm_sparsity/finetune.py b/examples/llm_sparsity/weight_sparsity/finetune.py similarity index 95% rename from examples/llm_sparsity/finetune.py rename to examples/llm_sparsity/weight_sparsity/finetune.py index 3cfc1073f..869068dbd 100644 --- a/examples/llm_sparsity/finetune.py +++ b/examples/llm_sparsity/weight_sparsity/finetune.py @@ -1,3 +1,18 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + # Adapted from https://github.com/tatsu-lab/stanford_alpaca/blob/3783d18/train.py # Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li diff --git a/examples/llm_sparsity/hf_pts.py b/examples/llm_sparsity/weight_sparsity/hf_pts.py similarity index 100% rename from examples/llm_sparsity/hf_pts.py rename to examples/llm_sparsity/weight_sparsity/hf_pts.py diff --git a/examples/llm_sparsity/launch_finetune.sh b/examples/llm_sparsity/weight_sparsity/launch_finetune.sh similarity index 100% rename from examples/llm_sparsity/launch_finetune.sh rename to examples/llm_sparsity/weight_sparsity/launch_finetune.sh diff --git a/examples/llm_sparsity/requirements.txt b/examples/llm_sparsity/weight_sparsity/requirements.txt similarity index 100% rename from examples/llm_sparsity/requirements.txt rename to examples/llm_sparsity/weight_sparsity/requirements.txt diff --git a/examples/llm_sparsity/utils.py b/examples/llm_sparsity/weight_sparsity/utils.py similarity index 100% rename from examples/llm_sparsity/utils.py rename to examples/llm_sparsity/weight_sparsity/utils.py diff --git a/modelopt/torch/sparsity/attention_sparsity/calibration/__init__.py b/modelopt/torch/sparsity/attention_sparsity/calibration/__init__.py deleted file mode 100644 index 3b616e8e3..000000000 --- a/modelopt/torch/sparsity/attention_sparsity/calibration/__init__.py +++ /dev/null @@ -1,26 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Calibration framework for sparse attention methods.""" - -from .calibrate import calibrate_sparse_attention -from .calibrator import DynamicThresholdCalibrator -from .dataset import RulerDatasetBuilder - -__all__ = [ - "DynamicThresholdCalibrator", - "RulerDatasetBuilder", - "calibrate_sparse_attention", -] diff --git a/modelopt/torch/sparsity/attention_sparsity/config.py b/modelopt/torch/sparsity/attention_sparsity/config.py index 5fdab0032..e72dacc94 100644 --- a/modelopt/torch/sparsity/attention_sparsity/config.py +++ b/modelopt/torch/sparsity/attention_sparsity/config.py @@ -34,18 +34,18 @@ class SparseAttentionAttributeConfig(ModeloptBaseConfig): """Sparse attention attribute configuration for pattern-based module config.""" + method: str = ModeloptField( + default="flash_skip_softmax", + title="Sparse attention method.", + description="The sparse attention method to use (e.g., 'flash_skip_softmax').", + ) + enable: bool = ModeloptField( default=True, title="Enable sparse attention.", description="If True, enables sparse attention. If False, bypasses sparsity.", ) - method: str = ModeloptField( - default="flash_softmax_skip", - title="Sparse attention method.", - description="The sparse attention method to use (e.g., 'flash_softmax_skip').", - ) - threshold: float | dict[str, float] = ModeloptField( default=1e-3, title="Sparsity threshold.", @@ -67,12 +67,6 @@ class SparseAttentionAttributeConfig(ModeloptBaseConfig): description="Block column size for block-wise sparsity in Flash Attention.", ) - collect_stats: bool = ModeloptField( - default=False, - title="Collect statistics.", - description="Whether to collect sparsity statistics during forward pass.", - ) - backend: str = ModeloptField( default="pytorch", title="Backend implementation.", @@ -156,103 +150,12 @@ def validate_threshold(cls, v): return v -class CalibrationConfig(ModeloptBaseConfig): - """Configuration for automatic threshold calibration using RULER dataset. - - Calibration learns a dynamic threshold λ = scale_factor / sequence_length that - achieves target sparsity. Only supports prefill phase (seq_len > 1). - """ - - target_sparse_ratio: float = ModeloptField( - default=0.5, - title="Target sparsity ratio", - description="Target ratio of sparse attention blocks (0.0 to 1.0).", - ) - - samples: int = ModeloptField( - default=24, - title="Calibration samples", - description="Total number of RULER samples for calibration (distributed across length bins).", - ) - - max_seqlen: int = ModeloptField( - default=32768, - title="Maximum sequence length", - description="Maximum sequence length for calibration (length bins auto-generated as powers of 2).", - ) - - num_length_bins: int = ModeloptField( - default=4, - title="Number of length bins", - description="Number of length bins to generate (hidden parameter, default: 4).", - ) - - threshold_trials: list[float] | None = ModeloptField( - default=None, - title="Threshold trials", - description=( - "List of threshold values to test during calibration. " - "If None, uses default: [1e-6, 5e-6, 1e-5, 5e-5, 1e-4, 5e-4, 1e-3, 5e-3, 1e-2, 5e-2, 1e-1, 5e-1]" - ), - ) - - @field_validator("threshold_trials") - @classmethod - def validate_threshold_trials(cls, v): - """Validate threshold_trials are in valid range.""" - if v is not None: - if not isinstance(v, list): - raise ValueError(f"threshold_trials must be a list, got {type(v)}") - if len(v) == 0: - raise ValueError("threshold_trials must not be empty") - for threshold in v: - if not isinstance(threshold, (int, float)): - raise ValueError(f"All threshold_trials must be numbers, got {type(threshold)}") - if threshold <= 0 or threshold >= 1: - raise ValueError( - f"All threshold_trials must be in range (0, 1), got {threshold}" - ) - return v - - @field_validator("target_sparse_ratio") - @classmethod - def validate_target_sparse_ratio(cls, v): - """Validate target sparsity ratio is between 0 and 1.""" - if not 0.0 <= v <= 1.0: - raise ValueError(f"target_sparse_ratio must be between 0.0 and 1.0, got {v}") - return v - - @field_validator("samples") - @classmethod - def validate_samples(cls, v): - """Validate samples is positive.""" - if v <= 0: - raise ValueError(f"samples must be positive, got {v}") - return v - - @field_validator("max_seqlen") - @classmethod - def validate_max_seqlen(cls, v): - """Validate max_seqlen is at least 1024.""" - if v < 1024: - raise ValueError(f"max_seqlen must be >= 1024, got {v}") - return v - - @field_validator("num_length_bins") - @classmethod - def validate_num_length_bins(cls, v): - """Validate num_length_bins is positive.""" - if v <= 0: - raise ValueError(f"num_length_bins must be positive, got {v}") - return v - - # Pre-defined Sparse Attention Configuration # Default configuration with block-wise sparsity optimized for Flash Attention SKIP_SOFTMAX_DEFAULT = { - "method": "flash_softmax_skip", "sparse_cfg": { "*attn*": { + "method": "flash_skip_softmax", "threshold": { "prefill": 1e-3, # More aggressive during prefill "decode": 1e-4, # Conservative during decode @@ -267,28 +170,6 @@ def validate_num_length_bins(cls, v): } -# Configuration with RULER calibration -# Note: threshold field is omitted - calibration determines dynamic threshold λ = a / length -# The calibrated threshold adapts to sequence length for optimal sparsity -SKIP_SOFTMAX_CALIB = { - "method": "flash_softmax_skip", - "sparse_cfg": { - "*attn*": { - "br": 128, - "bc": 128, - "backend": "pytorch", # Only pytorch backend supported - "enable": True, - "calibration": { - "target_sparse_ratio": 0.5, - "samples": 120, - "max_seqlen": 8192, - }, - }, - "default": {"enable": False}, - }, -} - - class SparseAttentionConfig(ModeloptBaseConfig): """Base configuration for sparse attention optimization. @@ -296,17 +177,12 @@ class SparseAttentionConfig(ModeloptBaseConfig): attention methods and supports pattern-based layer configuration. """ - # Method selection - method: str = Field("flash_softmax_skip", description="Sparse attention method to use") - - # Statistics collection - collect_stats: bool = Field( - False, description="Whether to collect sparsity statistics during forward pass" - ) - # Pattern-based sparse configuration (similar to quant_cfg in quantization) sparse_cfg: SparseAttentionCfgType = ModeloptField( - default={"*attention*": {"enable": True}, "default": {"enable": False}}, + default={ + "*attention*": {"method": "flash_skip_softmax", "enable": True}, + "default": {"enable": False}, + }, title="Sparse attention configuration", description="Pattern-based configuration for sparse attention. Keys are patterns to match module names, " "values are configuration dicts with parameters like 'threshold', 'enable', and 'calibration'.", @@ -319,19 +195,15 @@ class SparseAttentionConfig(ModeloptBaseConfig): ) -class FlashSoftmaxSkipConfig(SparseAttentionConfig): +class FlashSkipSoftmaxConfig(SparseAttentionConfig): """Configuration for Flash Attention-aware softmax skip sparse attention.""" - # Override method to default to flash_softmax_skip - method: str = Field( - "flash_softmax_skip", description="Sparse attention method (fixed to flash_softmax_skip)" - ) - - # Override sparse_cfg with flash_softmax_skip specific defaults + # Override sparse_cfg with flash_skip_softmax specific defaults sparse_cfg: SparseAttentionCfgType = ModeloptField( default={ "*attention*": { - "threshold": {"prefill": 1e-3, "decode": 1e-5}, + "method": "flash_skip_softmax", + "threshold": {"prefill": 1e-3, "decode": 1e-4}, "br": 128, # Flash Attention block rows "bc": 128, # Flash Attention block columns "backend": "pytorch", # Only pytorch backend supported @@ -340,17 +212,15 @@ class FlashSoftmaxSkipConfig(SparseAttentionConfig): "default": {"enable": False}, }, title="Flash softmax skip sparse configuration", - description="Pattern-based configuration with flash_softmax_skip specific defaults. " + description="Pattern-based configuration with flash_skip_softmax specific defaults. " "Includes FA block sizes (br, bc) and correction factor settings.", validate_default=True, ) __all__ = [ - "SKIP_SOFTMAX_CALIB", "SKIP_SOFTMAX_DEFAULT", - "CalibrationConfig", - "FlashSoftmaxSkipConfig", + "FlashSkipSoftmaxConfig", "SparseAttentionAttributeConfig", "SparseAttentionCfgType", "SparseAttentionConfig", diff --git a/modelopt/torch/sparsity/attention_sparsity/conversion.py b/modelopt/torch/sparsity/attention_sparsity/conversion.py index 028e2bb67..25347c37f 100644 --- a/modelopt/torch/sparsity/attention_sparsity/conversion.py +++ b/modelopt/torch/sparsity/attention_sparsity/conversion.py @@ -26,8 +26,8 @@ from modelopt.torch.utils import get_unwrapped_name from .config import SparseAttentionConfig -from .nn.sparse_attention import SparseAttentionModule, SparseAttentionRegistry from .plugins.huggingface import register_sparse_attention_on_the_fly +from .sparse_attention import SparseAttentionModule, SparseAttentionRegistry def is_attn_sparsified(model: nn.Module) -> bool: @@ -67,7 +67,7 @@ def convert_to_sparse_attention_model( # Apply configuration to sparse attention modules sparse_cfg = config.sparse_cfg if hasattr(config, "sparse_cfg") else {} - set_sparse_attention_by_cfg(model, sparse_cfg, config) + set_sparse_attention_by_cfg(model, sparse_cfg) # Create metadata metadata = {} @@ -106,33 +106,31 @@ def _replace_sparse_attention_modules(model: nn.Module, version=None): _replace_sparse_attention_modules(getattr(model, name), version=version) -def set_sparse_attention_by_cfg(model: nn.Module, sparse_cfg: dict, config: SparseAttentionConfig): +def set_sparse_attention_by_cfg(model: nn.Module, sparse_cfg: dict): """Apply sparse attention configuration to model. Similar to quantization's set_quantizer_by_cfg. Args: model: Model with sparse attention modules - sparse_cfg: Sparse configuration dictionary - config: Global sparse attention configuration + sparse_cfg: Sparse configuration dictionary mapping patterns to attributes """ sparse_cfg = sparse_cfg.copy() # Apply default first if exists if "default" in sparse_cfg: - set_sparse_attention_attribute(model, "*", sparse_cfg["default"], config) + set_sparse_attention_attribute(model, "*", sparse_cfg["default"]) sparse_cfg.pop("default") # Apply pattern-specific configs for pattern, cfg in sparse_cfg.items(): - set_sparse_attention_attribute(model, pattern, cfg, config) + set_sparse_attention_attribute(model, pattern, cfg) def set_sparse_attention_attribute( model: nn.Module, wildcard_or_filter: str | Callable, attribute_cfg: dict[str, Any], - global_config: SparseAttentionConfig, ): """Set sparse attention attributes for modules matching pattern. @@ -141,19 +139,11 @@ def set_sparse_attention_attribute( Args: model: Model to configure wildcard_or_filter: Pattern to match module names - attribute_cfg: Attributes to apply - global_config: Global sparse attention configuration + attribute_cfg: Attributes to apply (must include 'method') """ - # Merge global config fields with pattern config # Filter out model-level configs that shouldn't be passed to modules module_cfg = {k: v for k, v in attribute_cfg.items() if k != "calibration"} - full_cfg = { - "method": global_config.method, - "collect_stats": global_config.collect_stats, - **module_cfg, - } - for name, module in model.named_modules(): if not isinstance(module, SparseAttentionModule): continue @@ -165,11 +155,11 @@ def set_sparse_attention_attribute( elif callable(wildcard_or_filter): matched = wildcard_or_filter(name) else: - continue + raise NotImplementedError(f"Unsupported type {type(wildcard_or_filter)}") if matched: # Apply config using the same method as TensorQuantizer - module.set_from_attribute_config(full_cfg) + module.set_from_attribute_config(module_cfg) def restore_sparse_attention_model( @@ -236,16 +226,11 @@ def update_sparse_attention_metadata( if isinstance(module, SparseAttentionModule): module_name = get_unwrapped_name(name, model) - # Collect method config from module attributes - method_config = { - k[1:]: v - for k, v in module.__dict__.items() - if k.startswith("_") and k not in ("_method", "_enabled", "_sparse_method_instance") - } - + # Save the method configuration that was used + # _method_config already contains the validated config dict module_state = { "method": module._sparse_method_instance.name, - "method_config": method_config, + "method_config": module._method_config.copy(), } sparse_state[module_name] = module_state @@ -353,23 +338,16 @@ def print_sparse_attention_summary(model: nn.Module): method = getattr(module, "_method", "unknown") method_counts[method] = method_counts.get(method, 0) + 1 - print(f"\n{'=' * 70}") - print(f"{'Sparse Attention Summary':^70}") - print(f"{'=' * 70}") print(f"Total sparse attention modules: {len(sparse_modules)}") - print(f" Enabled: {enabled_count}") - print(f" Disabled: {disabled_count}") + print(f"Enabled: {enabled_count}") + print(f"Disabled: {disabled_count}") if method_counts: print("\nMethods:") for method, count in sorted(method_counts.items()): - print(f" {method}: {count}") - - print(f"\n{'Module Details':^70}") - print(f"{'-' * 70}") + print(f"{method}: {count}") for name, module in sparse_modules: - status = "✓" if module.is_enabled else "✗" method = getattr(module, "_method", "unknown") threshold = getattr(module, "_threshold", "N/A") @@ -381,7 +359,4 @@ def print_sparse_attention_summary(model: nn.Module): else: threshold_str = str(threshold) - print(f"{status} {name}") - print(f" Method: {method}, Threshold: {threshold_str}") - - print(f"{'=' * 70}\n") + print(f"{name}: Method: {method}, Threshold: {threshold_str}") diff --git a/modelopt/torch/sparsity/attention_sparsity/methods/__init__.py b/modelopt/torch/sparsity/attention_sparsity/methods/__init__.py index 5120bd755..8a109fda7 100644 --- a/modelopt/torch/sparsity/attention_sparsity/methods/__init__.py +++ b/modelopt/torch/sparsity/attention_sparsity/methods/__init__.py @@ -24,4 +24,4 @@ ] # Import method implementations to trigger registration -from . import flash_softmax_skip +from . import flash_skip_softmax diff --git a/modelopt/torch/sparsity/attention_sparsity/methods/flash_softmax_skip.py b/modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.py similarity index 90% rename from modelopt/torch/sparsity/attention_sparsity/methods/flash_softmax_skip.py rename to modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.py index 04b696d11..8801bafb0 100644 --- a/modelopt/torch/sparsity/attention_sparsity/methods/flash_softmax_skip.py +++ b/modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.py @@ -1,3 +1,18 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """Flash Attention-aware softmax skip method for sparse attention. This module implements block-wise sparsity that aligns with Flash Attention's @@ -12,8 +27,8 @@ from . import SparseAttentionMethod, register_sparse_method -@register_sparse_method("flash_softmax_skip") -class FlashSoftmaxSkipMethod(SparseAttentionMethod): +@register_sparse_method("flash_skip_softmax") +class FlashSkipSoftmax(SparseAttentionMethod): """Flash Attention-aware softmax skip sparse attention method. Implements row-level block-wise sparsity aligned with Flash Attention's @@ -25,20 +40,20 @@ def __init__(self, method_config: dict | None = None): Args: method_config: Configuration dict with threshold, br, bc, is_causal, etc. + All required fields should have defaults from SparseAttentionAttributeConfig. """ config = method_config or {} - # Extract configuration - self.threshold_config = config.get("threshold", 1e-4) - self.br = config.get("br", 128) - self.bc = config.get("bc", 128) + # Extract configuration (defaults handled by Pydantic) + self.threshold_config = config["threshold"] + self.br = config["br"] + self.bc = config["bc"] + self.backend = config["backend"] + self.is_causal = config["is_causal"] + + # Optional parameters not in Pydantic config self.enable_correction_factor = config.get("enable_correction_factor", True) - self.collect_stats = config.get("collect_stats", True) self.phase = config.get("phase", None) - self.backend = config.get("backend", "pytorch") - self.is_causal = config.get("is_causal", True) - # Calibration mode: when True, prevent threshold updates to preserve calibrator's test threshold - self._calibration_mode = False # Initialize threshold if isinstance(self.threshold_config, dict): @@ -55,10 +70,6 @@ def _update_threshold(self, phase: str): phase, self.threshold_config.get("default", self.threshold) ) - def set_calibration_mode(self, enabled: bool): - """Set calibration mode to prevent _update_threshold from modifying the threshold.""" - self._calibration_mode = enabled - def _infer_phase(self, attention_scores: torch.Tensor) -> str: """Infer phase from attention scores shape.""" return "decode" if attention_scores.shape[2] == 1 else "prefill" @@ -267,9 +278,8 @@ def apply_sparsity( # Infer phase from tensor shape phase = self._infer_phase(attention_scores) - # Update threshold for the detected phase (skip during calibration) - if not self._calibration_mode: - self._update_threshold(phase) + # Update threshold for the detected phase + self._update_threshold(phase) # Apply block-wise sparsity sparse_mask, stats = self.calc_correction_factor_and_p(attention_scores, phase) @@ -286,4 +296,4 @@ def apply_sparsity( @property def name(self) -> str: """Method identifier.""" - return "flash_softmax_skip" + return "flash_skip_softmax" diff --git a/modelopt/torch/sparsity/attention_sparsity/methods/registry.py b/modelopt/torch/sparsity/attention_sparsity/methods/registry.py index 081ad9e27..df7b5853b 100644 --- a/modelopt/torch/sparsity/attention_sparsity/methods/registry.py +++ b/modelopt/torch/sparsity/attention_sparsity/methods/registry.py @@ -15,6 +15,8 @@ """Registry and base class for sparse attention methods.""" +import re +import warnings from abc import ABC, abstractmethod import torch @@ -53,6 +55,27 @@ def name(self) -> str: _SPARSE_ATTENTION_METHODS: dict[str, dict[str, type[SparseAttentionMethod]]] = {} +def _version_key(version_str: str) -> list[int]: + """Extract numeric parts for proper version sorting. + + Args: + version_str: Version string (e.g., "v1", "v2", "v10") + + Returns: + List of integers extracted from version string for sorting + + Examples: + >>> _version_key("v1") + [1] + >>> _version_key("v10") + [10] + >>> _version_key("v2.3.1") + [2, 3, 1] + """ + parts = re.findall(r"\d+", version_str) + return [int(p) for p in parts] if parts else [0] + + def register_sparse_method(name: str, version: str = "v1"): """Decorator to register sparse attention methods with version support. @@ -60,10 +83,10 @@ def register_sparse_method(name: str, version: str = "v1"): name: Method name to register version: Version string (default: "v1") - Example: + Example:: + @register_sparse_method("my_method", version="v3") - class MyMethodV3(SparseAttentionMethod): - ... + class MyMethodV3(SparseAttentionMethod): ... """ def decorator(cls: type[SparseAttentionMethod]): @@ -71,8 +94,6 @@ def decorator(cls: type[SparseAttentionMethod]): _SPARSE_ATTENTION_METHODS[name] = {} if version in _SPARSE_ATTENTION_METHODS[name]: - import warnings - warnings.warn( f"Overriding existing sparse attention method: {name}@{version}", RuntimeWarning, @@ -99,8 +120,8 @@ def get_sparse_method(name: str, version: str | None = None) -> type[SparseAtten ValueError: If method name or version is not registered Example: - >>> get_sparse_method("flash_softmax_skip") # Latest version - >>> get_sparse_method("flash_softmax_skip", "v1") # Specific version + >>> get_sparse_method("flash_skip_softmax") # Latest version + >>> get_sparse_method("flash_skip_softmax", "v1") # Specific version """ if name not in _SPARSE_ATTENTION_METHODS: available = list(_SPARSE_ATTENTION_METHODS.keys()) @@ -109,7 +130,7 @@ def get_sparse_method(name: str, version: str | None = None) -> type[SparseAtten method_versions = _SPARSE_ATTENTION_METHODS[name] if not version: - version = sorted(method_versions.keys())[-1] + version = sorted(method_versions.keys(), key=_version_key)[-1] if version not in method_versions: available_versions = list(method_versions.keys()) diff --git a/modelopt/torch/sparsity/attention_sparsity/model_sparsify.py b/modelopt/torch/sparsity/attention_sparsity/model_sparsify.py index 908f3ad89..88434e746 100644 --- a/modelopt/torch/sparsity/attention_sparsity/model_sparsify.py +++ b/modelopt/torch/sparsity/attention_sparsity/model_sparsify.py @@ -19,15 +19,13 @@ import torch -from modelopt.torch.opt.conversion import ModeloptStateManager, apply_mode +from modelopt.torch.opt.conversion import apply_mode from modelopt.torch.opt.searcher import ForwardLoop -from .calibration import calibrate_sparse_attention from .config import SparseAttentionConfig from .mode import SparseAttentionModeRegistry __all__ = [ - "calibrate", "sparsify", ] @@ -39,19 +37,16 @@ def sparsify( ) -> torch.nn.Module: """Applies sparse attention optimization to the model in-place. - This method performs replacement of attention modules with their sparse counterparts and - optionally performs calibration as specified by ``config``. - ``forward_loop`` is used to forward data through the model and gather statistics for calibration. + This method performs replacement of attention modules with their sparse counterparts. Args: model: A pytorch model config: A dictionary or an instance of :class:`SparseAttentionConfig ` - specifying the values for keys ``"sparse_cfg"``, ``"method"``, and optionally ``"calibration"``. + specifying the values for keys ``"sparse_cfg"`` and ``"method"``. The ``"sparse_cfg"`` key specifies the sparse attention configurations. - The ``"method"`` key specifies the sparse attention method (e.g., "softmax_skip"). - The ``"calibration"`` key specifies calibration settings if automatic threshold tuning is desired. + The ``"method"`` key specifies the sparse attention method (e.g., "flash_skip_softmax"). Sparse attention configurations is a dictionary mapping wildcards or filter functions to its sparse attention attributes. The wildcards or filter functions are matched @@ -63,22 +58,13 @@ def sparsify( .. code-block::python config = { - "method": "softmax_skip", + "method": "flash_skip_softmax", "sparse_cfg": { - # Phase-aware thresholds with backend selection and calibration "*attention*": { "threshold": {"prefill": 1e-3, "decode": 1e-5}, - "backend": "pytorch", # Only pytorch backend supported + "backend": "pytorch", "enable": True, - "calibration": { # Optional: enables automatic threshold calibration - "target_sparse_ratio": 0.5, - "samples": 48, - "max_seqlen": 8192, - }, }, - # Disable for specific layers - "*layer_0*": {"enable": False}, - # Default settings "default": {"enable": False}, }, } @@ -89,11 +75,7 @@ def sparsify( This requires the model to be loaded with ``attn_implementation="eager"``. - forward_loop: A callable that forwards all calibration data through the model. This is used - to gather statistics for calibration. It should take model as the argument. It does not need - to return anything. - - This argument is only required when calibration is enabled in the config. + forward_loop: Reserved for future use. Here are a few examples for correct ``forward_loop`` definitions: @@ -144,54 +126,4 @@ def forward_loop(model) -> float: model, mode=[("sparse_attention", config)], registry=SparseAttentionModeRegistry ) - # Calibrate the sparsity ratio of the attention modules - return calibrate(model, forward_loop=forward_loop) - - -def calibrate( - model: torch.nn.Module, - forward_loop: ForwardLoop | None = None, -) -> torch.nn.Module: - """Calibrates sparse attention thresholds based on target sparsity. - - This function performs calibration to find optimal thresholds that achieve - the target sparsity ratio specified in the sparse attention configuration. - - Args: - model: A pytorch model with sparse attention already applied - forward_loop: Optional callable that forwards calibration data through the model. - It should take model as the argument and can optionally return metrics. - If None, will auto-generate RULER dataset for calibration. - - Returns: - The calibrated model with optimized sparse attention thresholds. - If no calibration is configured, returns the model unchanged. - """ - # Get the sparse attention config from the model's state - if not ModeloptStateManager.is_converted(model): - return model - - manager = ModeloptStateManager(model) - - sparse_attn_config = next( - (state["config"] for name, state in manager._state if name == "sparse_attention"), None - ) - - if sparse_attn_config is None: - return model - - # Check if calibration is configured in any sparse_cfg pattern - # Note: sparse_attn_config is always a dict (stored via config.model_dump()) - sparse_cfg = sparse_attn_config.get("sparse_cfg", {}) - - has_calibration = any( - isinstance(cfg, dict) and "calibration" in cfg for cfg in sparse_cfg.values() - ) - - if not has_calibration: - return model - - # Run calibration (handles stats collection internally) - calibrate_sparse_attention(model, sparse_attn_config, forward_loop=forward_loop) - return model diff --git a/modelopt/torch/sparsity/attention_sparsity/nn/__init__.py b/modelopt/torch/sparsity/attention_sparsity/nn/__init__.py deleted file mode 100644 index 00ff275bc..000000000 --- a/modelopt/torch/sparsity/attention_sparsity/nn/__init__.py +++ /dev/null @@ -1,20 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Neural network modules for sparse attention.""" - -from .sparse_attention import SparseAttentionModule, SparseAttentionRegistry - -__all__ = ["SparseAttentionModule", "SparseAttentionRegistry"] diff --git a/modelopt/torch/sparsity/attention_sparsity/plugins/huggingface.py b/modelopt/torch/sparsity/attention_sparsity/plugins/huggingface.py index 0012257b6..b0cd1dff6 100644 --- a/modelopt/torch/sparsity/attention_sparsity/plugins/huggingface.py +++ b/modelopt/torch/sparsity/attention_sparsity/plugins/huggingface.py @@ -15,12 +15,16 @@ """Dynamic sparse attention registration for HuggingFace models.""" +import logging + import torch.nn as nn import transformers from modelopt.torch.opt.dynamic import DynamicModule -from ..nn.sparse_attention import SparseAttentionModule, SparseAttentionRegistry +from ..sparse_attention import SparseAttentionModule, SparseAttentionRegistry + +logger = logging.getLogger(__name__) class _GenericSparseAttention(SparseAttentionModule): @@ -80,10 +84,8 @@ def register_sparse_attention_on_the_fly(model: nn.Module) -> bool: type_name = module_type.__name__ # Common attention module patterns - is_attention = ( - "attention" in type_name.lower() - or type_name.endswith("Attention") - or type_name.endswith("SelfAttention") + is_attention = "attention" in type_name.lower() or type_name.endswith( + ("Attention", "SelfAttention") ) if is_attention and module_type not in SparseAttentionRegistry: @@ -92,10 +94,12 @@ def register_sparse_attention_on_the_fly(model: nn.Module) -> bool: SparseAttentionRegistry.register({module_type: type_name})(_GenericSparseAttention) attention_types.add(module_type) registered_count += 1 - print(f"Registered {type_name} for sparse attention optimization") + logger.info(f"Registered {type_name} for sparse attention optimization") if registered_count > 0: - print(f"Dynamically registered {registered_count} attention module types for sparsity") + logger.info( + f"Dynamically registered {registered_count} attention module types for sparsity" + ) return registered_count > 0 diff --git a/modelopt/torch/sparsity/attention_sparsity/nn/sparse_attention.py b/modelopt/torch/sparsity/attention_sparsity/sparse_attention.py similarity index 83% rename from modelopt/torch/sparsity/attention_sparsity/nn/sparse_attention.py rename to modelopt/torch/sparsity/attention_sparsity/sparse_attention.py index a45931224..16b08bf19 100644 --- a/modelopt/torch/sparsity/attention_sparsity/nn/sparse_attention.py +++ b/modelopt/torch/sparsity/attention_sparsity/sparse_attention.py @@ -21,9 +21,8 @@ from modelopt.torch.opt.dynamic import DynamicModule, _DMRegistryCls from modelopt.torch.quantization.utils import replace_function -from ..config import SparseAttentionAttributeConfig -from ..methods import get_sparse_method -from .stats_manager import SparseAttentionStatsManager +from .config import SparseAttentionAttributeConfig +from .methods import get_sparse_method class SparseAttentionModule(DynamicModule): @@ -51,7 +50,7 @@ class SparseAttentionModule(DynamicModule): _enabled: bool Whether sparse attention is enabled _method: str - The sparse attention method to use (e.g., "flash_softmax_skip") + The sparse attention method to use (e.g., "flash_skip_softmax") _method_config: dict Configuration dictionary for the sparse method (threshold, br, bc, etc.) _sparse_method_instance: SparseAttentionMethod @@ -67,12 +66,10 @@ def set_from_attribute_config( Args: attribute_cfg: Sparse attention attribute configuration. - If None, uses default SparseAttentionAttributeConfig. """ - # Use default config if not provided - attribute_cfg = ( - attribute_cfg if attribute_cfg is not None else SparseAttentionAttributeConfig() - ) + # Ensure config is validated through Pydantic + if not isinstance(attribute_cfg, SparseAttentionAttributeConfig): + attribute_cfg = SparseAttentionAttributeConfig(**(attribute_cfg or {})) # Store raw config for method initialization self._method_config = {} @@ -87,8 +84,8 @@ def set_from_attribute_config( "method": ("_method", lambda val: str(val)), } - # Process each attribute from config - for attribute, val in attribute_cfg.items(): + # Process each attribute from validated config + for attribute, val in attribute_cfg.model_dump().items(): # Validate attribute if using config class if hasattr(SparseAttentionAttributeConfig, "model_fields"): assert attribute in SparseAttentionAttributeConfig.model_fields, ( @@ -132,10 +129,9 @@ def get_stats(self) -> dict: Returns: Dictionary with sparsity statistics including 'average_sparsity' if available. - Returns empty dict if stats manager is not enabled. + Returns empty dict (statistics collection will be added in calibration PR). """ - if self._stats_manager is not None and self._stats_manager.enabled: - return self._stats_manager.get_summary() + # TODO: Statistics collection will be added in calibration PR return {} def _setup(self): @@ -144,14 +140,6 @@ def _setup(self): if not hasattr(self, "_method"): self.set_from_attribute_config(None) - # Create stats manager if stats collection is enabled - if self._method_config.get("collect_stats", False): - self._stats_manager = SparseAttentionStatsManager( - module_name="sparse_attention", enabled=True - ) - else: - self._stats_manager = None - def forward(self, *args, **kwargs): """Forward with selected sparse attention method. @@ -169,10 +157,6 @@ def forward(self, *args, **kwargs): with context: result = super().forward(*args, **kwargs) - # Collect stats if manager is available - if self._stats_manager is not None and hasattr(self._sparse_method_instance, "_last_stats"): - self._stats_manager.collect(self._sparse_method_instance._last_stats) - return result def _get_sparse_context(self): diff --git a/tests/_test_utils/torch_sparsity/sparse_attention_common.py b/tests/_test_utils/torch_sparsity/sparse_attention_common.py new file mode 100644 index 000000000..7724908b0 --- /dev/null +++ b/tests/_test_utils/torch_sparsity/sparse_attention_common.py @@ -0,0 +1,195 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Common utilities for sparse attention testing.""" + +import torch +import torch.nn as nn + +import modelopt.torch.opt as mto +import modelopt.torch.sparsity.attention_sparsity as sparse_attn +from modelopt.torch.sparsity.attention_sparsity.sparse_attention import SparseAttentionModule + + +# Test models for sparse attention +class SimpleAttentionModel(nn.Module): + """Simple attention model for testing.""" + + def __init__(self, hidden_size=256, num_heads=8): + super().__init__() + self.hidden_size = hidden_size + self.num_heads = num_heads + self.attention = nn.MultiheadAttention( + embed_dim=hidden_size, num_heads=num_heads, batch_first=True + ) + self.fc = nn.Linear(hidden_size, hidden_size) + + def forward(self, x): + attn_output, _ = self.attention(x, x, x, need_weights=False) + return self.fc(attn_output) + + @classmethod + def get_input(cls, hidden_size=256, seq_len=10, batch_size=2): + """Get input tensor for testing.""" + return torch.randn(batch_size, seq_len, hidden_size) + + +class SimpleTransformerEncoderLayer(nn.Module): + """Simple TransformerEncoderLayer wrapper for testing.""" + + def __init__(self, d_model=128, nhead=4, dim_feedforward=256): + super().__init__() + self.d_model = d_model + self.nhead = nhead + self.layer = nn.TransformerEncoderLayer( + d_model=d_model, + nhead=nhead, + dim_feedforward=dim_feedforward, + batch_first=True, + ) + + def forward(self, x): + return self.layer(x) + + @classmethod + def get_input(cls, d_model=128, seq_len=20, batch_size=2): + """Get input tensor for testing.""" + return torch.randn(batch_size, seq_len, d_model) + + +class SimpleTransformerEncoder(nn.Module): + """Simple TransformerEncoder wrapper for testing.""" + + def __init__(self, d_model=128, nhead=4, num_layers=2): + super().__init__() + self.d_model = d_model + self.nhead = nhead + self.encoder = nn.TransformerEncoder( + nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, batch_first=True), + num_layers=num_layers, + ) + + def forward(self, x): + return self.encoder(x) + + @classmethod + def get_input(cls, d_model=128, seq_len=10, batch_size=2): + """Get input tensor for testing.""" + return torch.randn(batch_size, seq_len, d_model) + + +# Test configurations +FLASH_SKIP_SOFTMAX_DEFAULT_CFG = { + "sparse_cfg": { + "*attention*": { + "method": "flash_skip_softmax", + "threshold": 1e-4, + "br": 128, + "bc": 128, + "enable": True, + } + }, +} + +FLASH_SKIP_SOFTMAX_PHASE_AWARE_CFG = { + "sparse_cfg": { + "*attention*": { + "method": "flash_skip_softmax", + "threshold": {"prefill": 1e-3, "decode": 1e-5}, + "br": 128, + "bc": 128, + "enable": True, + } + }, +} + + +def get_test_configs(): + """Get test configurations for parameterized tests. + + Note: Calibration config excluded (requires GPU and real tokenizers). + """ + return [FLASH_SKIP_SOFTMAX_DEFAULT_CFG, FLASH_SKIP_SOFTMAX_PHASE_AWARE_CFG] + + +def sparsify_model_and_forward(model, config, calib_data): + """Apply sparse attention and run forward passes. + + Args: + model: Model to sparsify + config: Sparse attention configuration + calib_data: List of calibration data tensors + + Returns: + Sparsified model + """ + + def forward_loop(model): + for batch in calib_data: + model(batch) + + # Apply sparse attention + model = sparse_attn.sparsify(model, config, forward_loop=forward_loop) + + # Verify sparse attention modules were inserted + assert any(isinstance(m, SparseAttentionModule) for m in model.modules()), ( + "No sparse attention modules found" + ) + + # Test forward passes + model.eval() + with torch.no_grad(): + for batch in calib_data: + output = model(batch) + assert not torch.isnan(output).any(), "NaN in output" + assert output is not None, "Output is None" + + return model + + +def save_restore_test(model_cls, device, sparse_config): + """Test save and restore of sparse attention state. + + Args: + model_cls: Model class to test + device: Device to run on ('cpu' or 'cuda') + sparse_config: Sparse attention configuration + """ + # Create and sparsify reference model + model_sparse = model_cls().to(device) + calib_data = [model_sparse.get_input().to(device) for _ in range(2)] + + sparsify_model_and_forward(model_sparse, sparse_config, calib_data) + + # Save state + state_dict = mto.modelopt_state(model_sparse) + + # Restore to new model + model_restored = model_cls().to(device) + mto.restore_from_modelopt_state(model_restored, state_dict) + model_restored.load_state_dict(model_sparse.state_dict()) + + # Verify outputs match + test_input = calib_data[0] + model_sparse.eval() + model_restored.eval() + + with torch.no_grad(): + output_sparse = model_sparse(test_input) + output_restored = model_restored(test_input) + + assert torch.allclose(output_sparse, output_restored, atol=1e-6), ( + "Restored model output doesn't match original" + ) diff --git a/tests/examples/llm_sparsity/attention_sparsity/test_attention_sparsity.py b/tests/examples/llm_sparsity/attention_sparsity/test_attention_sparsity.py new file mode 100644 index 000000000..b82303990 --- /dev/null +++ b/tests/examples/llm_sparsity/attention_sparsity/test_attention_sparsity.py @@ -0,0 +1,52 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test attention sparsity example script.""" + +import pytest +from _test_utils.examples.run_command import extend_cmd_parts, run_example_command +from _test_utils.torch.misc import minimum_gpu + + +def run_attention_sparsity_command(*, model: str, method: str = "skip_softmax", **kwargs): + """Run attention sparsity example script. + + Args: + model: Path to model + method: Sparse attention method (corresponds to --sparse_attn arg) + **kwargs: Additional arguments to pass to the script + """ + kwargs.update( + { + "pyt_ckpt_path": model, + "sparse_attn": method, + } + ) + kwargs.setdefault("seq_len", 128) + kwargs.setdefault("num_samples", 1) + kwargs.setdefault("max_new_tokens", 16) + + cmd_parts = extend_cmd_parts(["python", "hf_sa.py"], **kwargs) + run_example_command(cmd_parts, "llm_sparsity/attention_sparsity") + + +@minimum_gpu(1) +@pytest.mark.parametrize("method", ["skip_softmax"]) +def test_attention_sparsity(tiny_llama_path, tmp_path, method): + """Test sparse attention with TinyLlama.""" + run_attention_sparsity_command( + model=tiny_llama_path, + method=method, + ) diff --git a/tests/examples/llm_sparsity/test_llama_sparsify.py b/tests/examples/llm_sparsity/weight_sparsity/test_llama_sparsify.py similarity index 93% rename from tests/examples/llm_sparsity/test_llama_sparsify.py rename to tests/examples/llm_sparsity/weight_sparsity/test_llama_sparsify.py index 7f9ef929b..7094b2989 100644 --- a/tests/examples/llm_sparsity/test_llama_sparsify.py +++ b/tests/examples/llm_sparsity/weight_sparsity/test_llama_sparsify.py @@ -31,7 +31,7 @@ def run_llm_sparsity_command( kwargs.setdefault("model_max_length", 1024) cmd_parts = extend_cmd_parts(["python", "hf_pts.py"], **kwargs) - run_example_command(cmd_parts, "llm_sparsity") + run_example_command(cmd_parts, "llm_sparsity/weight_sparsity") def run_llm_sparsity_ft_command( @@ -51,13 +51,15 @@ def run_llm_sparsity_ft_command( kwargs.setdefault("eval_bs", 1) cmd_parts = extend_cmd_parts(["bash", "launch_finetune.sh"], **kwargs) - run_example_command(cmd_parts, "llm_sparsity") + run_example_command(cmd_parts, "llm_sparsity/weight_sparsity") @pytest.fixture(scope="session") def data_path(tmp_path_factory): data_path = tmp_path_factory.mktemp("data") - run_example_command(["python", "data_prep.py", "--save_path", data_path], "llm_sparsity") + run_example_command( + ["python", "data_prep.py", "--save_path", data_path], "llm_sparsity/weight_sparsity" + ) # Copy eval data to train path for faster test run_example_command( diff --git a/tests/gpu/torch/sparsity/attention_sparsity/test_attention_sparsity_gpu.py b/tests/gpu/torch/sparsity/attention_sparsity/test_attention_sparsity_gpu.py new file mode 100644 index 000000000..bad077fdb --- /dev/null +++ b/tests/gpu/torch/sparsity/attention_sparsity/test_attention_sparsity_gpu.py @@ -0,0 +1,144 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""GPU tests for attention sparsity module.""" + +import pytest +import torch +from _test_utils.torch_sparsity.sparse_attention_common import ( + FLASH_SKIP_SOFTMAX_DEFAULT_CFG, + SimpleAttentionModel, + SimpleTransformerEncoder, + SimpleTransformerEncoderLayer, + get_test_configs, + save_restore_test, + sparsify_model_and_forward, +) + +import modelopt.torch.sparsity.attention_sparsity as sparse_attn + +# Skip all tests if GPU is not available +pytestmark = pytest.mark.skipif(not torch.cuda.is_available(), reason="GPU not available") + + +class TestAttentionSparsityGPU: + """GPU tests for attention sparsity.""" + + @pytest.fixture(autouse=True) + def setup(self): + """Setup for each test.""" + self.device = torch.device("cuda") + torch.cuda.empty_cache() + + @pytest.mark.parametrize( + "model_cls", + [SimpleAttentionModel, SimpleTransformerEncoderLayer, SimpleTransformerEncoder], + ) + @pytest.mark.parametrize("config", get_test_configs()) + def test_gpu_forward(self, model_cls, config): + """Test sparse attention forward pass on GPU.""" + model = model_cls().to(self.device) + calib_data = [model.get_input().to(self.device) for _ in range(2)] + + sparsify_model_and_forward(model, config, calib_data) + + # Additional GPU-specific checks + for batch in calib_data: + with torch.no_grad(): + output = model(batch) + assert output.device.type == "cuda" + + @pytest.mark.parametrize( + "model_cls", + [SimpleAttentionModel, SimpleTransformerEncoderLayer, SimpleTransformerEncoder], + ) + def test_save_restore(self, model_cls): + """Test save and restore on GPU.""" + save_restore_test(model_cls, "cuda", FLASH_SKIP_SOFTMAX_DEFAULT_CFG) + + @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16]) + def test_different_dtypes(self, dtype): + """Test sparse attention with different dtypes.""" + model = SimpleTransformerEncoderLayer(d_model=256, nhead=8).to(self.device).to(dtype) + calib_data = [model.get_input(d_model=256).to(self.device).to(dtype) for _ in range(2)] + + sparse_model = sparsify_model_and_forward(model, FLASH_SKIP_SOFTMAX_DEFAULT_CFG, calib_data) + + # Test forward + x = model.get_input(d_model=256).to(self.device).to(dtype) + with torch.no_grad(): + output = sparse_model(x) + + assert output.dtype == dtype + assert not torch.isnan(output).any() + if dtype != torch.bfloat16: # bfloat16 can have inf + assert not torch.isinf(output).any() + + def test_backward_pass(self): + """Test that gradients flow correctly through sparse attention.""" + model = SimpleAttentionModel(hidden_size=128, num_heads=4).to(self.device) + model = sparse_attn.sparsify(model, FLASH_SKIP_SOFTMAX_DEFAULT_CFG) + + # Enable training mode + model.train() + + x = model.get_input(hidden_size=128, seq_len=32).to(self.device) + x.requires_grad = True + + # Forward + output = model(x) + loss = output.sum() + + # Backward + loss.backward() + + # Check gradients exist + assert x.grad is not None + assert not torch.isnan(x.grad).any() + + # Check model gradients + for name, param in model.named_parameters(): + if param.requires_grad: + assert param.grad is not None, f"No gradient for {name}" + + @pytest.mark.parametrize("seq_len", [1, 1024, 2048]) + def test_various_sequence_lengths(self, seq_len): + """Test sparse attention with various sequence lengths.""" + model = SimpleAttentionModel(hidden_size=128, num_heads=4).to(self.device) + model = sparse_attn.sparsify(model, FLASH_SKIP_SOFTMAX_DEFAULT_CFG) + + x = model.get_input(hidden_size=128, seq_len=seq_len, batch_size=1).to(self.device) + + model.eval() + with torch.no_grad(): + output = model(x) + + assert output.shape == (1, seq_len, 128) + assert not torch.isnan(output).any() + + @pytest.mark.parametrize("batch_size", [1, 8, 16]) + def test_various_batch_sizes(self, batch_size): + """Test sparse attention with various batch sizes.""" + model = SimpleTransformerEncoderLayer(d_model=128, nhead=4).to(self.device) + model = sparse_attn.sparsify(model, FLASH_SKIP_SOFTMAX_DEFAULT_CFG) + + x = model.get_input(d_model=128, seq_len=64, batch_size=batch_size).to(self.device) + + model.eval() + with torch.no_grad(): + output = model(x) + + assert output.shape == (batch_size, 64, 128) + assert not torch.isnan(output).any() diff --git a/tests/gpu/torch/sparsity/attention_sparsity/test_integration_gpu.py b/tests/gpu/torch/sparsity/attention_sparsity/test_integration_gpu.py new file mode 100644 index 000000000..586cb3b9d --- /dev/null +++ b/tests/gpu/torch/sparsity/attention_sparsity/test_integration_gpu.py @@ -0,0 +1,190 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Integration testing with locally created minimal Llama model.""" + +import pytest +import torch +from _test_utils.torch.transformers_models import create_tiny_llama_dir +from transformers import AutoModelForCausalLM, AutoTokenizer + +import modelopt.torch.sparsity.attention_sparsity as sparse_attn +from modelopt.torch.sparsity.attention_sparsity import SparseAttentionConfig +from modelopt.torch.sparsity.attention_sparsity.sparse_attention import SparseAttentionModule + +# Skip all tests if GPU is not available +pytestmark = pytest.mark.skipif(not torch.cuda.is_available(), reason="GPU not available") + + +@pytest.fixture(scope="module") +def tiny_llama_dir(tmp_path_factory): + """Create minimal Llama model locally.""" + return create_tiny_llama_dir( + tmp_path_factory.mktemp("tiny_llama"), + with_tokenizer=True, + num_hidden_layers=2, # Minimal layers for fast testing + hidden_size=512, + intermediate_size=1024, + ) + + +@pytest.fixture(scope="module") +def tinyllama_model(tiny_llama_dir): + """Load locally created tiny Llama model.""" + model = AutoModelForCausalLM.from_pretrained( + tiny_llama_dir, + attn_implementation="eager", + device_map="cuda", + ) + return model + + +@pytest.fixture(scope="module") +def tinyllama_tokenizer(tiny_llama_dir): + """Load tokenizer for tiny Llama model.""" + tokenizer = AutoTokenizer.from_pretrained(tiny_llama_dir) + return tokenizer + + +class TestTinyLlama: + """TinyLlama sparse attention tests.""" + + def test_load_and_sparsify(self, tinyllama_model): + """Load TinyLlama and apply sparse attention.""" + model = tinyllama_model + + config = SparseAttentionConfig( + sparse_cfg={ + "*attn*": { + "method": "flash_skip_softmax", + "threshold": 1e-3, + "br": 128, + "bc": 128, + "backend": "pytorch", + "enable": True, + } + }, + ) + + sparse_model = sparse_attn.sparsify(model, config) + + # Verify sparse attention modules were added + sparse_count = sum( + 1 for m in sparse_model.modules() if isinstance(m, SparseAttentionModule) + ) + assert sparse_count > 0, "No sparse attention modules found" + + # Our tiny llama has 2 layers, so should have 2 attention modules + assert sparse_count == 2, f"Expected 2 sparse modules, got {sparse_count}" + + def test_forward_prefill(self, tinyllama_model, tinyllama_tokenizer): + """Forward pass with seq_len=64 (prefill).""" + model = tinyllama_model + tokenizer = tinyllama_tokenizer + + config = SparseAttentionConfig( + sparse_cfg={ + "*attn*": { + "threshold": 1e-3, + "backend": "pytorch", + "enable": True, + } + }, + ) + + sparse_model = sparse_attn.sparsify(model, config) + + # Create prefill input (seq_len > 1) + test_text = "Once upon a time in a land far away" + inputs = tokenizer(test_text, return_tensors="pt").to("cuda") + + # Forward pass + sparse_model.eval() + with torch.no_grad(): + outputs = sparse_model(**inputs) + + # Verify output + assert outputs.logits is not None + assert not torch.isnan(outputs.logits).any() + assert outputs.logits.shape[1] == inputs.input_ids.shape[1] # seq_len preserved + + def test_forward_decode(self, tinyllama_model): + """Forward pass with seq_len=1 (decode).""" + model = tinyllama_model + + config = SparseAttentionConfig( + sparse_cfg={ + "*attn*": { + "threshold": 1e-5, # More conservative for decode + "backend": "pytorch", + "enable": True, + } + }, + ) + + sparse_model = sparse_attn.sparsify(model, config) + + # Create decode input (seq_len = 1) + input_ids = torch.randint(0, 32000, (1, 1), device="cuda") + + # Forward pass + sparse_model.eval() + with torch.no_grad(): + outputs = sparse_model(input_ids) + + # Verify output + assert outputs.logits is not None + assert not torch.isnan(outputs.logits).any() + assert outputs.logits.shape == (1, 1, 32000) # batch=1, seq=1, vocab_size + + def test_gqa_attention(self, tinyllama_model): + """Verify GQA support (num_kv_heads < num_heads).""" + model = tinyllama_model + + # Check if model uses GQA + config = model.config + has_gqa = hasattr(config, "num_key_value_heads") and ( + config.num_key_value_heads < config.num_attention_heads + ) + + if not has_gqa: + pytest.skip("Model does not use GQA") + + # Apply sparse attention + sparse_config = SparseAttentionConfig( + sparse_cfg={ + "*attn*": { + "threshold": 1e-3, + "backend": "pytorch", + "enable": True, + } + }, + ) + + sparse_model = sparse_attn.sparsify(model, sparse_config) + + # Test forward pass with GQA + input_ids = torch.randint(0, 32000, (1, 32), device="cuda") + + sparse_model.eval() + with torch.no_grad(): + outputs = sparse_model(input_ids) + + assert outputs.logits is not None + assert not torch.isnan(outputs.logits).any() + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/tests/unit/torch/sparsity/attention_sparsity/test_flash_skip_softmax.py b/tests/unit/torch/sparsity/attention_sparsity/test_flash_skip_softmax.py new file mode 100644 index 000000000..b487d8639 --- /dev/null +++ b/tests/unit/torch/sparsity/attention_sparsity/test_flash_skip_softmax.py @@ -0,0 +1,282 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for FlashSkipSoftmax method internals.""" + +import pytest +import torch + +pytest.importorskip("transformers") + +from modelopt.torch.sparsity.attention_sparsity.methods.flash_skip_softmax import FlashSkipSoftmax + + +class TestFlashSkipSoftmaxMethod: + """Test FlashSkipSoftmax method internals.""" + + def test_phase_inference(self): + """Test phase detection from attention score shape.""" + method = FlashSkipSoftmax( + { + "threshold": 1e-3, + "br": 128, + "bc": 128, + "backend": "pytorch", + "is_causal": True, + } + ) + + # Prefill: seq_q > 1 + prefill_scores = torch.randn(2, 4, 64, 64) + assert method._infer_phase(prefill_scores) == "prefill" + + # Decode: seq_q = 1 + decode_scores = torch.randn(2, 4, 1, 64) + assert method._infer_phase(decode_scores) == "decode" + + def test_threshold_update_dict_config(self): + """Test threshold updates with dict config.""" + method = FlashSkipSoftmax( + { + "threshold": {"prefill": 1e-3, "decode": 1e-5}, + "br": 128, + "bc": 128, + "backend": "pytorch", + "is_causal": True, + } + ) + + # Initially uses prefill threshold + initial_threshold = method.threshold + + # Update to decode + method._update_threshold("decode") + assert method.threshold == 1e-5 + assert method.threshold != initial_threshold + + # Update back to prefill + method._update_threshold("prefill") + assert method.threshold == 1e-3 + + def test_threshold_update_static_config(self): + """Test threshold with static float config.""" + method = FlashSkipSoftmax( + { + "threshold": 5e-4, + "br": 128, + "bc": 128, + "backend": "pytorch", + "is_causal": True, + } + ) + + initial_threshold = method.threshold + assert initial_threshold == 5e-4 + + # Should not change for static config + method._update_threshold("decode") + assert method.threshold == 5e-4 + + def test_block_reshaping_divisible(self): + """Test block reshaping with divisible sequence lengths.""" + method = FlashSkipSoftmax( + { + "threshold": 1e-3, + "br": 128, + "bc": 128, + "backend": "pytorch", + "is_causal": True, + } + ) + + # Seq lengths divisible by 128 + attn = torch.randn(2, 4, 256, 256) + blocked, num_br, num_bc, padded_q, padded_k = method._reshape_to_blocks(attn, 128, 128) + + # Verify block dimensions + assert blocked.shape == (2, 4, 2, 128, 2, 128) # 256/128 = 2 blocks + assert num_br == 2 + assert num_bc == 2 + assert padded_q == 256 # No padding + assert padded_k == 256 # No padding + + def test_block_reshaping_with_padding(self): + """Test block reshaping with non-divisible lengths.""" + method = FlashSkipSoftmax( + { + "threshold": 1e-3, + "br": 128, + "bc": 128, + "backend": "pytorch", + "is_causal": True, + } + ) + + # Seq lengths NOT divisible by 128 + attn = torch.randn(2, 4, 200, 300) + blocked, num_br, num_bc, padded_q, padded_k = method._reshape_to_blocks(attn, 128, 128) + + # Verify padding applied + assert padded_q == 256 # ceil(200/128) * 128 = 2 * 128 + assert padded_k == 384 # ceil(300/128) * 128 = 3 * 128 + assert num_br == 2 + assert num_bc == 3 + assert blocked.shape == (2, 4, 2, 128, 3, 128) + + def test_correction_factor_calculation_prefill(self): + """Test correction factor for prefill phase.""" + method = FlashSkipSoftmax( + { + "threshold": 1e-3, + "br": 128, + "bc": 128, + "backend": "pytorch", + "is_causal": True, + } + ) + + # Create simple attention pattern + attn = torch.randn(1, 1, 128, 256) + + mask, stats = method.calc_correction_factor_and_p(attn, "prefill") + + # Verify stats structure + assert "correction_factor" in stats + assert "sparsity" in stats + assert "phase" in stats + assert "total_blocks" in stats + assert stats["phase"] == "prefill" + assert 0 <= stats["correction_factor"] <= 1 + # Sparsity can be negative if threshold is too low (more blocks kept than expected) + assert -1 <= stats["sparsity"] <= 1 + + def test_correction_factor_calculation_decode(self): + """Test correction factor for decode phase.""" + method = FlashSkipSoftmax( + { + "threshold": 1e-5, + "br": 128, + "bc": 128, + "backend": "pytorch", + "is_causal": True, + } + ) + + # Decode: single query + attn = torch.randn(1, 1, 1, 256) + + mask, stats = method.calc_correction_factor_and_p(attn, "decode") + + # Verify stats structure + assert stats["phase"] == "decode" + assert "correction_factor" in stats + assert 0 <= stats["sparsity"] <= 1 + assert mask.shape == (1, 1, 1, 256) + + def test_sparsity_statistics(self): + """Test sparsity statistics structure.""" + method = FlashSkipSoftmax( + { + "threshold": 1e-3, + "br": 128, + "bc": 128, + "backend": "pytorch", + "is_causal": True, + } + ) + + attn = torch.randn(1, 1, 128, 256) + _, stats = method.calc_correction_factor_and_p(attn, "prefill") + + # Verify statistics are present + assert stats["total_blocks"] > 0 + assert "sparse_blocks" in stats + assert "sample_length" in stats + assert stats["sample_length"] == 256 + + def test_block_mask_correctness(self): + """Test block mask shape and type.""" + method = FlashSkipSoftmax( + { + "threshold": 1e-3, + "br": 128, + "bc": 128, + "backend": "pytorch", + "is_causal": True, + } + ) + + attn = torch.randn(2, 4, 128, 256) + mask, _ = method.calc_correction_factor_and_p(attn, "prefill") + + # Verify mask properties + assert mask.shape == attn.shape + assert mask.dtype == torch.bool + assert mask.device == attn.device + + def test_causal_vs_noncausal(self): + """Test total_blocks calculation for causal vs non-causal.""" + config_base = { + "threshold": 1e-3, + "br": 128, + "bc": 128, + "backend": "pytorch", + } + + method_causal = FlashSkipSoftmax({**config_base, "is_causal": True}) + method_noncausal = FlashSkipSoftmax({**config_base, "is_causal": False}) + + attn = torch.randn(1, 1, 256, 256) # 2x2 blocks + + _, stats_causal = method_causal.calc_correction_factor_and_p(attn, "prefill") + _, stats_noncausal = method_noncausal.calc_correction_factor_and_p(attn, "prefill") + + # Causal: 2*(2+1)/2 = 3 blocks + # Non-causal: 2*2 = 4 blocks + assert stats_causal["total_blocks"] == 3 + assert stats_noncausal["total_blocks"] == 4 + + def test_apply_sparsity_assertions(self): + """Test apply_sparsity input validation.""" + method = FlashSkipSoftmax( + { + "threshold": 1e-3, + "br": 128, + "bc": 128, + "backend": "pytorch", + "is_causal": True, + } + ) + + # Test: attention_scores required + with pytest.raises(AssertionError, match="attention_scores must be provided"): + method.apply_sparsity() + + # Test: 4D shape required + with pytest.raises(AssertionError, match="Expected 4D"): + method.apply_sparsity(attention_scores=torch.randn(2, 64, 64)) # 3D + + def test_name_property(self): + """Test method name property.""" + method = FlashSkipSoftmax( + { + "threshold": 1e-3, + "br": 128, + "bc": 128, + "backend": "pytorch", + "is_causal": True, + } + ) + assert method.name == "flash_skip_softmax" diff --git a/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_config.py b/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_config.py new file mode 100644 index 000000000..1824825f9 --- /dev/null +++ b/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_config.py @@ -0,0 +1,129 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test sparse attention configuration validation.""" + +import pytest +from pydantic import ValidationError + +pytest.importorskip("transformers") + +from modelopt.torch.sparsity.attention_sparsity.config import ( + SKIP_SOFTMAX_DEFAULT, + FlashSkipSoftmaxConfig, + SparseAttentionAttributeConfig, + SparseAttentionConfig, +) + + +class TestSparseAttentionAttributeConfig: + """Test SparseAttentionAttributeConfig validators.""" + + def test_valid_config(self): + """Test creating valid config.""" + config = SparseAttentionAttributeConfig( + method="flash_skip_softmax", + threshold=1e-4, + br=128, + bc=128, + enable=True, + ) + assert config.method == "flash_skip_softmax" + assert config.threshold == 1e-4 + assert config.br == 128 + assert config.bc == 128 + + def test_method_validation(self): + """Test method must be string.""" + with pytest.raises(ValidationError, match="Input should be a valid string"): + SparseAttentionAttributeConfig(method=123) + + def test_block_size_validation_negative(self): + """Test block sizes must be positive.""" + with pytest.raises(ValidationError, match="Block size must be positive"): + SparseAttentionAttributeConfig(br=-1) + + with pytest.raises(ValidationError, match="Block size must be positive"): + SparseAttentionAttributeConfig(bc=0) + + def test_block_size_validation_large(self): + """Test that large block sizes are accepted.""" + # Large block sizes are allowed (warning removed for simplicity) + config = SparseAttentionAttributeConfig(br=2048) + assert config.br == 2048 + + def test_threshold_validation_range(self): + """Test threshold must be in range (0, 1).""" + with pytest.raises(ValidationError, match="Threshold must be in range"): + SparseAttentionAttributeConfig(threshold=0) + + with pytest.raises(ValidationError, match="Threshold must be in range"): + SparseAttentionAttributeConfig(threshold=-0.1) + + with pytest.raises(ValidationError, match="Threshold must be in range"): + SparseAttentionAttributeConfig(threshold=1.0) + + with pytest.raises(ValidationError, match="Threshold must be in range"): + SparseAttentionAttributeConfig(threshold=1.5) + + def test_threshold_validation_dict(self): + """Test threshold dict validation.""" + # Valid phase-aware threshold + config = SparseAttentionAttributeConfig(threshold={"prefill": 1e-3, "decode": 1e-5}) + assert config.threshold == {"prefill": 1e-3, "decode": 1e-5} + + # Invalid phase key + with pytest.raises(ValidationError, match="Invalid threshold phases"): + SparseAttentionAttributeConfig(threshold={"invalid_phase": 1e-3}) + + # Invalid threshold value in dict (negative) + with pytest.raises(ValidationError, match="must be in range"): + SparseAttentionAttributeConfig(threshold={"prefill": -1e-3}) + + # Invalid threshold value in dict (>= 1.0) + with pytest.raises(ValidationError, match="must be in range"): + SparseAttentionAttributeConfig(threshold={"prefill": 1.0}) + + def test_threshold_validation_type(self): + """Test threshold type validation.""" + with pytest.raises(ValidationError, match="Input should be a valid"): + SparseAttentionAttributeConfig(threshold="invalid") + + +class TestSparseAttentionConfig: + """Test SparseAttentionConfig.""" + + def test_default_config(self): + """Test default configuration.""" + config = SparseAttentionConfig() + assert "sparse_cfg" in config.model_dump() + # Check default pattern has method + assert config.sparse_cfg["*attention*"]["method"] == "flash_skip_softmax" + + def test_predefined_config(self): + """Test pre-defined configuration.""" + assert "sparse_cfg" in SKIP_SOFTMAX_DEFAULT + assert "method" in SKIP_SOFTMAX_DEFAULT["sparse_cfg"]["*attn*"] + assert "*attn*" in SKIP_SOFTMAX_DEFAULT["sparse_cfg"] + + +class TestFlashSkipSoftmaxConfig: + """Test FlashSkipSoftmaxConfig.""" + + def test_default_values(self): + """Test default values for flash_skip_softmax config.""" + config = FlashSkipSoftmaxConfig() + assert "*attention*" in config.sparse_cfg + assert config.sparse_cfg["*attention*"]["method"] == "flash_skip_softmax" diff --git a/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_conversion.py b/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_conversion.py new file mode 100644 index 000000000..8df8fe476 --- /dev/null +++ b/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_conversion.py @@ -0,0 +1,208 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for sparse attention conversion and replacement.""" + +import pytest + +pytest.importorskip("transformers") + +import torch.nn as nn +from _test_utils.torch_sparsity.sparse_attention_common import ( + FLASH_SKIP_SOFTMAX_DEFAULT_CFG, + SimpleAttentionModel, + SimpleTransformerEncoderLayer, +) + +import modelopt.torch.opt as mto +import modelopt.torch.sparsity.attention_sparsity as sparse_attn +from modelopt.torch.sparsity.attention_sparsity.conversion import ( + disable_sparse_attention, + enable_sparse_attention, + print_sparse_attention_summary, +) +from modelopt.torch.sparsity.attention_sparsity.sparse_attention import SparseAttentionModule + + +class TestSparseAttentionReplacement: + """Test module replacement logic.""" + + def test_basic_replacement(self): + """Test that attention modules are replaced with sparse versions.""" + model = SimpleAttentionModel() + + # Count original attention modules + original_attention_count = sum( + isinstance(m, nn.MultiheadAttention) for m in model.modules() + ) + assert original_attention_count > 0 + + # Apply sparse attention + sparse_model = sparse_attn.sparsify(model, FLASH_SKIP_SOFTMAX_DEFAULT_CFG) + + # Count sparse attention modules + sparse_attention_count = sum( + isinstance(m, SparseAttentionModule) for m in sparse_model.modules() + ) + + # Verify replacement occurred + assert sparse_attention_count > 0 + + def test_enable_disable_toggle(self): + """Test enabling and disabling sparse attention.""" + model = SimpleAttentionModel() + model = sparse_attn.sparsify(model, FLASH_SKIP_SOFTMAX_DEFAULT_CFG) + + # Check initially enabled + for module in model.modules(): + if isinstance(module, SparseAttentionModule): + assert module.is_enabled + + # Disable all sparse attention modules + disable_sparse_attention(model, "*") + for module in model.modules(): + if isinstance(module, SparseAttentionModule): + assert not module.is_enabled + + # Re-enable all sparse attention modules + enable_sparse_attention(model, "*") + for module in model.modules(): + if isinstance(module, SparseAttentionModule): + assert module.is_enabled + + def test_pattern_based_replacement(self): + """Test pattern-based selective replacement.""" + model = SimpleTransformerEncoderLayer() + + # Apply with pattern + config = { + "sparse_cfg": { + "*self_attn*": { + "method": "flash_skip_softmax", + "threshold": 1e-4, + "br": 128, + "bc": 128, + "enable": True, + }, + "default": {"enable": False}, + }, + } + + sparse_model = sparse_attn.sparsify(model, config) + + # Verify sparse modules exist + has_sparse = any(isinstance(m, SparseAttentionModule) for m in sparse_model.modules()) + assert has_sparse + + +class TestConversionEdgeCases: + """Test edge cases and error paths in conversion.""" + + def test_callable_filter(self): + """Test using callable filter instead of wildcard.""" + model = SimpleAttentionModel() + + # Use callable filter + def filter_func(name): + return "attn" in name + + config = { + "sparse_cfg": { + filter_func: { + "method": "flash_skip_softmax", + "threshold": 1e-3, + "enable": True, + }, + }, + } + + sparse_model = sparse_attn.sparsify(model, config) + has_sparse = any(isinstance(m, SparseAttentionModule) for m in sparse_model.modules()) + assert has_sparse + + def test_no_matching_modules(self): + """Test pattern that matches nothing.""" + model = SimpleAttentionModel() + + config = { + "sparse_cfg": { + "*nonexistent*": { + "method": "flash_skip_softmax", + "threshold": 1e-3, + "enable": True, + }, + }, + } + + # Should not error, even with no matches + sparse_attn.sparsify(model, config) + + def test_disable_enable_functions(self): + """Test disable/enable utility functions.""" + from modelopt.torch.sparsity.attention_sparsity.conversion import ( + disable_sparse_attention, + enable_sparse_attention, + ) + + model = SimpleAttentionModel() + model = sparse_attn.sparsify(model, FLASH_SKIP_SOFTMAX_DEFAULT_CFG) + + # Disable all + disable_sparse_attention(model, "*") + for module in model.modules(): + if isinstance(module, SparseAttentionModule): + assert not module.is_enabled + + # Enable all + enable_sparse_attention(model, "*") + for module in model.modules(): + if isinstance(module, SparseAttentionModule): + assert module.is_enabled + + def test_print_sparse_attention_summary(self, capsys): + """Test print_sparse_attention_summary function.""" + model = SimpleAttentionModel() + model = sparse_attn.sparsify(model, FLASH_SKIP_SOFTMAX_DEFAULT_CFG) + + # Print summary + print_sparse_attention_summary(model) + + # Capture output + captured = capsys.readouterr() + assert "Total sparse attention modules:" in captured.out + assert "Enabled:" in captured.out + + def test_restore_sparse_attention_model(self): + """Test save/restore via modelopt_state.""" + # Create and sparsify original model + model_orig = SimpleAttentionModel() + model_orig = sparse_attn.sparsify(model_orig, FLASH_SKIP_SOFTMAX_DEFAULT_CFG) + + # Save state + state_dict = mto.modelopt_state(model_orig) + + # Restore to new model + model_restored = SimpleAttentionModel() + mto.restore_from_modelopt_state(model_restored, state_dict) + + # Verify restoration + has_sparse = any(isinstance(m, SparseAttentionModule) for m in model_restored.modules()) + assert has_sparse + + # Verify module is configured + for module in model_restored.modules(): + if isinstance(module, SparseAttentionModule): + assert hasattr(module, "_method") + assert module._method == "flash_skip_softmax" diff --git a/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_mode.py b/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_mode.py new file mode 100644 index 000000000..e7e32e153 --- /dev/null +++ b/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_mode.py @@ -0,0 +1,43 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for sparse attention mode registry.""" + +import pytest + +pytest.importorskip("transformers") + +from modelopt.torch.opt.mode import _ModeRegistryCls +from modelopt.torch.sparsity.attention_sparsity.mode import SparseAttentionModeRegistry + + +def test_sparse_attention_mode_exists(): + """Test that sparse_attention mode is registered.""" + assert "sparse_attention" in SparseAttentionModeRegistry + + +def test_sparse_attention_mode_descriptor(): + """Test sparse attention mode descriptor properties.""" + mode_descriptor = _ModeRegistryCls.get_from_any("sparse_attention") + + assert mode_descriptor is not None + assert hasattr(mode_descriptor, "config_class") + assert hasattr(mode_descriptor, "convert") + + +def test_mode_registry_get(): + """Test getting mode from registry.""" + mode = SparseAttentionModeRegistry["sparse_attention"] + assert mode is not None