diff --git a/.gitignore b/.gitignore index 931c8ca..6a904c7 100644 --- a/.gitignore +++ b/.gitignore @@ -236,3 +236,6 @@ compile_commands.json # Rust lib Cargo.lock + +/examples/results +*.npy \ No newline at end of file diff --git a/benchmarks/README.md b/benchmarks/README.md new file mode 100644 index 0000000..e390344 --- /dev/null +++ b/benchmarks/README.md @@ -0,0 +1,89 @@ +# TopK Kernel Benchmarking Suite + +Standalone benchmarking for Vortex's three topk kernel variants, measuring kernel-level latency isolated from the full SGLang inference pipeline. + +## Kernel Variants + +| Kernel | Description | +|--------|-------------| +| `naive` | CUB radix sort (bf16 only) | +| `sglang_m0` | Two-stage hierarchical radix sort, no mapping | +| `sglang_m1` | + LUT mapping (requires `--lut-path`) | +| `sglang_m2` | + Quantile mapping (requires `--quantiles-path`) | +| `sglang_m3` | + Power mapping (configurable via `--mapping-power`) | +| `sglang_m4` | + Log mapping | + +## Quick Start + +```bash +# Activate environment +source /scr/dataset/yuke/xinrui/uv_env/vortex/bin/activate + +# Quick single-config test +python benchmarking/bench_topk.py \ + --batch-sizes 8 \ + --seq-lens 4096 \ + --topk-vals 30 \ + --num-kv-heads 2 \ + --repeat 200 + +# Sweep with histogram analysis +python benchmarking/bench_topk.py \ + --batch-sizes 4 8 16 \ + --seq-lens 2048 4096 8192 \ + --topk-vals 30 64 \ + --num-kv-heads 2 \ + --repeat 100 \ + --histogram + +# Full sweep with JSON output +python benchmarking/bench_topk.py \ + --output-json benchmarking/results.json \ + --histogram +``` + +## CLI Options + +| Argument | Default | Description | +|----------|---------|-------------| +| `--batch-sizes` | 1 4 8 16 32 64 | Batch sizes to sweep | +| `--seq-lens` | 1024 2048 4096 8192 | Sequence lengths to sweep | +| `--topk-vals` | 16 30 64 | TopK values to sweep | +| `--num-kv-heads` | 2 4 8 | KV head counts to sweep | +| `--page-size` | 16 | Tokens per page | +| `--reserved-bos` | 1 | Reserved BOS pages | +| `--reserved-eos` | 2 | Reserved EOS pages | +| `--score-dtype` | bfloat16 | Score tensor dtype (bfloat16 or float32) | +| `--distributions` | normal lognormal uniform | Score distributions to test | +| `--warmup` | 10 | Warmup iterations | +| `--repeat` | 100 | Timed iterations | +| `--mapping-power` | 0.5 | Power parameter for mode=3 | +| `--lut-path` | None | Path to .npy uint8[256] LUT for mode=1 | +| `--quantiles-path` | None | Path to .npy float32[256] quantiles for mode=2 | +| `--output-json` | None | Save results to JSON file | +| `--filter-kernels` | None | Only run specific kernels (e.g., `naive sglang_m0`) | +| `--histogram` | False | Collect bin distribution statistics | + +## Histogram Analysis + +When `--histogram` is passed, each config additionally runs `topk_profile_histogram` and reports: + +- **max/mean ratio**: Peak bin count divided by average (lower = more uniform) +- **Gini coefficient**: Inequality measure of bin distribution (0 = perfectly uniform) +- **nonzero_bins**: How many of the 256 bins received any values + +This shows whether mapping modes improve bin uniformity for a given score distribution. + +## Output Format + +``` +TopK Kernel Benchmark Results +GPU: NVIDIA H100 80GB HBM3 | SM count: 132 + +bs=8 | seq=4096 | topk=30 | heads=2 | pages/seg=256 | dist=normal + naive : 0.0420ms (median) +/- 0.0030ms [min=0.0390, max=0.0510] + sglang mode=0 : 0.0310ms (median) +/- 0.0020ms [min=0.0290, max=0.0380] + sglang mode=3 : 0.0330ms (median) +/- 0.0020ms [min=0.0300, max=0.0400] + sglang mode=4 : 0.0320ms (median) +/- 0.0020ms [min=0.0300, max=0.0390] + histogram stats : max/mean=3.99 gini=0.568 nonzero_bins=70/256 +``` diff --git a/benchmarks/__init__.py b/benchmarks/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/benchmarks/analyze_topk_distribution.py b/benchmarks/analyze_topk_distribution.py new file mode 100644 index 0000000..5531187 --- /dev/null +++ b/benchmarks/analyze_topk_distribution.py @@ -0,0 +1,494 @@ +""" +TopK distribution analysis and visualization. + +Loads profiling data from: + - profile_topk_distribution.py output (.npz): raw histograms, LUT tables + - bench_topk.py output (.json): benchmark results + per-mode histogram data + +Produces visualization plots for evaluating mapping mode effectiveness. + +Usage: + python scripts/analyze_topk_distribution.py \ + --bench-json bench_hitrate.json \ + --output-dir plots/ + + python scripts/analyze_topk_distribution.py \ + --profile-npz profile_output.npz \ + --bench-json bench_hitrate.json \ + --output-dir plots/ --max-segments 8 +""" + +import argparse +import json +import os +from typing import Optional + +import matplotlib +matplotlib.use("Agg") +import matplotlib.pyplot as plt +import matplotlib.colors as mcolors +import numpy as np + +# Canonical mapping mode names — shared across all profiling/analysis tools +MAPPING_MODE_NAMES = { + 0: "None", + 1: "LUT CDF", + 2: "Quantile", + 3: "Power", + 4: "Log", + 5: "Index Cache", + 6: "Asinh", + 7: "Log1p", + 8: "Trunc8", + 9: "Erf", + 10: "Tanh", + 11: "Subtract", + 13: "ExpStretch", + 14: "TopkWindow", +} + +MAPPING_MODE_FORMULAS = { + 0: "None (fp16 bucketing)", + 1: "LUT CDF (calibrated)", + 2: "Quantile (calibrated)", + 3: "Power: sign(x)*|x|^p", + 4: "Log: sign(x)*log(|x|+1)", + 5: "Index Cache", + 6: "Asinh: asinh(beta*x)", + 7: "Log1p: sign(x)*log1p(alpha*|x|)", + 8: "Trunc8: bf16 upper-8-bit bucketing", + 9: "Erf: erf(alpha*x)", + 10: "Tanh: tanh(alpha*x)", + 11: "Subtract: x - pivot (RadiK-style)", + 13: "ExpStretch: exp(alpha*x)", + 14: "TopkWindow: k-aware linear windowing", +} + + +def _mode_key_to_display(mode_key: str) -> str: + """Convert a mode key like 'mode_3', 'mode_3_Power', or 'mode_3_Power_noscale' to display name.""" + # Handle noscale suffix + noscale = mode_key.endswith("_noscale") + base_key = mode_key[:-len("_noscale")] if noscale else mode_key + suffix = " noscale" if noscale else "" + + # Handle new format: "mode_3_Power" + parts = base_key.split("_", 2) + if len(parts) >= 3: + return parts[2] + suffix # e.g. "Power noscale" + # Handle old format: "mode_3" + try: + mode_num = int(parts[1]) + return MAPPING_MODE_NAMES.get(mode_num, base_key) + suffix + except (IndexError, ValueError): + return mode_key + + +def _mode_key_to_number(mode_key: str) -> int: + """Extract the mode number from a key like 'mode_3', 'mode_3_Power', or 'mode_3_Power_noscale'.""" + parts = mode_key.split("_") + try: + return int(parts[1]) + except (IndexError, ValueError): + return -1 + + +def compute_per_segment_stats(histograms: np.ndarray) -> dict: + """Compute per-row Gini coefficient and max/mean ratio. + + Args: + histograms: [num_segments, 256] array of bin counts + + Returns: + dict with 'gini' and 'max_mean' arrays of shape [num_segments] + """ + num_seg = histograms.shape[0] + ginis = np.zeros(num_seg) + max_means = np.zeros(num_seg) + + for i in range(num_seg): + row = histograms[i].astype(np.float64) + nonzero = row[row > 0] + if len(nonzero) == 0: + continue + + max_means[i] = nonzero.max() / nonzero.mean() + + # Gini coefficient + sorted_vals = np.sort(nonzero) + n = len(sorted_vals) + index = np.arange(1, n + 1, dtype=np.float64) + ginis[i] = (2.0 * (index * sorted_vals).sum() / (n * sorted_vals.sum()) - (n + 1) / n) + ginis[i] = max(0.0, ginis[i]) + + return {"gini": ginis, "max_mean": max_means} + + +def plot_bin_distribution(histograms: np.ndarray, output_dir: str, max_segments: int = 4): + """Plot 256-bin bar chart per segment (first N segments).""" + num_seg = min(histograms.shape[0], max_segments) + for i in range(num_seg): + fig, ax = plt.subplots(figsize=(12, 4)) + ax.bar(range(256), histograms[i], width=1.0, color="steelblue", edgecolor="none") + ax.set_xlabel("Bin") + ax.set_ylabel("Count") + ax.set_title(f"Segment {i}: 256-bin histogram") + ax.set_xlim(-1, 256) + fig.tight_layout() + fig.savefig(os.path.join(output_dir, f"bin_dist_seg_{i}.png"), dpi=150) + plt.close(fig) + print(f" Saved {num_seg} bin distribution plots") + + +def plot_bin_heatmap(histograms: np.ndarray, output_dir: str): + """Heatmap: segments x bins, LogNorm colormap.""" + fig, ax = plt.subplots(figsize=(14, max(4, histograms.shape[0] * 0.15 + 1))) + # Add 1 to avoid log(0) + data = histograms.astype(np.float64) + 1 + im = ax.imshow( + data, + aspect="auto", + cmap="viridis", + norm=mcolors.LogNorm(vmin=1, vmax=data.max()), + interpolation="nearest", + ) + ax.set_xlabel("Bin") + ax.set_ylabel("Segment") + ax.set_title("Bin distribution heatmap (log scale)") + fig.colorbar(im, ax=ax, label="Count + 1") + fig.tight_layout() + fig.savefig(os.path.join(output_dir, "bin_heatmap.png"), dpi=150) + plt.close(fig) + print(" Saved bin_heatmap.png") + + +def plot_before_after_mapping( + raw_histograms: np.ndarray, + lut_table: np.ndarray, + output_dir: str, + max_segments: int = 4, +): + """Side-by-side: raw histogram vs. LUT-remapped histogram.""" + num_seg = min(raw_histograms.shape[0], max_segments) + for i in range(num_seg): + raw = raw_histograms[i] + # Remap: redistribute counts through LUT + remapped = np.zeros(256, dtype=np.float64) + for bin_idx in range(256): + new_bin = int(lut_table[bin_idx]) + remapped[new_bin] += raw[bin_idx] + + fig, axes = plt.subplots(1, 2, figsize=(16, 4), sharey=True) + axes[0].bar(range(256), raw, width=1.0, color="steelblue", edgecolor="none") + axes[0].set_title(f"Segment {i}: Raw (mode=0)") + axes[0].set_xlabel("Bin") + axes[0].set_ylabel("Count") + + axes[1].bar(range(256), remapped, width=1.0, color="darkorange", edgecolor="none") + axes[1].set_title(f"Segment {i}: After LUT remap") + axes[1].set_xlabel("Bin") + + fig.tight_layout() + fig.savefig(os.path.join(output_dir, f"mapping_comparison_{i}.png"), dpi=150) + plt.close(fig) + print(f" Saved {num_seg} mapping comparison plots") + + +def plot_summary_table( + histograms: np.ndarray, + mode_stats_data: Optional[dict], + output_dir: str, +): + """Per-segment stats table: Gini, max/mean, resolution rate.""" + stats = compute_per_segment_stats(histograms) + num_seg = histograms.shape[0] + + col_labels = ["Segment", "Gini", "Max/Mean"] + cell_data = [] + for i in range(num_seg): + cell_data.append([str(i), f"{stats['gini'][i]:.3f}", f"{stats['max_mean'][i]:.2f}"]) + + fig, ax = plt.subplots(figsize=(6, max(2, num_seg * 0.4 + 1))) + ax.axis("off") + table = ax.table(cellText=cell_data, colLabels=col_labels, loc="center", cellLoc="center") + table.auto_set_font_size(False) + table.set_fontsize(9) + table.scale(1.0, 1.3) + ax.set_title("Per-segment distribution stats", fontsize=11, pad=10) + fig.tight_layout() + fig.savefig(os.path.join(output_dir, "summary_table.png"), dpi=150, bbox_inches="tight") + plt.close(fig) + print(" Saved summary_table.png") + + +def plot_distribution_comparison(dist_histograms: dict, output_dir: str, suffix: str = "", title: str = ""): + """Overlay 256-bin distributions for different data sources (uniform, normal, real). + + Args: + dist_histograms: {"uniform": [256], "normal": [256], "real": [256], ...} + output_dir: output directory for the plot + suffix: optional suffix for output filename (e.g. "_m0") + title: optional custom title for the plot + """ + names = list(dist_histograms.keys()) + n = len(names) + if n == 0: + print(" No distribution histograms to compare") + return + + fig, axes = plt.subplots(1, n, figsize=(6 * n, 4), squeeze=False) + axes = axes[0] + + for idx, name in enumerate(names): + counts = np.array(dist_histograms[name], dtype=np.float64) + ax = axes[idx] + ax.bar(range(256), counts, width=1.0, color="steelblue", edgecolor="none") + ax.set_xlabel("Bucket") + ax.set_ylabel("Count") + ax.set_xlim(-1, 256) + ax.set_title(name) + + # Annotate with stats + nonzero = counts[counts > 0] + if len(nonzero) > 0: + mean_val = nonzero.mean() + max_val = nonzero.max() + max_mean = max_val / mean_val if mean_val > 0 else 0.0 + sorted_vals = np.sort(nonzero) + nn = len(sorted_vals) + index = np.arange(1, nn + 1, dtype=np.float64) + gini = max(0.0, 2.0 * (index * sorted_vals).sum() / (nn * sorted_vals.sum()) - (nn + 1) / nn) + nz_bins = int(len(nonzero)) + else: + max_mean = gini = 0.0 + nz_bins = 0 + + stats_text = f"gini={gini:.3f}\nmax/mean={max_mean:.2f}\nbins={nz_bins}/256" + ax.text(0.97, 0.95, stats_text, transform=ax.transAxes, + fontsize=8, verticalalignment="top", horizontalalignment="right", + bbox=dict(boxstyle="round,pad=0.3", facecolor="wheat", alpha=0.7)) + + fig.suptitle(title or "Bucket Distribution Comparison", fontsize=13) + fig.tight_layout() + fname = f"distribution_comparison{suffix}.png" + fig.savefig(os.path.join(output_dir, fname), dpi=150) + plt.close(fig) + print(f" Saved {fname}") + + +def save_bucket_table(dist_histograms: dict, output_dir: str, filename: str = "bucket_counts.csv"): + """Write a CSV table listing the count per bucket for each distribution. + + Columns: bucket, dist1, dist2, ... (256 rows, one per bucket). + """ + import csv + + names = list(dist_histograms.keys()) + if not names: + return + + path = os.path.join(output_dir, filename) + with open(path, "w", newline="") as f: + writer = csv.writer(f) + writer.writerow(["bucket"] + names) + for b in range(256): + row = [b] + [int(dist_histograms[n][b]) for n in names] + writer.writerow(row) + + # Also print a compact summary to stdout (top-20 hottest buckets per dist) + print(f" Saved {path}") + for name in names: + counts = np.array(dist_histograms[name], dtype=np.int64) + total = counts.sum() + top_idx = np.argsort(counts)[::-1][:20] + print(f" [{name}] total={total} top-20 hottest buckets:") + for rank, idx in enumerate(top_idx): + if counts[idx] == 0: + break + pct = counts[idx] / total * 100 if total > 0 else 0 + print(f" #{rank+1:2d} bucket {idx:3d}: {counts[idx]:>10d} ({pct:5.1f}%)") + + +def plot_mapping_mode_comparison(mode_stats_data: dict, output_dir: str): + """Grouped bar chart comparing modes on gini and max/mean.""" + modes = sorted(mode_stats_data.keys()) + if not modes: + print(" No histogram data to plot mode comparison") + return + + mode_labels = [] + for m in modes: + label = _mode_key_to_display(m) + param = mode_stats_data[m].get("param") + if param: + label = f"{label} ({param})" + mode_labels.append(label) + ginis = [mode_stats_data[m]["gini"] for m in modes] + max_means = [mode_stats_data[m]["max_mean_ratio"] for m in modes] + + x = np.arange(len(modes)) + width = 0.3 + + fig, ax1 = plt.subplots(figsize=(max(10, len(modes) * 0.8), 5)) + ax2 = ax1.twinx() + + bars1 = ax1.bar(x - width / 2, ginis, width, label="Gini", color="darkorange") + bars2 = ax2.bar(x + width / 2, max_means, width, label="Max/Mean", color="seagreen", alpha=0.7) + + ax1.set_xlabel("Mapping Mode") + ax1.set_ylabel("Gini") + ax2.set_ylabel("Max/Mean Ratio") + ax1.set_xticks(x) + ax1.set_xticklabels(mode_labels, rotation=30, ha="right") + ax1.set_ylim(0, 1.1) + ax1.set_title("Mapping Mode Comparison") + + # Combine legends + lines1, labels1 = ax1.get_legend_handles_labels() + lines2, labels2 = ax2.get_legend_handles_labels() + ax1.legend(lines1 + lines2, labels1 + labels2, loc="upper right") + + fig.tight_layout() + fig.savefig(os.path.join(output_dir, "mode_comparison.png"), dpi=150) + plt.close(fig) + print(" Saved mode_comparison.png") + + +def main(): + parser = argparse.ArgumentParser(description="Analyze TopK bucket sort distribution") + parser.add_argument("--profile-npz", type=str, default=None, + help="Path to .npz from profile_topk_distribution.py") + parser.add_argument("--bench-json", type=str, default=None, + help="Path to JSON from bench_topk.py") + parser.add_argument("--output-dir", type=str, default="plots", + help="Directory for output plots") + parser.add_argument("--max-segments", type=int, default=4, + help="Max segments for per-segment plots") + parser.add_argument("--real-histograms", type=str, default=None, + help="Path to .npy raw_histograms from calibrate_topk.py (real-data bucket counts)") + args = parser.parse_args() + + if args.profile_npz is None and args.bench_json is None and args.real_histograms is None: + parser.error("At least one of --profile-npz, --bench-json, or --real-histograms is required") + + os.makedirs(args.output_dir, exist_ok=True) + print(f"Output directory: {args.output_dir}") + + raw_histograms = None + lut_table = None + mode_stats_data = None + + # Load profile data + if args.profile_npz: + print(f"\nLoading profile data from {args.profile_npz}") + data = np.load(args.profile_npz, allow_pickle=True) + if "raw_histograms" in data: + raw_histograms = data["raw_histograms"] + print(f" raw_histograms: {raw_histograms.shape}") + if "aggregate_lut" in data: + lut_table = data["aggregate_lut"] + print(f" aggregate_lut: {lut_table.shape}") + elif "lut_tables" in data: + # Use first LUT if aggregate not available + lut_table = data["lut_tables"] + if lut_table.ndim > 1: + lut_table = lut_table[0] + print(f" lut_table: {lut_table.shape}") + + # Load bench data + dist_histograms = {} # {distribution_name: [256] counts} for comparison plot + mode_histograms = {} # {mode_key: {dist_name: [256]}} for per-mode plots + + if args.bench_json: + print(f"\nLoading benchmark data from {args.bench_json}") + with open(args.bench_json) as f: + bench_data = json.load(f) + + if bench_data and isinstance(bench_data, list): + # Use first config entry for histogram mode visualization + entry = bench_data[0] + if "histograms" in entry: + mode_stats_data = entry["histograms"] + print(f" Histogram modes: {list(mode_stats_data.keys())}") + + # Extract raw_counts per distribution from bench entries + for entry in bench_data: + dist_name = entry.get("distribution", "unknown") + hist_data = entry.get("histogram", {}) + if "raw_counts" in hist_data and dist_name not in dist_histograms: + dist_histograms[dist_name] = hist_data["raw_counts"] + print(f" Loaded histogram for distribution: {dist_name}") + + # Extract per-mode histograms from histograms data + mode_histograms = {} # {mode_key: {dist_name: [256]}} + for entry in bench_data: + dist_name = entry.get("distribution", "unknown") + histograms_data = entry.get("histograms", {}) + for mode_key, mode_data in histograms_data.items(): + if isinstance(mode_data, dict) and "raw_counts" in mode_data: + if mode_key not in mode_histograms: + mode_histograms[mode_key] = {} + if dist_name not in mode_histograms[mode_key]: + mode_histograms[mode_key][dist_name] = mode_data["raw_counts"] + if mode_histograms: + print(f" Loaded per-mode histograms for: {sorted(mode_histograms.keys())}") + + # Load real-data histograms from .npy (calibrate_topk.py output) + real_counts = None + if args.real_histograms: + print(f"\nLoading real-data histograms from {args.real_histograms}") + real_hists = np.load(args.real_histograms) # [num_samples, 256] + real_counts = real_hists.sum(axis=0).tolist() # aggregate across samples + dist_histograms["real"] = real_counts + print(f" real_histograms shape: {real_hists.shape}, aggregated to [256]") + + # Generate plots + if raw_histograms is not None: + print("\nGenerating histogram plots...") + plot_bin_distribution(raw_histograms, args.output_dir, args.max_segments) + plot_bin_heatmap(raw_histograms, args.output_dir) + plot_summary_table(raw_histograms, mode_stats_data, args.output_dir) + + if lut_table is not None: + print("\nGenerating before/after mapping comparison...") + plot_before_after_mapping(raw_histograms, lut_table, args.output_dir, args.max_segments) + + if mode_stats_data is not None: + print("\nGenerating mode comparison plot...") + plot_mapping_mode_comparison(mode_stats_data, args.output_dir) + + if dist_histograms: + print("\nGenerating distribution comparison plot (raw/unmapped)...") + plot_distribution_comparison(dist_histograms, args.output_dir) + print("\nSaving bucket count table (raw/unmapped)...") + save_bucket_table(dist_histograms, args.output_dir) + + # Per-mode distribution plots and tables + if mode_histograms: + print("\nGenerating per-mode distribution plots and tables...") + for mode_key in sorted(mode_histograms): + mname = _mode_key_to_display(mode_key) + mode_num = _mode_key_to_number(mode_key) + mformula = MAPPING_MODE_FORMULAS.get(mode_num, mname) + # Include hyperparameter value in title if available + param_str = "" + if mode_stats_data and mode_key in mode_stats_data: + param = mode_stats_data[mode_key].get("param") + if param: + param_str = f" [{param}]" + mode_suffix = mname.lower().replace(" ", "_") + plot_distribution_comparison( + mode_histograms[mode_key], args.output_dir, + suffix=f"_{mode_suffix}", + title=f"Bucket Distribution — {mname}{param_str} ({mformula})", + ) + save_bucket_table( + mode_histograms[mode_key], args.output_dir, + filename=f"bucket_counts_{mode_suffix}.csv", + ) + + print(f"\nDone. All outputs saved to {args.output_dir}/") + + +if __name__ == "__main__": + main() diff --git a/benchmarks/autotune_topk_mapping.py b/benchmarks/autotune_topk_mapping.py new file mode 100644 index 0000000..e103953 --- /dev/null +++ b/benchmarks/autotune_topk_mapping.py @@ -0,0 +1,439 @@ +""" +Auto-tune TopK mapping hyperparameters by profiled kernel latency. + +For each (mode, hyperparameter) combo in the sweep grid, this script picks +the hyperparameter whose remapped score distribution produces the lowest +*unfused* topk kernel latency. The measurement is a split-phase: + + 1. topk_remap_only(x, mode, power) → float32 buffer [NOT timed] + 2. topk_output_sglang(remapped) [TIMED] + +Timing only step 2 isolates the Stage-2 radix cost, which is what bucket +uniformity actually affects. The remap cost is the same constant regardless +of power, so it would only pollute the ranking. + +Non-arithmetic baselines (MAPPING_LUT_CDF=1, MAPPING_QUANTILE=2, +MAPPING_TRUNC8=8) route their mapping through compute_stage1_bin, not +apply_transform, so split-phase is a no-op for them. Those are timed via +the fused kernel and marked `timing_mode="fused_fallback"` in the output. + +Distribution statistics (gini, max/mean, counter-based Stage-2 cost) are +still collected for diagnostics, but they do NOT drive the ranking — the +ranking is purely latency-driven. + +Usage: + python benchmarks/autotune_topk_mapping.py \\ + --topk-val 2048 --batch-size 4 --seq-len 65536 --num-kv-heads 8 \\ + --real-histograms calibration/raw_histograms.npy \\ + --output-json autotune_results.json +""" + +import argparse +import json +import math +from typing import Dict, List, Optional + +import numpy as np +import torch + +from bench_topk import ( + make_topk_inputs, + bench_kernel, + compute_histogram_stats, + scores_from_histogram, +) +from vortex_torch_C import ( + topk_output_sglang, + topk_output_sglang_fused, + topk_remap_only, + topk_profile_histogram, + topk_profile_counters, +) + + +# Modes where topk_mapping.cuh::apply_transform is a genuine value-space +# transform (power / asinh / log / log1p / erf / tanh / subtract / exp_stretch, +# plus the top-spreading shift_pow2 / shift_pow3 / linear_steep family) and +# also mode 0 (identity). For these the split-phase `remap_only + unfused +# topk` is correct. Modes 1/2/8 (LUT_CDF / QUANTILE / TRUNC8) apply their +# mapping inside compute_stage1_bin, so split-phase is a no-op. +ARITHMETIC_MODES = {0, 3, 4, 6, 7, 9, 10, 11, 13, 15, 16, 17, 18, 19, 20} + + +# Only parametric modes need auto-tuning. Mode 0 (none) and mode 4 (log) +# have no knob; mode 0 is always the baseline. Sweep grids widened so the +# autotune actually explores the tails of each transform. +SWEEP_GRID: Dict[int, List[float]] = { + 3: [0.1, 0.5, 1.0, 2.0, 4.0, 5.0, 9.0], # power: p + 6: [0.1, 0.5, 1.0, 2.0, 4.0, 8.0, 16.0], # asinh: beta + 7: [0.1, 0.5, 1.0, 2.0, 4.0, 8.0, 16.0], # log1p: alpha + 9: [0.1, 0.5, 1.0, 2.0, 4.0, 8.0, 16.0], # erf: alpha + 10: [0.1, 0.5, 1.0, 2.0, 4.0, 8.0, 16.0], # tanh: alpha + 11: [-2.0, -1.0, -0.5, 0.0, 0.5, 1.0, 2.0], # subtract: pivot + 13: [0.5, 1.0, 2.0, 4.0, 8.0, 16.0, 32.0], # exp_stretch: alpha + 15: [-1.0, -0.5, -0.25, 0.0, 0.25, 0.5, 1.0], # shift_pow2: pivot + 16: [-4.0, -2.0, -1.0, -0.5, 0.0, 0.5, 1.0, 2.0], # shift_pow3: pivot (widened) + 17: [0.5, 1.0, 2.0, 4.0, 8.0, 16.0, 32.0], # linear_steep: k + 18: [-1.0, -0.5, -0.25, 0.0, 0.25, 0.5, 1.0], # half_square: pivot + 19: [-1.0, -0.5, -0.25, 0.0, 0.25, 0.5, 1.0], # half_cube: pivot + # dense_mant clamp: sweep a wide range because real attention scores + # can span [-400, +200] on some models (raw logits), not just [0, 1]. + 20: [0.0, 1.0, 5.0, 10.0, 20.0, 50.0, 100.0], # dense_mant: clamp pivot +} + +PARAM_NAME = { + 3: "p", 6: "beta", 7: "alpha", 9: "alpha", 10: "alpha", + 11: "pivot", 13: "alpha", + 15: "pivot", 16: "pivot", 17: "k", + 18: "pivot", 19: "pivot", + 20: "clamp", +} +MODE_NAMES = { + 0: "none", 1: "lut_cdf", 2: "quantile", + 3: "power", 4: "log", 6: "asinh", 7: "log1p", + 8: "trunc8", 9: "erf", 10: "tanh", 11: "subtract", 13: "exp_stretch", + 15: "shift_pow2", 16: "shift_pow3", 17: "linear_steep", + 18: "half_square", 19: "half_cube", + 20: "dense_mant", +} + +# Non-parametric modes — no knob to sweep; timed once as a reference point. +# LUT_CDF (1) and QUANTILE (2) are added here at runtime when the caller +# passes --lut-path / --quantiles-path. +BASELINES = [(0, 0.5), (4, 0.5), (8, 0.5)] + + +# ---------- Real-distribution score generation ---------- +# _build_bin_range_table / scores_from_histogram now live in bench_topk.py +# so both autotune and bench_topk draw scores from the same sampler. + + +def _make_real_inputs(args, histogram: np.ndarray) -> dict: + eff_bs = args.batch_size * args.num_kv_heads + num_pages_per_seg = math.ceil(args.seq_len / args.page_size) + total_dense = eff_bs * num_pages_per_seg + sparse_per_seg = min(args.topk_val + args.reserved_bos + args.reserved_eos, num_pages_per_seg) + + dense_kv_indptr = torch.arange( + 0, (eff_bs + 1) * num_pages_per_seg, num_pages_per_seg, + dtype=torch.int32, device="cuda", + ) + sparse_kv_indptr = torch.arange( + 0, (eff_bs + 1) * sparse_per_seg, sparse_per_seg, + dtype=torch.int32, device="cuda", + ) + dense_kv_indices = torch.arange(total_dense, dtype=torch.int32, device="cuda") + sparse_kv_indices = torch.zeros(eff_bs * sparse_per_seg, dtype=torch.int32, device="cuda") + x = scores_from_histogram(histogram, total_dense, device="cuda", + score_dtype=torch.bfloat16) + remapped = torch.empty(total_dense, dtype=torch.float32, device="cuda").reshape(x.shape) + + return { + "x": x, + "remapped": remapped, + "dense_kv_indptr": dense_kv_indptr, + "sparse_kv_indptr": sparse_kv_indptr, + "dense_kv_indices": dense_kv_indices, + "sparse_kv_indices": sparse_kv_indices, + "eff_batch_size": eff_bs, + "num_pages_per_seg": num_pages_per_seg, + "sparse_per_seg": sparse_per_seg, + } + + +def _ensure_remapped_buffer(inputs: dict) -> torch.Tensor: + """Lazy-allocate a float32 buffer matching x.shape for the split-phase.""" + buf = inputs.get("remapped") + if buf is None: + x = inputs["x"] + buf = torch.empty(x.numel(), dtype=torch.float32, device=x.device).reshape(x.shape) + inputs["remapped"] = buf + return buf + + +# ---------- Latency-based evaluation ---------- + +def _time_fused(inputs, args, mode: int, power: float) -> dict: + """Fused remap+topk kernel latency (used as fallback for modes 1/2/8).""" + eff_bs = inputs["eff_batch_size"] + pages_per_seg = inputs["num_pages_per_seg"] + inputs["sparse_kv_indices"].zero_() + lut_t = getattr(args, "_mapping_lut", None) if mode == 1 else None + q_t = getattr(args, "_mapping_quantiles", None) if mode == 2 else None + call_args = ( + inputs["x"], + inputs["dense_kv_indptr"], + inputs["sparse_kv_indptr"], + inputs["dense_kv_indices"], + inputs["sparse_kv_indices"], + eff_bs, + args.topk_val, + args.reserved_bos, + args.reserved_eos, + pages_per_seg, + mode, + power, + lut_t, + q_t, + ) + return bench_kernel(topk_output_sglang_fused, call_args, + warmup=args.warmup, repeat=args.repeat) + + +def _time_unfused_on_remapped(inputs, args, mode: int, power: float) -> dict: + """Time the unfused topk kernel on pre-remapped scores. + + For mode 0 the original scores are used directly. For every other + arithmetic mode we run topk_remap_only once (not timed) into a + pre-allocated float32 buffer, then time topk_output_sglang on that + buffer with bench_kernel's warmup + repeat loop. This isolates the + Stage-2 radix cost from the remap pass. + """ + eff_bs = inputs["eff_batch_size"] + pages_per_seg = inputs["num_pages_per_seg"] + + if mode == 0: + src = inputs["x"] + else: + remapped = _ensure_remapped_buffer(inputs) + topk_remap_only( + inputs["x"], + inputs["dense_kv_indptr"], + remapped, + eff_bs, + args.reserved_bos, + args.reserved_eos, + mode, + float(power), + ) + torch.cuda.synchronize() + src = remapped + + inputs["sparse_kv_indices"].zero_() + call_args = ( + src, + inputs["dense_kv_indptr"], + inputs["sparse_kv_indptr"], + inputs["dense_kv_indices"], + inputs["sparse_kv_indices"], + eff_bs, + args.topk_val, + args.reserved_bos, + args.reserved_eos, + pages_per_seg, + ) + return bench_kernel(topk_output_sglang, call_args, + warmup=args.warmup, repeat=args.repeat) + + +def _time_mode(inputs, args, mode: int, power: float) -> tuple: + """Returns (latency_dict, timing_mode_str).""" + if mode in ARITHMETIC_MODES: + return _time_unfused_on_remapped(inputs, args, mode, power), "unfused_on_remapped" + return _time_fused(inputs, args, mode, power), "fused_fallback" + + +def _collect_diagnostics(inputs, args, mode: int, power: float) -> dict: + """Optional distribution/counter stats for reporting only (post-timing).""" + eff_bs = inputs["eff_batch_size"] + pages_per_seg = inputs["num_pages_per_seg"] + diag = {} + lut_t = getattr(args, "_mapping_lut", None) if mode == 1 else None + q_t = getattr(args, "_mapping_quantiles", None) if mode == 2 else None + + if args.collect_stats: + hist = torch.zeros(eff_bs, 256, dtype=torch.int32, device="cuda") + topk_profile_histogram( + inputs["x"], inputs["dense_kv_indptr"], hist, + eff_bs, args.reserved_bos, args.reserved_eos, + mode, power, lut_t, q_t, + ) + torch.cuda.synchronize() + diag.update(compute_histogram_stats(hist)) + + counter_buf = torch.zeros(eff_bs, 6, dtype=torch.int32, device="cuda") + inputs["sparse_kv_indices"].zero_() + topk_profile_counters( + inputs["x"], + inputs["dense_kv_indptr"], + inputs["sparse_kv_indptr"], + inputs["dense_kv_indices"], + inputs["sparse_kv_indices"], + counter_buf, + eff_bs, + args.topk_val, + args.reserved_bos, + args.reserved_eos, + pages_per_seg, + mode, + power, + lut_t, + q_t, + ) + torch.cuda.synchronize() + c = counter_buf.float() + diag["threshold_bin_mean"] = c[:, 0].mean().item() + diag["num_equal_mean"] = c[:, 2].mean().item() + diag["refine_rounds_mean"] = c[:, 4].mean().item() + # selected_from_thr = topk_val - num_above (clamped >= 0). Used as + # a tie-breaker by bench_topk._load_autotune_hparams when several + # modes have indistinguishable latency. + sel_from_thr = (float(args.topk_val) - c[:, 1]).clamp(min=0.0) + diag["selected_from_thr_mean"] = sel_from_thr.mean().item() + + return diag + + +def _run_sweep(args, inputs, dist_label: str) -> List[dict]: + results = [] + + # Baselines: time them but their param is fixed. + for mode, power in BASELINES: + lat, tmode = _time_mode(inputs, args, mode, power) + entry = { + "mode": mode, + "mode_name": MODE_NAMES.get(mode, f"m{mode}"), + "param_name": "(baseline)", + "param": power, + "distribution": dist_label, + "timing_mode": tmode, + "latency_ms": lat["mean_ms"], + "latency_median_ms": lat["median_ms"], + "latency_min_ms": lat["min_ms"], + } + entry.update(_collect_diagnostics(inputs, args, mode, power)) + results.append(entry) + print( + f" mode={mode:>2d} ({MODE_NAMES[mode]:>10s}) baseline " + f" [{tmode:>20s}] latency={lat['mean_ms']:.4f} ms" + ) + + # Parametric sweep, one (mode, param) combo at a time. + for mode, values in SWEEP_GRID.items(): + pname = PARAM_NAME[mode] + for val in values: + lat, tmode = _time_mode(inputs, args, mode, float(val)) + entry = { + "mode": mode, + "mode_name": MODE_NAMES.get(mode, f"m{mode}"), + "param_name": pname, + "param": float(val), + "distribution": dist_label, + "timing_mode": tmode, + "latency_ms": lat["mean_ms"], + "latency_median_ms": lat["median_ms"], + "latency_min_ms": lat["min_ms"], + } + entry.update(_collect_diagnostics(inputs, args, mode, float(val))) + results.append(entry) + print( + f" mode={mode:>2d} ({MODE_NAMES[mode]:>10s}) {pname}={val:<6.3f} " + f" [{tmode:>20s}] latency={lat['mean_ms']:.4f} ms" + ) + + return results + + +def _print_ranked(results: List[dict]) -> None: + ranked = sorted(results, key=lambda r: r["latency_ms"]) + header = ( + f"{'Rank':>4s} {'Mode':<12s} {'Param':<14s} {'Dist':<10s} {'Latency (ms)':>14s}" + ) + print("\n" + "=" * len(header)) + print("TopK auto-tune results (ranked by measured kernel latency, lower is better)") + print("=" * len(header)) + print(header) + print("-" * len(header)) + for i, r in enumerate(ranked): + param_str = f"{r['param_name']}={r['param']}" if r["param_name"] != "(baseline)" else "(baseline)" + print( + f"{i + 1:4d} {r['mode_name']:<12s} {param_str:<14s} " + f"{r['distribution']:<10s} {r['latency_ms']:14.4f}" + ) + print("=" * len(header)) + + # Best per mode. + best: Dict[int, dict] = {} + for r in results: + m = r["mode"] + if m not in best or r["latency_ms"] < best[m]["latency_ms"]: + best[m] = r + print("\nBest per mode (by latency):") + for m in sorted(best.keys()): + r = best[m] + param_str = f"{r['param_name']}={r['param']}" if r["param_name"] != "(baseline)" else "(baseline)" + print( + f" mode {m:>2d} ({r['mode_name']:>5s}): {param_str:<16s} " + f"latency={r['latency_ms']:.4f} ms" + ) + + +def main(): + parser = argparse.ArgumentParser("TopK mapping hyperparameter auto-tuner (latency-driven)") + parser.add_argument("--batch-size", type=int, default=4) + parser.add_argument("--num-kv-heads", type=int, default=8) + parser.add_argument("--seq-len", type=int, default=65536) + parser.add_argument("--topk-val", type=int, default=2048) + parser.add_argument("--page-size", type=int, default=16) + parser.add_argument("--reserved-bos", type=int, default=1) + parser.add_argument("--reserved-eos", type=int, default=2) + parser.add_argument("--distributions", type=str, nargs="+", + default=["normal"], + help="Synthetic distributions when --real-histograms is not set.") + parser.add_argument("--real-histograms", type=str, default=None, + help="Path to raw_histograms.npy from calibration.") + parser.add_argument("--warmup", type=int, default=20) + parser.add_argument("--repeat", type=int, default=100) + parser.add_argument("--collect-stats", action="store_true", + help="Also collect histogram + counter diagnostics (post-timing, no cost).") + parser.add_argument("--output-json", type=str, default=None) + parser.add_argument("--lut-path", type=str, default=None, + help="Path to .npy uint8[256] LUT for MAPPING_LUT_CDF (mode 1).") + parser.add_argument("--quantiles-path", type=str, default=None, + help="Path to .npy float32[256] quantile table for MAPPING_QUANTILE (mode 2).") + args = parser.parse_args() + + # Modes 1 (LUT_CDF) and 2 (Quantile) are no longer evaluated — they + # don't use topk_mapping::apply_transform (their mapping is done inside + # compute_stage1_bin) and are kept out of the comparison entirely. + args._mapping_lut = None + args._mapping_quantiles = None + + real_histogram: Optional[np.ndarray] = None + if args.real_histograms: + raw = np.load(args.real_histograms) + real_histogram = raw.sum(axis=0) if raw.ndim > 1 else raw + + all_results: List[dict] = [] + + if real_histogram is not None: + inputs = _make_real_inputs(args, real_histogram) + print("\n=== Latency sweep on REAL distribution " + f"(batch={args.batch_size} heads={args.num_kv_heads} seq={args.seq_len} topk={args.topk_val}) ===") + all_results += _run_sweep(args, inputs, "real") + else: + for dist in args.distributions: + inputs = make_topk_inputs( + batch_size=args.batch_size, + num_kv_heads=args.num_kv_heads, + seq_len=args.seq_len, + page_size=args.page_size, + topk_val=args.topk_val, + reserved_bos=args.reserved_bos, + reserved_eos=args.reserved_eos, + score_dtype=torch.bfloat16, + distribution=dist, + ) + print(f"\n=== Latency sweep on synthetic dist={dist} ===") + all_results += _run_sweep(args, inputs, dist) + + _print_ranked(all_results) + + if args.output_json: + with open(args.output_json, "w") as f: + json.dump(all_results, f, indent=2) + print(f"\nResults saved to {args.output_json}") + + +if __name__ == "__main__": + main() diff --git a/benchmarks/bench_topk.py b/benchmarks/bench_topk.py new file mode 100644 index 0000000..3653c55 --- /dev/null +++ b/benchmarks/bench_topk.py @@ -0,0 +1,1035 @@ +""" +TopK kernel benchmarking suite. + +Lean rewrite after the remap-benchmark refactor. Exposes three public +helpers used by autotune_topk_mapping.py (make_topk_inputs, bench_kernel, +compute_histogram_stats) and a CLI with two modes: + + - default : time the baseline (unmapped) kernel and the fused + kernel across a grid of (mode, power, batch, seq_len, + topk_val, distribution) configs. + - --remap-bench: time baseline vs fused vs split-phase (remap-only + + unmapped-topk-on-remapped) and report threshold stats + from topk_profile_counters. +""" + +import argparse +import json +import math +import statistics +from typing import Dict, List + +import numpy as np +import torch + +from vortex_torch_C import ( + topk_output, # full CUB BlockRadixSort topk (max 4096 pages/seg) + topk_output_sglang, # 2-stage radix approximate topk (unmapped baseline) + topk_output_sglang_fused, # fused remap + 2-stage radix topk + topk_output_sglang_ori, # original SGLang reference kernel + topk_remap_only, # standalone value-space remap + topk_profile_histogram, + topk_profile_counters, +) + +# topk_output's template ladder tops out at 8192 pages per segment +# (see topk.cu::topk_output, branches up to <= 8192). Runs larger than +# that hit TORCH_CHECK(false). +TOPK_OUTPUT_MAX_PAGES = 8192 + +# The ori kernel has TopK baked in at compile time. If setup.py was built +# with a different value, calls will fail; this is the topk_val that +# matches the current build of topk_sglang_ori.cu. +TOPK_ORI_BAKED_IN = 30 + + +MAPPING_MODE_NAMES = { + 0: "None", + 1: "LUT_CDF", + 2: "Quantile", + 3: "Power", + 4: "Log", + 6: "Asinh", + 7: "Log1p", + 8: "Trunc8", + 9: "Erf", + 10: "Tanh", + 11: "Subtract", + 13: "ExpStretch", + 15: "ShiftPow2", + 16: "ShiftPow3", + 17: "LinearSteep", + 18: "HalfSquare", + 19: "HalfCube", + 20: "DenseMant", +} + +# Modes whose value-space transform is a real apply_transform() pass. Modes +# 1 (LUT_CDF), 2 (QUANTILE) and 8 (TRUNC8) apply their mapping inside +# compute_stage1_bin, not apply_transform — so `topk_remap_only` cannot +# reproduce them (the fp32 buffer would just contain the raw values). For +# those modes the split-phase numbers are N/A; only the fused kernel is a +# meaningful reference. +ARITHMETIC_MODES = {0, 3, 4, 6, 7, 9, 10, 11, 13, 15, 16, 17, 18, 19, 20} + + +_AUTOTUNE_TIE_TOLERANCE_MS = 0.0002 # ≈ CUDA event noise floor at this kernel size + + +def _load_autotune_hparams(path: str) -> Dict[int, float]: + """Load per-mode best hyperparameters from an autotune_results.json. + + The JSON is produced by autotune_topk_mapping.py and contains a list of + {mode, param, latency_ms, num_equal_mean, selected_from_thr_mean, ...} + entries. For each mode we group all sweep entries, find the lowest + latency, then break ties (within `_AUTOTUNE_TIE_TOLERANCE_MS`) by: + + 1. Smallest `num_equal_mean` (= thr_size). Stage-2 cost is O(thr_size), + so a smaller threshold bin is a better proxy for real fused + latency than the noisy `latency_ms` measurement. + 2. Smallest `selected_from_thr_mean`. How many pages the topk has to + pull from the threshold bin during refinement. + 3. Lowest `latency_ms` again (final fallback). + + Modes with no parametric sweep (0=None, 4=Log) return a dummy 0.5; + the caller should override to taste. + """ + with open(path) as f: + data = json.load(f) + grouped: Dict[int, list] = {} + for r in data: + m = r.get("mode") + lat = r.get("latency_ms") + if m is None or lat is None: + continue + grouped.setdefault(m, []).append(r) + + best: Dict[int, dict] = {} + for m, entries in grouped.items(): + min_lat = min(e["latency_ms"] for e in entries) + contenders = [ + e for e in entries + if e["latency_ms"] - min_lat <= _AUTOTUNE_TIE_TOLERANCE_MS + ] + # Tie-breakers: lowest num_equal_mean, then lowest sel_thr, + # then lowest latency. Missing diagnostic fields → +inf so they + # lose tie-breaks (we still keep them as fallback candidates). + def _rank_key(e): + return ( + e.get("num_equal_mean", float("inf")), + e.get("selected_from_thr_mean", float("inf")), + e["latency_ms"], + ) + best[m] = min(contenders, key=_rank_key) + + return {m: float(r["param"]) for m, r in best.items()} + + +def _key_to_fp16(key: int) -> np.float16: + """Invert convert_to_uint8's sign-flip for a single 16-bit key.""" + bits = (key & 0x7FFF) if key >= 0x8000 else ((~key) & 0xFFFF) + return np.array([bits], dtype=np.uint16).view(np.float16)[0] + + +def build_bin_range_table(): + """Per-bin (lo, hi) fp16 value tables for the 256 Stage-1 radix bins. + + Shared by the real-distribution samplers in bench_topk.py and + autotune_topk_mapping.py so both scripts generate identical inputs. + """ + all_bits = np.arange(65536, dtype=np.uint16) + all_fp16 = all_bits.view(np.float16) + keys = np.where( + (all_bits & 0x8000).astype(bool), + (~all_bits).astype(np.uint16), + all_bits | np.uint16(0x8000), + ) + bins = (keys >> 8).astype(np.uint8) + all_f32 = all_fp16.astype(np.float32) + valid = np.isfinite(all_f32) + bin_lo = np.full(256, np.inf, dtype=np.float32) + bin_hi = np.full(256, -np.inf, dtype=np.float32) + for b in range(256): + mask = (bins == b) & valid + if mask.any(): + vals = all_f32[mask] + bin_lo[b] = vals.min() + bin_hi[b] = vals.max() + empty = bin_lo > bin_hi + for b in np.where(empty)[0]: + val = float(_key_to_fp16((int(b) << 8) | 0x80)) + bin_lo[b] = val + bin_hi[b] = val + return bin_lo, bin_hi + + +def scores_from_histogram( + histogram: np.ndarray, + total_pages: int, + device: str = "cuda", + score_dtype: torch.dtype = torch.bfloat16, +) -> torch.Tensor: + """Sample `total_pages` scores whose Stage-1 bucket distribution matches + the given 256-bin histogram (produced by calibration). Each bucket is + sampled uniformly over the fp16 range that maps into it.""" + bin_lo, bin_hi = build_bin_range_table() + counts = histogram.astype(np.float64) + total = counts.sum() + if total == 0: + return torch.zeros(total_pages, 1, 1, dtype=score_dtype, device=device) + probs = counts / total + bin_indices = np.random.choice(256, size=total_pages, p=probs) + lo = bin_lo[bin_indices] + hi = bin_hi[bin_indices] + rand = np.random.uniform(0, 1, size=total_pages).astype(np.float32) + scores_f32 = lo + rand * (hi - lo) + return torch.from_numpy(scores_f32).to(score_dtype).reshape(total_pages, 1, 1).to(device) + + +def make_topk_inputs( + batch_size: int, + num_kv_heads: int, + seq_len: int, + page_size: int, + topk_val: int, + reserved_bos: int, + reserved_eos: int, + score_dtype: torch.dtype, + distribution: str = "normal", + real_histogram: np.ndarray = None, + device: str = "cuda", +) -> dict: + """Synthesize CSR-formatted paged attention inputs for kernel timing. + + When `real_histogram` is provided, scores are drawn from that 256-bin + distribution (ignoring `distribution`) so the benchmark sees the same + Stage-1 bucket distribution as the calibrated model. + """ + eff_batch_size = batch_size * num_kv_heads + num_pages_per_seg = math.ceil(seq_len / page_size) + total_dense_pages = eff_batch_size * num_pages_per_seg + sparse_per_seg = min(topk_val + reserved_bos + reserved_eos, num_pages_per_seg) + total_sparse_pages = eff_batch_size * sparse_per_seg + + dense_kv_indptr = torch.arange( + 0, (eff_batch_size + 1) * num_pages_per_seg, num_pages_per_seg, + dtype=torch.int32, device=device, + ) + sparse_kv_indptr = torch.arange( + 0, (eff_batch_size + 1) * sparse_per_seg, sparse_per_seg, + dtype=torch.int32, device=device, + ) + dense_kv_indices = torch.arange(total_dense_pages, dtype=torch.int32, device=device) + sparse_kv_indices = torch.zeros(total_sparse_pages, dtype=torch.int32, device=device) + + if real_histogram is not None: + x = scores_from_histogram(real_histogram, total_dense_pages, device=device, + score_dtype=score_dtype) + elif distribution == "normal": + x = torch.randn(total_dense_pages, 1, 1, device=device).to(score_dtype) + elif distribution == "lognormal": + x = torch.randn(total_dense_pages, 1, 1, device=device).exp().to(score_dtype) + elif distribution == "uniform": + x = torch.rand(total_dense_pages, 1, 1, device=device).to(score_dtype) + elif distribution == "bucket_uniform": + # Uniform across all 256 fp16 radix buckets. Random uint16 bit + # patterns → interpret as fp16. NaN/Inf patterns collapse to ±0. + raw_bits = torch.randint(0, 65536, (total_dense_pages,), dtype=torch.int32, device=device) + abs_bits = raw_bits & 0x7FFF + raw_bits[abs_bits >= 0x7C00] = raw_bits[abs_bits >= 0x7C00] & 0x8000 + x = raw_bits.to(torch.int16).view(torch.float16).float().reshape(total_dense_pages, 1, 1).to(score_dtype) + else: + raise ValueError(f"Unknown distribution: {distribution}") + + return { + "x": x, + "dense_kv_indptr": dense_kv_indptr, + "sparse_kv_indptr": sparse_kv_indptr, + "dense_kv_indices": dense_kv_indices, + "sparse_kv_indices": sparse_kv_indices, + "eff_batch_size": eff_batch_size, + "num_pages_per_seg": num_pages_per_seg, + "sparse_per_seg": sparse_per_seg, + } + + +def bench_kernel(kernel_fn, args, warmup: int = 10, repeat: int = 100) -> dict: + """Time a kernel with CUDA events. Returns latency stats in ms.""" + for _ in range(warmup): + kernel_fn(*args) + torch.cuda.synchronize() + + start_events = [torch.cuda.Event(enable_timing=True) for _ in range(repeat)] + end_events = [torch.cuda.Event(enable_timing=True) for _ in range(repeat)] + for i in range(repeat): + start_events[i].record() + kernel_fn(*args) + end_events[i].record() + torch.cuda.synchronize() + + times = [s.elapsed_time(e) for s, e in zip(start_events, end_events)] + return { + "mean_ms": statistics.mean(times), + "median_ms": statistics.median(times), + "std_ms": statistics.stdev(times) if len(times) > 1 else 0.0, + "min_ms": min(times), + "max_ms": max(times), + } + + +def compute_histogram_stats(histograms: torch.Tensor) -> dict: + """Bin distribution statistics from histogram tensor [B, 256].""" + h = histograms.float() + h_sum = h.sum(dim=0) # [256] + nonzero = h_sum[h_sum > 0] + if len(nonzero) == 0: + return { + "max_mean_ratio": 0.0, "std": 0.0, "gini": 0.0, + "num_nonzero_bins": 0, "entropy": 0.0, "effective_bins": 0.0, + } + mean_val = nonzero.mean().item() + max_val = nonzero.max().item() + std_val = nonzero.std().item() if len(nonzero) > 1 else 0.0 + sorted_bins = nonzero.sort().values + n = len(sorted_bins) + idx = torch.arange(1, n + 1, device=sorted_bins.device, dtype=torch.float32) + gini = (2.0 * (idx * sorted_bins).sum() / (n * sorted_bins.sum()) - (n + 1) / n).item() + p = nonzero / nonzero.sum() + entropy = -(p * p.log2()).sum().item() + return { + "max_mean_ratio": max_val / mean_val if mean_val > 0 else 0.0, + "std": std_val, + "gini": max(0.0, gini), + "num_nonzero_bins": int(len(nonzero)), + "entropy": entropy, + "effective_bins": 2 ** entropy, + } + + +def _collect_threshold_stats(inputs, topk_val, pages_per_seg, args, mode: int, power: float) -> dict: + """Run topk_profile_counters + topk_profile_histogram once and aggregate + threshold-bin / bucket-distribution stats. Profile kernels run AFTER all + latency measurements, so their writes never contaminate timing. + """ + eff_bs = inputs["eff_batch_size"] + counter_buf = torch.zeros(eff_bs, 6, dtype=torch.int32, device="cuda") + inputs["sparse_kv_indices"].zero_() + lut_t = getattr(args, "_mapping_lut", None) if mode == 1 else None + q_t = getattr(args, "_mapping_quantiles", None) if mode == 2 else None + topk_profile_counters( + inputs["x"], + inputs["dense_kv_indptr"], + inputs["sparse_kv_indptr"], + inputs["dense_kv_indices"], + inputs["sparse_kv_indices"], + counter_buf, + eff_bs, + topk_val, + args.reserved_bos, + args.reserved_eos, + pages_per_seg, + mode, + power, + lut_t, + q_t, + ) + torch.cuda.synchronize() + c = counter_buf.float() + + # Run the 256-bin histogram profile to compute the rank_target_bins + # metric: how many bins ABOVE the threshold bin (i.e. the bins whose + # pages are selected without Stage-2 refinement) actually contain + # selected pages, and the mean pages-per-such-bin. + hist_buf = torch.zeros(eff_bs, 256, dtype=torch.int32, device="cuda") + topk_profile_histogram( + inputs["x"], + inputs["dense_kv_indptr"], + hist_buf, + eff_bs, + args.reserved_bos, + args.reserved_eos, + mode, + power, + lut_t, + q_t, + ) + torch.cuda.synchronize() + + thr_idx = counter_buf[:, 0].to(torch.int64) # [eff_bs] + hist = hist_buf.to(torch.int64) # [eff_bs, 256] + bin_ids = torch.arange(256, device="cuda", dtype=torch.int64).unsqueeze(0) # [1, 256] + above_mask = bin_ids > thr_idx.unsqueeze(1) # [eff_bs, 256] + above_populated = ((hist > 0) & above_mask).sum(dim=1).float() # bins >thr with any pages + pages_above = (hist * above_mask.to(torch.int64)).sum(dim=1).float() # total pages in those bins + # Mean pages per populated above-threshold bin (per-segment, then + # averaged). Guard against divide-by-zero. + pages_per_bin = torch.where( + above_populated > 0, + pages_above / above_populated, + torch.zeros_like(above_populated), + ) + + # Selected from threshold bin = topk_val - num_above (clamped >= 0). + sel_from_thr = (float(topk_val) - c[:, 1]).clamp(min=0.0) + return { + "threshold_bin_mean": c[:, 0].mean().item(), + "threshold_bin_max": c[:, 0].max().item(), + "num_above_mean": c[:, 1].mean().item(), + "threshold_bin_size_mean": c[:, 2].mean().item(), # NUM_EQUAL + "threshold_bin_size_max": c[:, 2].max().item(), + "selected_from_thr_mean": sel_from_thr.mean().item(), + "selected_from_thr_max": sel_from_thr.max().item(), + "refine_rounds_mean": c[:, 4].mean().item(), + # Rank-target metrics: how the top pages are actually spread. + "above_bins_mean": above_populated.mean().item(), + "pages_per_above_bin_mean": pages_per_bin.mean().item(), + } + + +def _resolve_hparam(args, mode: int) -> float: + """Pick the hyperparameter for a mode: autotune JSON wins, then --mapping-hparam.""" + if mode == 0: + return 0.5 # unused for MAPPING_NONE + hparams: Dict[int, float] = getattr(args, "_autotune_hparams", {}) or {} + if mode in hparams: + return hparams[mode] + return args.mapping_hparam + + +def _remap_bench_one_config(args, batch_size, num_kv_heads, seq_len, topk_val, + distribution, modes: List[int], + head_label: str = "all") -> dict: + """Time baseline, fused, and split-phase for each mode at one config. + + `head_label` is metadata: ``"all"`` for the aggregated table (default), + or a stringified head index ``"0".."N-1"`` for per-head benches. The + caller is responsible for setting ``args._real_histogram`` to the + head-sliced sub-histogram before invoking this function in per-head mode. + """ + real_hist = getattr(args, "_real_histogram", None) if distribution == "real" else None + inputs = make_topk_inputs( + batch_size=batch_size, + num_kv_heads=num_kv_heads, + seq_len=seq_len, + page_size=args.page_size, + topk_val=topk_val, + reserved_bos=args.reserved_bos, + reserved_eos=args.reserved_eos, + score_dtype=torch.bfloat16, + distribution=distribution if distribution != "real" else "normal", + real_histogram=real_hist, + ) + eff_bs = inputs["eff_batch_size"] + pages_per_seg = inputs["num_pages_per_seg"] + total_dense = inputs["x"].numel() + + # Baseline = unmapped topk_output_sglang (CUB two-stage radix, the + # kernel every mapped mode's split-phase ends up calling). This is + # the `base_us` column and also what the `None` row reports, so + # None's topk_us == base_us by construction. + baseline_args = ( + inputs["x"], + inputs["dense_kv_indptr"], + inputs["sparse_kv_indptr"], + inputs["dense_kv_indices"], + inputs["sparse_kv_indices"], + eff_bs, topk_val, args.reserved_bos, args.reserved_eos, pages_per_seg, + ) + inputs["sparse_kv_indices"].zero_() + baseline = bench_kernel(topk_output_sglang, baseline_args, args.warmup, args.repeat) + + # Optional extra row: the full CUB BlockRadixSort topk from topk.cu. + # This is a "true naive" — exact sort, no bucketing tricks — for A/B + # against the 2-stage approximate baseline. Only runs when pages_per_seg + # fits the kernel's template ladder (<= TOPK_OUTPUT_MAX_PAGES = 4096). + naive_ms = None + if pages_per_seg <= TOPK_OUTPUT_MAX_PAGES: + naive_args = ( + inputs["x"], + inputs["dense_kv_indptr"], + inputs["dense_kv_indices"], # NOTE: topk_output arg order differs + inputs["sparse_kv_indptr"], # from topk_output_sglang + inputs["sparse_kv_indices"], + eff_bs, topk_val, args.reserved_bos, args.reserved_eos, pages_per_seg, + ) + inputs["sparse_kv_indices"].zero_() + naive_ms = bench_kernel( + topk_output, naive_args, args.warmup, args.repeat + )["mean_ms"] + + # Optional extra row: the original SGLang kernel from topk_sglang_ori.cu, + # compiled with TopK=TOPK_ORI_BAKED_IN. Only runs when topk_val matches + # that constant; otherwise the row is skipped with a warning. It is NOT + # used as the baseline — this is a separate A/B point so you can see the + # ori-vs-naive gap at a glance. + sglang_ori_ms = None + if topk_val == TOPK_ORI_BAKED_IN: + ori_indices = torch.empty(eff_bs, TOPK_ORI_BAKED_IN, + dtype=torch.int32, device="cuda") + ori_args = ( + inputs["x"], + inputs["dense_kv_indptr"], + ori_indices, + eff_bs, topk_val, args.reserved_bos, args.reserved_eos, pages_per_seg, + ) + sglang_ori_ms = bench_kernel( + topk_output_sglang_ori, ori_args, args.warmup, args.repeat + )["mean_ms"] + + # Pre-allocate the float32 buffer used for the split-phase (remap → baseline). + # Split-phase remapped buffer is **float32** to preserve Stage-2 + # refinement precision. The fused kernel computes transforms in + # fp32 internally (so its Stage-2 sub-bin keys carry transform- + # dependent bits in positions [15:0]); a narrower remapped buffer + # (bf16 or fp16) would zero those bits on round-trip and change + # the Stage-2 tie-break ordering vs the fused path. fp32 is the + # only lossless choice. The kernel supports bf16 output too (see + # topk_remap_only's dispatch table) for experimental paths, but we + # don't use it here because correctness matters more than the + # small memory-bandwidth win. + remapped = torch.empty(total_dense, dtype=torch.float32, device="cuda").reshape(inputs["x"].shape) + + config = { + "batch_size": batch_size, + "num_kv_heads": num_kv_heads, + "seq_len": seq_len, + "topk_val": topk_val, + "distribution": distribution, + "pages_per_seg": pages_per_seg, + "head": head_label, + "baseline_ms": baseline["mean_ms"], + "naive_ms": naive_ms, + "sglang_ori_ms": sglang_ori_ms, + "modes": [], + } + + # Naive row — full CUB BlockRadixSort from topk.cu. No mapping, no + # remap, no fused. Only populated when pages_per_seg fits the kernel. + if naive_ms is not None: + config["modes"].append({ + "mode": -2, # sentinel so ranking/autotune skip it + "mode_name": "Naive", + "power": 0.5, + "remap_ms": None, + "topk_after_remap_ms": naive_ms, + "split_total_ms": None, + "fused_ms": None, + "threshold_bin_mean": 0.0, + "threshold_bin_max": 0.0, + "num_above_mean": 0.0, + "threshold_bin_size_mean": 0.0, + "threshold_bin_size_max": 0.0, + "selected_from_thr_mean": 0.0, + "selected_from_thr_max": 0.0, + "refine_rounds_mean": 0.0, + "above_bins_mean": 0.0, + "pages_per_above_bin_mean": 0.0, + }) + + # The None row is a pass-through to the naive baseline: no remap, no + # fused, and topk_us == base_us by construction. Distribution metrics + # are populated by running the profile kernels with mode=0 so the user + # can see the unmapped Stage-1 bucket layout as a reference. + none_stats = _collect_threshold_stats( + inputs, topk_val, pages_per_seg, args, mode=0, power=0.5 + ) + config["modes"].append({ + "mode": 0, + "mode_name": "None", + "power": 0.5, + "remap_ms": None, + "topk_after_remap_ms": baseline["mean_ms"], + "split_total_ms": None, + "fused_ms": None, + **none_stats, + }) + + # Extra row for the original SGLang kernel — only populated when the + # build's baked-in TopK matches topk_val. Also a pass-through (no + # remap, no fused); topk_us is the ori kernel latency. + if sglang_ori_ms is not None: + config["modes"].append({ + "mode": -1, # sentinel so ranking/autotune skip it + "mode_name": "sglang_ori", + "power": 0.5, + "remap_ms": None, + "topk_after_remap_ms": sglang_ori_ms, + "split_total_ms": None, + "fused_ms": None, + "threshold_bin_mean": 0.0, + "threshold_bin_max": 0.0, + "num_above_mean": 0.0, + "threshold_bin_size_mean": 0.0, + "threshold_bin_size_max": 0.0, + "selected_from_thr_mean": 0.0, + "selected_from_thr_max": 0.0, + "refine_rounds_mean": 0.0, + "above_bins_mean": 0.0, + "pages_per_above_bin_mean": 0.0, + }) + else: + print(f"[bench-remap] sglang_ori row SKIPPED: topk_val={topk_val} != " + f"TOPK_ORI_BAKED_IN ({TOPK_ORI_BAKED_IN}). Rebuild topk_sglang_ori.cu " + f"with a matching TopK to enable the row.") + + for mode in modes: + # Mode 0 is already emitted as the `None` row above (pass-through + # to the ori baseline with no remap/fused). Skip to avoid a + # duplicate row and a spurious fused-mode-0 measurement. + if mode == 0: + continue + + power = _resolve_hparam(args, mode) + + lut_t = getattr(args, "_mapping_lut", None) if mode == 1 else None + q_t = getattr(args, "_mapping_quantiles", None) if mode == 2 else None + fused_args = ( + inputs["x"], + inputs["dense_kv_indptr"], + inputs["sparse_kv_indptr"], + inputs["dense_kv_indices"], + inputs["sparse_kv_indices"], + eff_bs, topk_val, args.reserved_bos, args.reserved_eos, pages_per_seg, + mode, power, lut_t, q_t, + ) + inputs["sparse_kv_indices"].zero_() + fused = bench_kernel(topk_output_sglang_fused, fused_args, args.warmup, args.repeat) + + # Split-phase timing is only meaningful for arithmetic modes. + # MAPPING_LUT_CDF / QUANTILE / TRUNC8 apply their mapping inside + # compute_stage1_bin, which topk_remap_only cannot reproduce, so we + # report N/A for the split-phase fields and rely on the fused kernel + # as the only valid reference latency. + if mode in ARITHMETIC_MODES: + remap_args = ( + inputs["x"], + inputs["dense_kv_indptr"], + remapped, + eff_bs, args.reserved_bos, args.reserved_eos, + mode, power, + ) + remap_only = bench_kernel(topk_remap_only, remap_args, args.warmup, args.repeat) + + # Populate the remapped buffer once so the unfused-topk warmup + # iterations don't read stale data. + topk_remap_only(*remap_args) + torch.cuda.synchronize() + split_topk_args = ( + remapped, + inputs["dense_kv_indptr"], + inputs["sparse_kv_indptr"], + inputs["dense_kv_indices"], + inputs["sparse_kv_indices"], + eff_bs, topk_val, args.reserved_bos, args.reserved_eos, pages_per_seg, + ) + inputs["sparse_kv_indices"].zero_() + split_topk = bench_kernel(topk_output_sglang, split_topk_args, args.warmup, args.repeat) + + remap_ms = remap_only["mean_ms"] + topk_after_remap_ms = split_topk["mean_ms"] + split_total_ms = remap_ms + topk_after_remap_ms + else: + remap_ms = None + topk_after_remap_ms = None + split_total_ms = None + + # Counter collection is run AFTER all timing measurements for this mode + # so it cannot affect the timings. + stats = _collect_threshold_stats(inputs, topk_val, pages_per_seg, args, mode, power) + + row = { + "mode": mode, + "mode_name": MAPPING_MODE_NAMES.get(mode, f"m{mode}"), + "power": power, + "remap_ms": remap_ms, + "topk_after_remap_ms": topk_after_remap_ms, + "split_total_ms": split_total_ms, + "fused_ms": fused["mean_ms"], + **stats, + } + config["modes"].append(row) + + return config + + +# Stage-2 working-set cap, matches SMEM_INPUT_SIZE in fast_topk_clean_fused +# (32 KB dynamic smem / 2 ping-pong buffers / 4 bytes per int = 4096). +_STAGE2_SMEM_CAP = 4096 + + +def _print_remap_table(results: List[dict]) -> None: + # The printed table only carries metrics that participate in the + # fused-kernel cost model. All purely-informational columns + # (thr_bin / sel_thr / abv_bins / pg/bin) were dropped — they're + # still in the JSON for downstream tools, just not in the table. + header = ( + f"{'mode':<14s} {'remap_ms':>9s} {'topk_ms':>9s} {'split_ms':>9s} " + f"{'fused_ms':>9s} {'base_ms':>9s} " + f"{'s1p2_load':>9s} {'eff_thr':>7s} {'rounds':>6s} {'s2_work':>8s}" + ) + for cfg in results: + banner = ( + f"\n[batch={cfg['batch_size']} heads={cfg['num_kv_heads']} " + f"seq_len={cfg['seq_len']} topk={cfg['topk_val']} " + f"dist={cfg['distribution']} pages_per_seg={cfg['pages_per_seg']} " + f"head={cfg.get('head', 'all')}]" + ) + print(banner) + extra_notes = [] + if cfg.get("naive_ms") is not None: + extra_notes.append("Naive row = topk.cu (CUB full sort)") + if cfg.get("sglang_ori_ms") is not None: + extra_notes.append("sglang_ori row = topk_sglang_ori.cu") + notes_str = "" + if extra_notes: + notes_str = " | " + " | ".join(extra_notes) + print(f" Baseline: topk_sglang.cu (CUB two-stage){notes_str}") + print( + f" s1p2_load = thr_size (uncapped global re-reads in Stage-1 pass 2) " + f"eff_thr = min(thr_size, {_STAGE2_SMEM_CAP}) " + f"rounds = stage-2 passes (1..4) " + f"s2_work = rounds * eff_thr" + ) + print(header) + print("-" * len(header)) + base_ms = cfg["baseline_ms"] + for row in cfg["modes"]: + if row["mode"] == 0: + label = "None" + elif row["mode"] == -1: + label = row.get("mode_name", "sglang_ori") + elif row["mode"] == -2: + label = row.get("mode_name", "Naive") + else: + label = f"{row['mode_name']}(p={row['power']})" + def _fmt(v): + return f"{v:9.4f}" if v is not None else f"{'N/A':>9s}" + fused_str = _fmt(row.get("fused_ms")) + thr_size = row.get("threshold_bin_size_mean", 0.0) + rounds = row.get("refine_rounds_mean", 0.0) + eff_thr = min(thr_size, float(_STAGE2_SMEM_CAP)) + s2_work = rounds * eff_thr + s1p2_load = thr_size # alias: same number, named for the cost-model role + print( + f"{label:<14s} " + f"{_fmt(row['remap_ms'])} " + f"{_fmt(row['topk_after_remap_ms'])} " + f"{_fmt(row['split_total_ms'])} " + f"{fused_str} " + f"{base_ms:9.4f} " + f"{s1p2_load:9.0f} " + f"{eff_thr:7.0f} " + f"{rounds:6.2f} " + f"{s2_work:8.0f}" + ) + + +def _combine_per_head_cfgs(per_head_cfgs: List[dict]) -> dict: + """Combine a list of per-head cfg dicts (same shape, head='0','1',...) + into a single aggregated cfg tagged head='all', by averaging every + numeric field. This is used when --per-head-bench is on so the + aggregated row reflects the realistic per-head behaviour rather than + a separate kernel launch on an averaged histogram. + + Assumes every cfg has the same `modes` list in the same order — which + holds because all per-head sub-runs use identical (batch, heads, seq, + topk, page_size, reserved, mapping_modes) parameters and therefore + take the same code paths through `_remap_bench_one_config`. + """ + assert per_head_cfgs, "_combine_per_head_cfgs called with empty list" + base = per_head_cfgs[0] + n_modes = len(base["modes"]) + # Sanity: same shape. + for c in per_head_cfgs[1:]: + assert len(c["modes"]) == n_modes, ( + f"per-head cfgs disagree on mode count: {n_modes} vs {len(c['modes'])}" + ) + + def _mean_or_none(vals): + vs = [v for v in vals if v is not None] + return (sum(vs) / len(vs)) if vs else None + + combined: Dict = { + "batch_size": base["batch_size"], + "num_kv_heads": base["num_kv_heads"], + "seq_len": base["seq_len"], + "topk_val": base["topk_val"], + "distribution": base["distribution"], + "pages_per_seg": base["pages_per_seg"], + "head": "all", + "baseline_ms": _mean_or_none([c.get("baseline_ms") for c in per_head_cfgs]), + "naive_ms": _mean_or_none([c.get("naive_ms") for c in per_head_cfgs]), + "sglang_ori_ms": _mean_or_none([c.get("sglang_ori_ms") for c in per_head_cfgs]), + "modes": [], + } + + # Numeric fields per mode row that we average; non-numeric fields (mode, + # mode_name, power) are copied from the first cfg since they're identical + # across heads by construction. + NUMERIC_KEYS = ( + "remap_ms", "topk_after_remap_ms", "split_total_ms", "fused_ms", + "threshold_bin_mean", "threshold_bin_max", + "num_above_mean", + "threshold_bin_size_mean", "threshold_bin_size_max", + "selected_from_thr_mean", "selected_from_thr_max", + "refine_rounds_mean", + "above_bins_mean", "pages_per_above_bin_mean", + ) + for mi in range(n_modes): + sample = base["modes"][mi] + merged = { + "mode": sample["mode"], + "mode_name": sample["mode_name"], + "power": sample["power"], + } + for key in NUMERIC_KEYS: + merged[key] = _mean_or_none([c["modes"][mi].get(key) for c in per_head_cfgs]) + combined["modes"].append(merged) + return combined + + +def _run_remap_bench(args) -> None: + modes = [int(m) for m in args.mapping_modes] + # Mode 0 is emitted as the "None" row from _remap_bench_one_config + # itself (pass-through to the ori baseline). Drop any user-supplied 0 + # to avoid a duplicate row. + modes = [m for m in modes if m != 0] + + distributions = list(args.distributions) + if getattr(args, "_real_histogram", None) is not None: + if "real" not in distributions: + distributions.append("real") + print(f"[remap-bench] 'real' distribution enabled " + f"(histogram total count = {int(args._real_histogram.sum())})") + + if getattr(args, "per_head_bench", False): + if getattr(args, "_real_histograms_raw", None) is None: + raise SystemExit( + "[bench-remap] --per-head-bench requires --real-histograms with a 2D raw file." + ) + if not args.num_kv_heads or any(h <= 0 for h in args.num_kv_heads): + raise SystemExit("[bench-remap] --per-head-bench requires --num-kv-heads > 0.") + # When the user passes multiple --num-kv-heads values we slice by the + # first one (the others are degenerate for per-head reporting since + # the histogram file has a fixed head count). + per_head_count = int(args.num_kv_heads[0]) + + results = [] + # When --per-head-bench is on, each "real"-distribution aggregate is + # built by averaging the 8 per-head measurements (NOT by running an + # extra kernel on an averaged histogram). This grouping keeps the + # per-head cfgs that should fold into each (bs, heads, seq, topk) + # aggregate point. + per_head_groups: dict = {} + + # ---- Per-head tables (printed first) ---- + if getattr(args, "per_head_bench", False): + raw = args._real_histograms_raw + saved_agg = args._real_histogram + try: + for h in range(per_head_count): + # Slice rows belonging to head `h`. Rows are interleaved as + # row_idx % num_kv_heads = head_idx, so this strided slice + # collects all (call, batch, h) triples across the file. + args._real_histogram = raw[h::per_head_count].sum(axis=0) + for bs in args.batch_sizes: + for heads in args.num_kv_heads: + for seq_len in args.seq_lens: + for topk_val in args.topk_vals: + cfg = _remap_bench_one_config( + args, bs, heads, seq_len, topk_val, "real", modes, + head_label=str(h), + ) + results.append(cfg) + per_head_groups.setdefault( + (bs, heads, seq_len, topk_val), [] + ).append(cfg) + finally: + args._real_histogram = saved_agg + + # ---- Aggregated tables (printed last) ---- + for bs in args.batch_sizes: + for heads in args.num_kv_heads: + for seq_len in args.seq_lens: + for topk_val in args.topk_vals: + for dist in distributions: + if dist == "real" and getattr(args, "per_head_bench", False): + cfgs = per_head_groups.get((bs, heads, seq_len, topk_val), []) + if cfgs: + # Combine the per-head cfgs into a single + # aggregated row — no extra kernel launch. + cfg = _combine_per_head_cfgs(cfgs) + results.append(cfg) + continue + cfg = _remap_bench_one_config( + args, bs, heads, seq_len, topk_val, dist, modes, + head_label="all", + ) + results.append(cfg) + + _print_remap_table(results) + + if args.output_json: + with open(args.output_json, "w") as f: + json.dump(results, f, indent=2) + print(f"\nResults saved to {args.output_json}") + + +def _run_latency_sweep(args) -> None: + """Simple baseline-vs-fused latency sweep (no split-phase, no counters).""" + modes = [int(m) for m in args.mapping_modes] + distributions = list(args.distributions) + if getattr(args, "_real_histogram", None) is not None and "real" not in distributions: + distributions.append("real") + results = [] + for bs in args.batch_sizes: + for heads in args.num_kv_heads: + for seq_len in args.seq_lens: + for topk_val in args.topk_vals: + for dist in distributions: + real_hist = args._real_histogram if dist == "real" else None + inputs = make_topk_inputs( + batch_size=bs, num_kv_heads=heads, seq_len=seq_len, + page_size=args.page_size, topk_val=topk_val, + reserved_bos=args.reserved_bos, reserved_eos=args.reserved_eos, + score_dtype=torch.bfloat16, + distribution=dist if dist != "real" else "normal", + real_histogram=real_hist, + ) + eff_bs = inputs["eff_batch_size"] + pages_per_seg = inputs["num_pages_per_seg"] + row_modes = [] + for mode in modes: + power = _resolve_hparam(args, mode) + inputs["sparse_kv_indices"].zero_() + if mode == 0: + call = topk_output_sglang + call_args = ( + inputs["x"], inputs["dense_kv_indptr"], + inputs["sparse_kv_indptr"], inputs["dense_kv_indices"], + inputs["sparse_kv_indices"], + eff_bs, topk_val, + args.reserved_bos, args.reserved_eos, pages_per_seg, + ) + else: + call = topk_output_sglang_fused + lut_t = getattr(args, "_mapping_lut", None) if mode == 1 else None + q_t = getattr(args, "_mapping_quantiles", None) if mode == 2 else None + call_args = ( + inputs["x"], inputs["dense_kv_indptr"], + inputs["sparse_kv_indptr"], inputs["dense_kv_indices"], + inputs["sparse_kv_indices"], + eff_bs, topk_val, + args.reserved_bos, args.reserved_eos, pages_per_seg, + mode, power, lut_t, q_t, + ) + stats = bench_kernel(call, call_args, args.warmup, args.repeat) + row_modes.append({ + "mode": mode, "mode_name": MAPPING_MODE_NAMES.get(mode, f"m{mode}"), + "power": power, "mean_ms": stats["mean_ms"], + "median_ms": stats["median_ms"], + }) + print( + f"bs={bs} h={heads} seq={seq_len} topk={topk_val} " + f"dist={dist} mode={mode:>2d} lat={stats['mean_ms']:.4f} ms" + ) + results.append({ + "batch_size": bs, "num_kv_heads": heads, "seq_len": seq_len, + "topk_val": topk_val, "distribution": dist, "modes": row_modes, + }) + + if args.output_json: + with open(args.output_json, "w") as f: + json.dump(results, f, indent=2) + print(f"\nResults saved to {args.output_json}") + + +def main(): + p = argparse.ArgumentParser("TopK kernel benchmarks") + p.add_argument("--batch-sizes", type=int, nargs="+", default=[4]) + p.add_argument("--num-kv-heads", type=int, nargs="+", default=[8]) + p.add_argument("--seq-lens", type=int, nargs="+", default=[8192]) + p.add_argument("--topk-vals", type=int, nargs="+", default=[30]) + p.add_argument("--distributions", type=str, nargs="+", + default=["normal"], + choices=["normal", "lognormal", "uniform", "bucket_uniform", "real"], + help="Synthetic distributions. Use 'real' (or --real-histograms) to " + "sample scores from a calibrated raw_histograms.npy.") + p.add_argument("--real-histograms", type=str, default=None, + help="Path to raw_histograms.npy from calibrate_topk.py. When set, a " + "'real' distribution is appended to the sweep so every " + "(mode, hparam) combo is also timed on the calibrated score " + "distribution.") + p.add_argument("--mapping-modes", type=int, nargs="+", + default=[0, 3, 6, 7], + help="Mapping modes to sweep (0=None, 3=Power, 6=Asinh, 7=Log1p, etc.)") + p.add_argument("--mapping-hparam", "--mapping-power", type=float, default=0.5, + dest="mapping_hparam", + help="Fallback hyperparameter for every non-zero mapping mode when " + "no --autotune-json is provided: p for mode 3 (power), beta for " + "mode 6 (asinh), alpha for modes 7/9/10/13 (log1p/erf/tanh/exp_stretch).") + p.add_argument("--autotune-json", type=str, default=None, + help="Path to autotune_results.json produced by autotune_topk_mapping.py. " + "When set, the per-mode hyperparameter with the lowest measured " + "latency in that file is used instead of --mapping-hparam.") + p.add_argument("--lut-path", type=str, default=None, + help="Path to .npy uint8[256] LUT for MAPPING_LUT_CDF (mode 1).") + p.add_argument("--quantiles-path", type=str, default=None, + help="Path to .npy float32[256] quantile table for MAPPING_QUANTILE (mode 2).") + p.add_argument("--page-size", type=int, default=16) + p.add_argument("--reserved-bos", type=int, default=1) + p.add_argument("--reserved-eos", type=int, default=2) + p.add_argument("--warmup", type=int, default=10) + p.add_argument("--repeat", type=int, default=100) + p.add_argument("--output-json", type=str, default=None) + p.add_argument("--remap-bench", action="store_true", + help="Run the split-phase remap/topk/fused/baseline benchmark.") + p.add_argument("--per-head-bench", action="store_true", + help="In addition to the aggregated 'real'-distribution table, also " + "run the remap-bench once per KV head: slice the calibrated " + "histogram into one sub-histogram per head (using " + "row_idx %% num_kv_heads = head_idx), bench each, and print one " + "table per head followed by the aggregated table. Requires " + "--real-histograms (with a 2D raw file) and --num-kv-heads.") + args = p.parse_args() + + args._autotune_hparams = {} + if args.autotune_json: + args._autotune_hparams = _load_autotune_hparams(args.autotune_json) + print(f"[autotune] using best-latency hyperparameters from {args.autotune_json}:") + for m, v in sorted(args._autotune_hparams.items()): + print(f" mode {m:>2d} -> {v}") + + args._real_histogram = None + args._real_histograms_raw = None + if args.real_histograms: + # mmap_mode='r' keeps the (potentially 20+ GB) raw file off-heap; we + # only materialise per-head sums when --per-head-bench is set. + raw = np.load(args.real_histograms, mmap_mode='r') + args._real_histogram = raw.sum(axis=0) if raw.ndim > 1 else raw + if raw.ndim > 1: + args._real_histograms_raw = raw + print(f"[real] loaded calibrated histogram from {args.real_histograms} " + f"(shape={raw.shape} → [256] aggregate)") + + args._mapping_lut = None + args._mapping_quantiles = None + if args.lut_path: + lut_np = np.load(args.lut_path).astype(np.uint8) + assert lut_np.shape == (256,), f"LUT must be [256], got {lut_np.shape}" + args._mapping_lut = torch.from_numpy(lut_np).cuda() + print(f"[mapping] loaded LUT from {args.lut_path}") + if args.quantiles_path: + q_np = np.load(args.quantiles_path).astype(np.float32) + assert q_np.shape == (256,), f"quantiles must be [256], got {q_np.shape}" + args._mapping_quantiles = torch.from_numpy(q_np).cuda() + print(f"[mapping] loaded quantiles from {args.quantiles_path}") + + if args.remap_bench: + _run_remap_bench(args) + else: + _run_latency_sweep(args) + + +if __name__ == "__main__": + main() diff --git a/benchmarks/calibrate_topk.py b/benchmarks/calibrate_topk.py new file mode 100644 index 0000000..f3343aa --- /dev/null +++ b/benchmarks/calibrate_topk.py @@ -0,0 +1,231 @@ +#!/usr/bin/env python3 +""" +Offline calibration for TopK mapping modes 1 (LUT CDF) and 2 (quantile). + +Runs the model on real data with hit-rate profiling enabled, collects score +histograms from the topk_sglang kernel, and generates: + - lut.npy : uint8[256] CDF-equalized LUT for mapping mode 1 + - quantiles.npy: float32[256] quantile breakpoints for mapping mode 2 + +Usage: + python benchmarks/calibrate_topk.py \ + --model-name Qwen/Qwen3-1.7B \ + --topk-val 30 --mem 0.7 \ + --output-dir calibration_output/ +""" + +import argparse +import json +import os +import shutil +import sys + +import numpy as np + +# Add project root to path so we can import from benchmarks/ +sys.path.insert(0, os.path.join(os.path.dirname(os.path.abspath(__file__)), "..")) + +from benchmarks.profile_topk_distribution import ( + compute_lut_from_histogram, + generate_tables_from_histograms, +) + + +def main(): + parser = argparse.ArgumentParser( + description="Offline calibration for TopK mapping modes 1 & 2" + ) + parser.add_argument("--model-name", type=str, default="Qwen/Qwen3-1.7B") + parser.add_argument("--topk-val", type=int, default=30) + parser.add_argument("--page-size", type=int, default=16) + parser.add_argument("--mem", type=float, default=0.7) + parser.add_argument( + "--max-total-tokens", + type=int, + default=1048576, + help="Hard cap on KV pool token slots (ServerArgs.max_total_tokens). " + "Block-sparse profiling uses a small bytes/token estimate, so the auto " + "budget can be huge on large GPUs; VTXGraphAttnBackend then allocates " + "dense bf16 sparse_prefill K/V buffers proportional to this cap (~4 KiB per " + "token per buffer). For offline calibration, a few hundred K to 1M tokens " + "is usually enough.", + ) + parser.add_argument( + "--min-free-disk-gb", + type=float, + default=20.0, + help="Abort if the filesystem for --output-dir (and HF cache, typically the same) " + "has less than this many GiB free. First-time model downloads need many GiB. " + "Set to 0 to disable.", + ) + parser.add_argument("--kv-cache-dtype", type=str, default="auto") + parser.add_argument("--topk-type", type=str, default="sglang") + parser.add_argument("--num-prompts", type=int, default=16, + help="Number of calibration prompts to use (default: 16)") + parser.add_argument("--output-dir", type=str, default="calibration_output/") + parser.add_argument("--vortex-module-name", type=str, default="block_sparse_attention") + parser.add_argument( + "--watchdog-timeout", + type=float, + default=None, + metavar="SEC", + help="SGLang scheduler watchdog (seconds). Forward batches must complete within this time. " + "Default: engine default (300). Use 0 to disable when using this repo's SGLang fork.", + ) + args = parser.parse_args() + + # Classic HTTP downloads avoid XET chunk reconstruction ("Background writer channel + # closed") that often surfaces when the disk is full or nearly full. + if "HF_HUB_DISABLE_XET" not in os.environ: + os.environ["HF_HUB_DISABLE_XET"] = "1" + + if args.min_free_disk_gb > 0: + check_path = os.path.abspath(args.output_dir) + while check_path and not os.path.isdir(check_path): + parent = os.path.dirname(check_path) + if parent == check_path: + check_path = os.getcwd() + break + check_path = parent + usage = shutil.disk_usage(check_path) + free_gb = usage.free / (1024.0**3) + if free_gb < args.min_free_disk_gb: + raise SystemExit( + f"[calibrate] ERROR: Only {free_gb:.1f} GiB free on filesystem containing " + f"{args.output_dir!r} (checked from {check_path!r}). " + f"Need at least ~{args.min_free_disk_gb} GiB for Hugging Face weights, hub cache, " + f"and logs. Free disk space or point HF_HOME at a larger disk. " + f"To skip this check: --min-free-disk-gb 0" + ) + + # Lazy imports to avoid slow startup when just checking --help + import sglang as sgl + import torch + import vortex_torch + + os.makedirs(args.output_dir, exist_ok=True) + + print(f"[calibrate] Launching engine with hit-rate profiling enabled...") + engine_kwargs = dict( + model_path=args.model_name, + disable_cuda_graph=True, + page_size=args.page_size, + vortex_topk_val=args.topk_val, + disable_overlap_schedule=True, + attention_backend="flashinfer", + enable_vortex_sparsity=True, + vortex_page_reserved_bos=1, + vortex_page_reserved_eos=2, + vortex_layers_skip=list(range(1)), + vortex_module_name=args.vortex_module_name, + vortex_max_seq_lens=12288, + mem_fraction_static=args.mem, + max_total_tokens=args.max_total_tokens, + kv_cache_dtype=args.kv_cache_dtype, + vortex_topk_type=args.topk_type, + vortex_topk_mapping_mode=0, # Use mode 0 during calibration + vortex_topk_histogram=True, # Enable histogram collection + ) + if args.watchdog_timeout is not None: + engine_kwargs["watchdog_timeout"] = args.watchdog_timeout + llm = sgl.Engine(**engine_kwargs) + + # Clear any residual histograms in the worker process + llm.clear_topk_histograms() + + # Load calibration prompts + prompts_path = os.path.join( + os.path.dirname(os.path.abspath(__file__)), "..", "examples", "amc23.jsonl" + ) + with open(prompts_path, "r", encoding="utf-8") as f: + all_requests = [json.loads(line) for line in f] + + # Use up to num_prompts + requests = all_requests[:args.num_prompts] + prompts = [req["prompt"] for req in requests] + + print(f"[calibrate] Running {len(prompts)} calibration prompts...") + sampling_params = { + "temperature": 0.6, + "top_p": 0.95, + "top_k": 20, + "max_new_tokens": 8192, + } + llm.generate(prompts, sampling_params) + + # Collect histograms via RPC from worker process + histograms = llm.get_topk_histograms() + print(f"[calibrate] Collected {len(histograms)} histogram batches") + + if len(histograms) == 0: + print("[calibrate] ERROR: No histograms collected. " + "Ensure topk_type='sglang' and vortex_topk_histogram=True.", + file=sys.stderr) + llm.shutdown() + sys.exit(1) + + # Stack all histograms: each is [eff_bs, 256], concatenate along batch dim + all_hists = torch.cat(histograms, dim=0).numpy() # [total_samples, 256] + print(f"[calibrate] Total histogram samples: {all_hists.shape[0]}") + + # Regression guard: refuse to save a collapsed histogram. A healthy + # calibration touches tens to hundreds of bins; if almost everything lands + # in a single bin, the scoring pipeline silently produced zero scores + # (see the Sgl_Decode_Plan_Workload_Kernel `w > topk_val` bug fixed in + # csrc/utils_sglang.cu). Saving 20+ GB of all-zeros wastes disk and poisons + # downstream benches, so fail loudly here. + _pooled = all_hists.sum(axis=0).astype(np.float64) + _total = float(_pooled.sum()) + if _total > 0: + _top_frac = float(_pooled.max()) / _total + _nz_bins = int((_pooled > 0).sum()) + if _top_frac > 0.95 or _nz_bins < 5: + llm.shutdown() + raise SystemExit( + f"[calibrate] ERROR: degenerate histogram — top bin holds " + f"{_top_frac:.2%} of mass, only {_nz_bins}/256 bins nonzero. " + f"The scoring pipeline is likely not running (check " + f"winfo_num_workloads in plan_decode, or `w > topk_val` in " + f"Sgl_Decode_Plan_Workload_Kernel). Refusing to save to avoid " + f"writing a useless multi-GB file." + ) + + # --- Generate LUT (mode 1) --- + # Aggregate histogram across all samples + avg_histogram = all_hists.mean(axis=0) + lut = compute_lut_from_histogram(avg_histogram) + lut_path = os.path.join(args.output_dir, "lut.npy") + np.save(lut_path, lut) + print(f"[calibrate] Saved LUT to {lut_path} (shape={lut.shape}, dtype={lut.dtype})") + + # --- Generate quantiles (mode 2) --- + # Use bin centers as proxy scores weighted by histogram counts + bin_centers = np.arange(256, dtype=np.float32) + # Expand histogram counts into a weighted score distribution + total_counts = avg_histogram.astype(np.float64) + total = total_counts.sum() + if total > 0: + cdf = np.cumsum(total_counts) / total + # Invert CDF to get quantile breakpoints in [0, 255] space + percentiles = np.linspace(0, 1, 256) + quantiles = np.interp(percentiles, cdf, bin_centers).astype(np.float32) + else: + quantiles = bin_centers.copy() + + quantiles_path = os.path.join(args.output_dir, "quantiles.npy") + np.save(quantiles_path, quantiles) + print(f"[calibrate] Saved quantiles to {quantiles_path} (shape={quantiles.shape}, dtype={quantiles.dtype})") + + # Save raw histograms for debugging + raw_path = os.path.join(args.output_dir, "raw_histograms.npy") + np.save(raw_path, all_hists) + print(f"[calibrate] Saved raw histograms to {raw_path} (shape={all_hists.shape})") + + # Cleanup + llm.clear_topk_histograms() + llm.shutdown() + print(f"[calibrate] Done. Output files in {args.output_dir}/") + + +if __name__ == "__main__": + main() diff --git a/benchmarks/profile_topk_distribution.py b/benchmarks/profile_topk_distribution.py new file mode 100644 index 0000000..bea911b --- /dev/null +++ b/benchmarks/profile_topk_distribution.py @@ -0,0 +1,132 @@ +#!/usr/bin/env python3 +""" +Profile TopK bin distribution and generate mapping tables. + +This script collects Stage 1 (8-bit coarse histogram) distributions from +the topk_sglang kernel and generates LUT/quantile mapping tables that +can be used to equalize the bin distribution for improved sorting efficiency. + +Usage: + python scripts/profile_topk_distribution.py \ + --model-name Qwen/Qwen3-1.7B \ + --output mapping_tables.npz \ + --num-prompts 32 \ + --mem 0.7 + +Output (.npz): + lut_tables: [num_collected, 256] uint8 - CDF-equalized LUT per sample + quantile_tables: [num_collected, 256] float32 - quantile breakpoints per sample + raw_histograms: [num_collected, 256] int32 - raw bin histograms +""" + +import argparse +import numpy as np +import torch + + +def compute_lut_from_histogram(histogram: np.ndarray) -> np.ndarray: + """Compute CDF-equalized LUT from a 256-bin histogram. + + Args: + histogram: [256] int array of bin counts + + Returns: + lut: [256] uint8 array where lut[i] = floor(CDF(i) * 255) + """ + cdf = np.cumsum(histogram).astype(np.float64) + total = cdf[-1] + if total == 0: + return np.arange(256, dtype=np.uint8) + cdf_normalized = cdf / total + lut = np.floor(cdf_normalized * 255).astype(np.uint8) + return lut + + +def compute_quantiles_from_scores(scores: np.ndarray, num_quantiles: int = 256) -> np.ndarray: + """Compute quantile breakpoints from raw float scores. + + Args: + scores: 1D array of float scores + num_quantiles: number of quantile bins (default 256) + + Returns: + quantiles: [num_quantiles] float32 array of sorted breakpoints + """ + if len(scores) == 0: + return np.zeros(num_quantiles, dtype=np.float32) + percentiles = np.linspace(0, 100, num_quantiles) + quantiles = np.percentile(scores, percentiles).astype(np.float32) + return quantiles + + +def generate_tables_from_histograms(histograms: np.ndarray) -> dict: + """Generate LUT and quantile tables from collected histograms. + + Args: + histograms: [N, 256] int32 array of bin histograms + + Returns: + dict with 'lut_tables' and 'aggregate_lut' + """ + N = histograms.shape[0] + lut_tables = np.zeros((N, 256), dtype=np.uint8) + + for i in range(N): + lut_tables[i] = compute_lut_from_histogram(histograms[i]) + + # Aggregate: average histogram across all samples + avg_histogram = histograms.mean(axis=0) + aggregate_lut = compute_lut_from_histogram(avg_histogram) + + return { + 'lut_tables': lut_tables, + 'aggregate_lut': aggregate_lut, + } + + +def main(): + parser = argparse.ArgumentParser( + description="Profile TopK bin distribution and generate mapping tables") + parser.add_argument("--output", type=str, default="mapping_tables.npz", + help="Output .npz file path") + parser.add_argument("--histograms-input", type=str, default=None, + help="Load pre-collected histograms from .npy file instead of running inference") + parser.add_argument("--scores-input", type=str, default=None, + help="Load pre-collected raw scores from .npy for quantile computation") + args = parser.parse_args() + + results = {} + + if args.histograms_input: + print(f"Loading histograms from {args.histograms_input}") + histograms = np.load(args.histograms_input) + if histograms.ndim == 1: + histograms = histograms.reshape(1, -1) + results['raw_histograms'] = histograms + + tables = generate_tables_from_histograms(histograms) + results.update(tables) + + if args.scores_input: + print(f"Loading scores from {args.scores_input}") + scores = np.load(args.scores_input) + quantiles = compute_quantiles_from_scores(scores.flatten()) + results['quantile_table'] = quantiles + + if not results: + print("No input provided. Use --histograms-input or --scores-input.") + print("\nTo collect histograms, use the topk_profile_histogram() function from vortex_torch_C:") + print(" from vortex_torch_C import topk_profile_histogram") + print(" histograms = torch.zeros(eff_batch_size, 256, dtype=torch.int32, device='cuda')") + print(" topk_profile_histogram(scores, dense_kv_indptr, histograms, eff_batch_size, bos, eos)") + print(" np.save('histograms.npy', histograms.cpu().numpy())") + return + + np.savez(args.output, **results) + print(f"Saved mapping tables to {args.output}") + for key, val in results.items(): + print(f" {key}: shape={val.shape}, dtype={val.dtype}") + + +if __name__ == "__main__": + main() diff --git a/csrc/archived/README.md b/csrc/archived/README.md new file mode 100644 index 0000000..6e08a1d --- /dev/null +++ b/csrc/archived/README.md @@ -0,0 +1,19 @@ +# Archived TopK kernels + +These files are **not compiled** (not listed in `setup.py`) and are kept only +for historical reference. + +- `topk_slgang_ori.cu` — the original SGLang TopK reference kernel (typo in + the filename is intentional, matches the upstream commit it was adapted + from). Superseded by the fused `fast_topk_vortex` path in + `../topk_sglang.cu`. +- `topk_sglang_ori_fastpath.cu` — the `fast_topk_ori` / + `TopKOutput_Ori_Kernel` / `launch_ori_kernel` code extracted out of + `../topk_sglang.cu`. It was the "zero mapping overhead" fast path with + flexible `radix_bits` (4–10). We no longer test it — mode 0 now goes + through the standard fused kernel with `MAPPING_NONE`, which pays no + mapping overhead because `mapped_convert_to_uint8` degenerates to + `convert_to_uint8` in that branch. + +If you need to resurrect any of this, add the `.cu` to `setup.py` and +re-export its entry points from `../register.cc` / `../register.h`. diff --git a/csrc/archived/fast_topk_vortex_prepass.cu b/csrc/archived/fast_topk_vortex_prepass.cu new file mode 100644 index 0000000..5b743f1 --- /dev/null +++ b/csrc/archived/fast_topk_vortex_prepass.cu @@ -0,0 +1,525 @@ +// Archived: not compiled. See csrc/archived/README.md +// +// fast_topk_vortex — the heavy fused remap+topk kernel with auto-range, +// pivot, tail-window, topk-window pre-passes and LUT/quantile support. +// Extracted from csrc/topk_sglang.cu as part of the remap-benchmark refactor. +// Replaced by a lean fast_topk_clean_fused that applies a simple element-wise +// transform (from topk_mapping.cuh apply_transform) in Stage-1 bucketing — +// no pre-pass, no LUT, no auto-range. +// +// References types/constants from its former translation unit (TopKMappingParams, +// needs_*, mapped_convert_to_uint8, kSmem, kThreadsPerBlock, COUNTER_*). This +// file will not compile standalone; kept for history only. + +// ====================================================================== +// Templated version of fast_topk_cuda_tl with mapping support: +// - ScoreT: float or __nv_bfloat16 +// - StopAfterStage1: return after Stage 1 route/filter (for profiling) +// - WriteCounters: write diagnostic counters to global memory + +// - mapping: configurable value-remapping for Stage 1 bin assignment +template +__device__ void fast_topk_vortex( + const ScoreT* __restrict__ input, + int* __restrict__ index, + int row_start, + int length, + int target_k, + const TopKMappingParams& mapping, + int* counters = nullptr) +{ + int topk = target_k; + constexpr auto BLOCK_SIZE = 1024; + constexpr auto RADIX = 256; + constexpr auto SMEM_INPUT_SIZE = kSmem / (2 * sizeof(int)); + + alignas(128) __shared__ int vh_histogram_buf[2][RADIX + 128]; + alignas(128) __shared__ int vh_counter; + alignas(128) __shared__ int vh_threshold_bin_id; + alignas(128) __shared__ int vh_num_input[2]; + + // Shared memory for mapping LUT / quantiles (loaded once per block) + __shared__ uint8_t s_mapping_lut[256]; + __shared__ float s_mapping_quantiles[256]; + + // Auto-range for transform modes (3/4/6/7) + __shared__ float s_range_min, s_range_inv_range; + + auto& vh_histogram = vh_histogram_buf[0]; + extern __shared__ int vh_input_idx[][SMEM_INPUT_SIZE]; + + const int tx = threadIdx.x; + + // Load mapping tables into shared memory if needed + if (mapping.mode == MAPPING_LUT_CDF && mapping.lut != nullptr) { + if (tx < 256) s_mapping_lut[tx] = mapping.lut[tx]; + __syncthreads(); + } + if (mapping.mode == MAPPING_QUANTILE && mapping.quantiles != nullptr) { + if (tx < 256) s_mapping_quantiles[tx] = mapping.quantiles[tx]; + __syncthreads(); + } + + // Pre-pass: compute per-block min/max of transformed values for linear bucketing. + // sample_stride > 1 reduces pre-pass cost by scanning every Nth element; + // the approximated range may miss extreme outliers but Stage 2 uses raw + // float bits for exact ordering, so correctness is preserved. + if (needs_auto_range(mapping.mode) && !mapping.noscale) { + const int stride = (mapping.sample_stride > 1) ? mapping.sample_stride : 1; + float local_min = __FLT_MAX__, local_max = -__FLT_MAX__; + for (int idx = tx * stride; idx < length; idx += BLOCK_SIZE * stride) { + float val = apply_transform(vortex_to_float(input[idx + row_start]), mapping); + local_min = fminf(local_min, val); + local_max = fmaxf(local_max, val); + } + // Warp-level reduction + for (int offset = 16; offset > 0; offset >>= 1) { + local_min = fminf(local_min, __shfl_xor_sync(0xFFFFFFFF, local_min, offset)); + local_max = fmaxf(local_max, __shfl_xor_sync(0xFFFFFFFF, local_max, offset)); + } + // Cross-warp reduction via shared memory + __shared__ float s_warp_mins[32], s_warp_maxs[32]; + int warp_id = tx >> 5, lane_id = tx & 31; + if (lane_id == 0) { s_warp_mins[warp_id] = local_min; s_warp_maxs[warp_id] = local_max; } + __syncthreads(); + if (tx < (BLOCK_SIZE >> 5)) { + local_min = s_warp_mins[tx]; local_max = s_warp_maxs[tx]; + for (int offset = 16; offset > 0; offset >>= 1) { + local_min = fminf(local_min, __shfl_xor_sync(0xFFFFFFFF, local_min, offset)); + local_max = fmaxf(local_max, __shfl_xor_sync(0xFFFFFFFF, local_max, offset)); + } + if (tx == 0) { + s_range_min = local_min; + float range = local_max - local_min; + s_range_inv_range = (range > 0.0f) ? 255.0f / range : 0.0f; + } + } + __syncthreads(); + } else if (needs_pivot(mapping.mode)) { + // Pivot pre-pass: compute mean of all elements, store in s_range_min. + // MAPPING_SUBTRACT uses convert_to_uint8(x - range_min), so centering + // around the mean helps distribute values more evenly across bins. + float local_sum = 0.0f; + for (int idx = tx; idx < length; idx += BLOCK_SIZE) { + local_sum += vortex_to_float(input[idx + row_start]); + } + // Warp-level reduction + for (int offset = 16; offset > 0; offset >>= 1) { + local_sum += __shfl_xor_sync(0xFFFFFFFF, local_sum, offset); + } + __shared__ float s_warp_sums[32]; + int warp_id = tx >> 5, lane_id = tx & 31; + if (lane_id == 0) s_warp_sums[warp_id] = local_sum; + __syncthreads(); + if (tx < (BLOCK_SIZE >> 5)) { + local_sum = s_warp_sums[tx]; + for (int offset = 16; offset > 0; offset >>= 1) { + local_sum += __shfl_xor_sync(0xFFFFFFFF, local_sum, offset); + } + if (tx == 0) { + s_range_min = local_sum / float(length); // mean as pivot + s_range_inv_range = 0.0f; + } + } + __syncthreads(); + } else if (needs_tail_window(mapping.mode)) { + // Adaptive tail-window pre-pass: estimate tau_low = Q(1 - rho*k/n) + // and local_max via a sampled quantile estimator. All 256 coarse bins + // are then allocated to [tau_low, local_max]; scores below tau_low + // collapse into bin 0 via linear_map_to_uint8 clamping. + constexpr int MAX_SAMPLES = 1024; + __shared__ float s_samples[MAX_SAMPLES]; + __shared__ int s_sample_count; + + if (tx == 0) s_sample_count = 0; + __syncthreads(); + + // Compute sampling stride so we collect ~MAX_SAMPLES from the segment + const int desired_stride = (length + MAX_SAMPLES - 1) / MAX_SAMPLES; + const int sample_stride = max(desired_stride, 1); + + // Each thread samples elements and finds local_max simultaneously + float local_max = -__FLT_MAX__; + for (int idx = tx * sample_stride; idx < length; idx += BLOCK_SIZE * sample_stride) { + float val = vortex_to_float(input[idx + row_start]); + local_max = fmaxf(local_max, val); + int slot = ::atomicAdd(&s_sample_count, 1); + if (slot < MAX_SAMPLES) { + s_samples[slot] = val; + } + } + + // Reduce local_max across block + for (int offset = 16; offset > 0; offset >>= 1) + local_max = fmaxf(local_max, __shfl_xor_sync(0xFFFFFFFF, local_max, offset)); + __shared__ float s_warp_maxs_tw[32]; + { + int warp_id = tx >> 5, lane_id = tx & 31; + if (lane_id == 0) s_warp_maxs_tw[warp_id] = local_max; + } + __syncthreads(); + if (tx < (BLOCK_SIZE >> 5)) { + local_max = s_warp_maxs_tw[tx]; + for (int offset = 16; offset > 0; offset >>= 1) + local_max = fmaxf(local_max, __shfl_xor_sync(0xFFFFFFFF, local_max, offset)); + if (tx == 0) s_warp_maxs_tw[0] = local_max; + } + __syncthreads(); + local_max = s_warp_maxs_tw[0]; + + int nsamp = min(s_sample_count, MAX_SAMPLES); + + // Simple odd-even transposition sort on the sample buffer. + // nsamp <= 1024, and we have 1024 threads, so each thread + // handles one element. O(nsamp) parallel rounds suffice. + __syncthreads(); + if (nsamp >= 2) { + for (int pass = 0; pass < nsamp; ++pass) { + // Even phase: compare (0,1), (2,3), ... + if (tx * 2 + 1 < nsamp) { + int i = tx * 2; + if (s_samples[i] > s_samples[i + 1]) { + float tmp = s_samples[i]; + s_samples[i] = s_samples[i + 1]; + s_samples[i + 1] = tmp; + } + } + __syncthreads(); + // Odd phase: compare (1,2), (3,4), ... + if (tx * 2 + 2 < nsamp) { + int i = tx * 2 + 1; + if (s_samples[i] > s_samples[i + 1]) { + float tmp = s_samples[i]; + s_samples[i] = s_samples[i + 1]; + s_samples[i + 1] = tmp; + } + } + __syncthreads(); + } + } + + // Estimate tau_low = Q(1 - rho * k / n) + if (tx == 0) { + float rho = mapping.power_exp; // reused as tail expansion factor + if (rho <= 0.0f) rho = 4.0f; + int k = (mapping.target_k > 0) ? mapping.target_k : target_k; + float frac = 1.0f - rho * float(k) / float(length); + frac = fmaxf(frac, 0.0f); // clamp: never go below rank 0 + + float tau_low; + if (nsamp < 4 || frac <= 0.0f) { + // Too few samples or the tail covers everything: full range + tau_low = -__FLT_MAX__; + } else { + float fidx = frac * float(nsamp - 1); + int lo = __float2int_rd(fidx); + lo = min(max(lo, 0), nsamp - 2); + float t = fidx - float(lo); + tau_low = s_samples[lo] * (1.0f - t) + s_samples[lo + 1] * t; + } + + // Fallback: if tau_low >= local_max, use full-range linear mapping + if (tau_low >= local_max) { + // Find the actual minimum from sorted samples + tau_low = (nsamp > 0) ? s_samples[0] : local_max; + } + + float range = local_max - tau_low; + s_range_min = tau_low; + s_range_inv_range = (range > 1e-10f) ? 255.0f / range : 0.0f; + } + __syncthreads(); + } else if (needs_topk_window(mapping.mode)) { + // Topk-window pre-pass with streaming variance heuristic. + // tau_low = max - rho * sigma * sqrt(2 * log(n/k)) + float local_max = -__FLT_MAX__; + float local_sum = 0.0f, local_sum_sq = 0.0f; + for (int idx = tx; idx < length; idx += BLOCK_SIZE) { + float val = vortex_to_float(input[idx + row_start]); + local_max = fmaxf(local_max, val); + local_sum += val; + local_sum_sq += val * val; + } + for (int offset = 16; offset > 0; offset >>= 1) { + local_max = fmaxf(local_max, __shfl_xor_sync(0xFFFFFFFF, local_max, offset)); + local_sum += __shfl_xor_sync(0xFFFFFFFF, local_sum, offset); + local_sum_sq += __shfl_xor_sync(0xFFFFFFFF, local_sum_sq, offset); + } + __shared__ float s_warp_maxs_tw2[32], s_warp_sums_tw2[32], s_warp_sq_tw2[32]; + { + int warp_id = tx >> 5, lane_id = tx & 31; + if (lane_id == 0) { + s_warp_maxs_tw2[warp_id] = local_max; + s_warp_sums_tw2[warp_id] = local_sum; + s_warp_sq_tw2[warp_id] = local_sum_sq; + } + } + __syncthreads(); + if (tx < (BLOCK_SIZE >> 5)) { + local_max = s_warp_maxs_tw2[tx]; + local_sum = s_warp_sums_tw2[tx]; + local_sum_sq = s_warp_sq_tw2[tx]; + for (int offset = 16; offset > 0; offset >>= 1) { + local_max = fmaxf(local_max, __shfl_xor_sync(0xFFFFFFFF, local_max, offset)); + local_sum += __shfl_xor_sync(0xFFFFFFFF, local_sum, offset); + local_sum_sq += __shfl_xor_sync(0xFFFFFFFF, local_sum_sq, offset); + } + if (tx == 0) { + float rho = mapping.power_exp; + if (rho <= 0.0f) rho = 4.0f; + int k = (mapping.target_k > 0) ? mapping.target_k : target_k; + float n = float(length); + float mean = local_sum / n; + float var = local_sum_sq / n - mean * mean; + float sigma = (var > 0.0f) ? sqrtf(var) : 0.0f; + float ratio = n / fmaxf(float(k), 1.0f); + float z = sqrtf(2.0f * __logf(fmaxf(ratio, 1.0f))); + float tau_low = local_max - rho * sigma * z; + if (tau_low >= local_max) tau_low = local_max - 1.0f; + float range = local_max - tau_low; + s_range_min = tau_low; + s_range_inv_range = (range > 1e-10f) ? 255.0f / range : 0.0f; + } + } + __syncthreads(); + } else { + if (tx == 0) { s_range_min = 0.0f; s_range_inv_range = 0.0f; } + __syncthreads(); + } + + // Stage 1: 8-bit coarse histogram (with optional mapping) + // Bin cache: store computed bins in vh_input_idx[1] (reinterpreted as uint8_t*) + // to avoid recomputing mapped_convert_to_uint8 in the route/filter pass. + // vh_input_idx[1] is unused until Stage 2 double-buffering starts after route. + constexpr int BIN_CACHE_CAPACITY = SMEM_INPUT_SIZE * static_cast(sizeof(int)); // uint8 entries + uint8_t* bin_cache = reinterpret_cast(vh_input_idx[1]); + const bool use_bin_cache = (length <= BIN_CACHE_CAPACITY); + + if (tx < RADIX + 1) vh_histogram[tx] = 0; + __syncthreads(); + + for (int idx = tx; idx < length; idx += BLOCK_SIZE) { + const auto bin = mapped_convert_to_uint8( + vortex_to_float(input[idx + row_start]), + mapping, s_mapping_lut, s_mapping_quantiles, + s_range_min, s_range_inv_range); + ::atomicAdd(&vh_histogram[bin], 1); + if (use_bin_cache) { + bin_cache[idx] = bin; + } + } + __syncthreads(); + + const auto run_cumsum = [&] { +#pragma unroll 8 + for (int i = 0; i < 8; ++i) { + static_assert(1 << 8 == RADIX); + if (C10_LIKELY(tx < RADIX)) { + const auto j = 1 << i; + const auto k = i & 1; + auto value = vh_histogram_buf[k][tx]; + if (tx < RADIX - j) { + value += vh_histogram_buf[k][tx + j]; + } + vh_histogram_buf[k ^ 1][tx] = value; + } + __syncthreads(); + } + }; + + run_cumsum(); + if (tx < RADIX && vh_histogram[tx] > topk && vh_histogram[tx + 1] <= topk) { + vh_threshold_bin_id = tx; + vh_num_input[0] = 0; + vh_counter = 0; + } + __syncthreads(); + + const auto threshold_bin = vh_threshold_bin_id; + topk -= vh_histogram[threshold_bin + 1]; + + if (WriteCounters && tx == 0 && counters) { + counters[COUNTER_THRESHOLD_BIN] = threshold_bin; + counters[COUNTER_REMAINING_K] = topk; + } + + if (topk == 0) { + for (int idx = tx; idx < length; idx += BLOCK_SIZE) { + int bin; + if (use_bin_cache) { + bin = static_cast(bin_cache[idx]); + } else { + bin = static_cast( + mapped_convert_to_uint8( + vortex_to_float(input[idx + row_start]), + mapping, s_mapping_lut, s_mapping_quantiles, + s_range_min, s_range_inv_range)); + } + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&vh_counter, 1); + index[pos] = idx; + } + } + __syncthreads(); + if (WriteCounters && tx == 0 && counters) { + counters[COUNTER_NUM_ABOVE] = vh_counter; + counters[COUNTER_NUM_EQUAL] = 0; + counters[COUNTER_REFINE_ROUNDS] = 0; + counters[COUNTER_STAGE2_INPUT] = 0; + } + return; + } else { + __syncthreads(); + if (tx < RADIX + 1) vh_histogram[tx] = 0; + __syncthreads(); + + for (int idx = tx; idx < length; idx += BLOCK_SIZE) { + const auto raw_input = vortex_to_float(input[idx + row_start]); + int bin; + if (use_bin_cache) { + bin = static_cast(bin_cache[idx]); + } else { + bin = static_cast( + mapped_convert_to_uint8(raw_input, mapping, + s_mapping_lut, s_mapping_quantiles, + s_range_min, s_range_inv_range)); + } + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&vh_counter, 1); + index[pos] = idx; + } else if (bin == threshold_bin) { + const auto pos = ::atomicAdd(&vh_num_input[0], 1); + if (C10_LIKELY(pos < SMEM_INPUT_SIZE)) { + vh_input_idx[0][pos] = idx; + const auto b32 = convert_to_uint32(raw_input); + const auto sub_bin = (b32 >> 24) & 0xFF; + ::atomicAdd(&vh_histogram[sub_bin], 1); + } + } + } + __syncthreads(); + if (WriteCounters && tx == 0 && counters) { + counters[COUNTER_NUM_ABOVE] = vh_counter; + counters[COUNTER_NUM_EQUAL] = vh_num_input[0]; + counters[COUNTER_STAGE2_INPUT] = vh_num_input[0]; + } + if (StopAfterStage1) return; + } + + // Stage 2: refine with 8-bit radix passes (unchanged — uses raw float bits) + if constexpr (WriteCounters) { + // Default: all 4 rounds used; overwritten at break if resolved early + if (tx == 0 && counters) counters[COUNTER_REFINE_ROUNDS] = 4; + } +#pragma unroll 4 + for (int round = 0; round < 4; ++round) { + __shared__ int vh_last_remain; + const auto r_idx = round % 2; + + const auto _raw_num_input = vh_num_input[r_idx]; + const auto num_input = (_raw_num_input < int(SMEM_INPUT_SIZE)) + ? _raw_num_input + : int(SMEM_INPUT_SIZE); + + run_cumsum(); + if (tx < RADIX && vh_histogram[tx] > topk && vh_histogram[tx + 1] <= topk) { + vh_threshold_bin_id = tx; + vh_num_input[r_idx ^ 1] = 0; + vh_last_remain = topk - vh_histogram[tx + 1]; + } + __syncthreads(); + + const auto threshold_bin = vh_threshold_bin_id; + topk -= vh_histogram[threshold_bin + 1]; + + if (topk == 0) { + for (int i = tx; i < num_input; i += BLOCK_SIZE) { + const auto idx = vh_input_idx[r_idx][i]; + const auto offset = 24 - round * 8; + const auto bin = (convert_to_uint32( + vortex_to_float(input[idx + row_start])) >> offset) & 0xFF; + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&vh_counter, 1); + index[pos] = idx; + } + } + __syncthreads(); + if constexpr (WriteCounters) { + if (tx == 0 && counters) { + counters[COUNTER_REFINE_ROUNDS] = round + 1; + } + } + break; + } else { + __syncthreads(); + if (tx < RADIX + 1) vh_histogram[tx] = 0; + __syncthreads(); + for (int i = tx; i < num_input; i += BLOCK_SIZE) { + const auto idx = vh_input_idx[r_idx][i]; + const auto raw_input = vortex_to_float(input[idx + row_start]); + const auto offset = 24 - round * 8; + const auto bin = (convert_to_uint32(raw_input) >> offset) & 0xFF; + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&vh_counter, 1); + index[pos] = idx; + } else if (bin == threshold_bin) { + if (round == 3) { + const auto pos = ::atomicAdd(&vh_last_remain, -1); + if (pos > 0) { + index[target_k - pos] = idx; + } + } else { + const auto pos = ::atomicAdd(&vh_num_input[r_idx ^ 1], 1); + if (C10_LIKELY(pos < SMEM_INPUT_SIZE)) { + vh_input_idx[r_idx ^ 1][pos] = idx; + const auto b32 = convert_to_uint32(raw_input); + const auto sub_bin = (b32 >> (offset - 8)) & 0xFF; + ::atomicAdd(&vh_histogram[sub_bin], 1); + } + } + } + } + __syncthreads(); + } + } +} + +// Wrapper kernel: one CUDA block per batch*head segment +template +__global__ __launch_bounds__(kThreadsPerBlock) +void TopKOutput_Kernel( + const ScoreT* __restrict__ score, + const int* __restrict__ dense_kv_indptr, + const int* __restrict__ sparse_kv_indptr, + const int* __restrict__ dense_kv_indices, + int* __restrict__ sparse_kv_indices, + const int topk_val, + const int page_reserved_bos, + const int page_reserved_eos, + const TopKMappingParams mapping) +{ + const int bx = blockIdx.x; + + const int start = dense_kv_indptr[bx] + page_reserved_bos; + const int end = dense_kv_indptr[bx + 1] - page_reserved_eos; + const int nblk = end - start; + if (nblk <= topk_val) return; + + const ScoreT* __restrict__ score_blk = score + start; + const int* __restrict__ idx_blk = dense_kv_indices + start; + int* __restrict__ out_blk = sparse_kv_indices + + sparse_kv_indptr[bx] + + page_reserved_bos; + + __shared__ int s_indices[VORTEX_MAX_TOPK]; + fast_topk_vortex(score_blk, s_indices, 0, nblk, topk_val, mapping); + __syncthreads(); + + // Remap position indices -> page indices via dense_kv_indices + const int tx = threadIdx.x; + for (int i = tx; i < topk_val; i += kThreadsPerBlock) { + out_blk[i] = idx_blk[s_indices[i]]; + } +} + + diff --git a/csrc/archived/topk_mapping_full.cuh b/csrc/archived/topk_mapping_full.cuh new file mode 100644 index 0000000..f85204e --- /dev/null +++ b/csrc/archived/topk_mapping_full.cuh @@ -0,0 +1,217 @@ +// Archived: not included by any compiled TU. See csrc/archived/README.md. +// The full mapping header supporting LUT_CDF, QUANTILE, TRUNC8, SUBTRACT, +// ADAPTIVE_TAIL_WINDOW, TOPK_WINDOW and the auto-range/pivot/tail-window +// pre-pass infrastructure. Replaced by the lean element-wise-only header +// at csrc/topk_mapping.cuh for the remap-benchmark refactor. +#pragma once +#include +#include +#include + +// ============================================================ +// TopK bucket-sort Stage-1 remapping strategies +// +// These transforms remap float scores before Stage 1's 8-bit +// histogram binning. The primary goal is to maximize coarse-bin +// resolution in the score region that determines the top-k +// cutoff, thereby: +// - shrinking the Stage-1 threshold bin (fewer collisions) +// - reducing COUNTER_NUM_EQUAL / COUNTER_STAGE2_INPUT +// - reducing the number of Stage-2 refine rounds +// +// Stage 2 refinement still uses convert_to_uint32() on raw +// floats, so final ordering correctness is always preserved. +// +// Modes 3/4/6/7/9/10 apply a nonlinear transform then linearly +// map the result to [0,255]. Mode 12 (ADAPTIVE_TAIL_WINDOW) +// directly focuses all 256 bins on the competitive upper tail +// estimated from the top-k ratio, collapsing irrelevant +// low-score mass into bin 0. +// ============================================================ + +enum TopKMappingMode { + MAPPING_NONE = 0, // Original convert_to_uint8 behavior + MAPPING_LUT_CDF = 1, // LUT-based CDF equalization + MAPPING_QUANTILE = 2, // Piecewise-linear quantile mapping + MAPPING_POWER = 3, // Monotonic power transform + MAPPING_LOG = 4, // Log transform + // Mode 5 reserved (previously INDEX_CACHE, removed) + MAPPING_ASINH = 6, // asinh(beta * x), beta via power_exp + MAPPING_LOG1P = 7, // sign(x) * log1p(alpha * |x|), alpha via power_exp + MAPPING_TRUNC8 = 8, // BF16 upper-8-bit bucketing + MAPPING_ERF = 9, // erf(alpha * x) + MAPPING_TANH = 10, // tanh(alpha * x) + MAPPING_SUBTRACT = 11, // subtract pivot, then fp16 bucketing + MAPPING_ADAPTIVE_TAIL_WINDOW = 12, // focus bins on upper tail via sampled quantile + MAPPING_EXP_STRETCH = 13, // exp(alpha * x), concentrates bin resolution on upper tail + MAPPING_TOPK_WINDOW = 14, // k-aware linear windowing: focus bins on [tau_low, max] +}; + +struct TopKMappingParams { + int mode; // TopKMappingMode + float power_exp; // For MAPPING_POWER (default 0.5) + // For MAPPING_ADAPTIVE_TAIL_WINDOW: tail expansion + // factor rho (default 4.0). tau_low = Q(1 - rho*k/n). + const uint8_t* __restrict__ lut; // [256] byte LUT, or nullptr + const float* __restrict__ quantiles; // [256] float quantile breakpoints, or nullptr + bool noscale; // Skip auto-range linear scaling, use fp16 bucketing on f(x) + int sample_stride; // Pre-pass sampling stride (1=full, 8=1/8, 0=skip) + int target_k; // Top-k value; used by MAPPING_ADAPTIVE_TAIL_WINDOW +}; + +// NOTE: convert_to_uint8() must be defined before including this header. +// It is defined in topk_sglang.cu within the anonymous namespace. + +// ---- Individual transform functions (return float, no bucketing) ---- + +__device__ __forceinline__ float transform_power(float x, float p) { + return copysignf(__powf(fabsf(x), p), x); +} + +__device__ __forceinline__ float transform_log(float x) { + return copysignf(__logf(fabsf(x) + 1.0f), x); +} + +__device__ __forceinline__ float transform_asinh(float x, float beta) { + return asinhf(beta * x); +} + +__device__ __forceinline__ float transform_log1p(float x, float alpha) { + return copysignf(log1pf(alpha * fabsf(x)), x); +} + +__device__ __forceinline__ float transform_erf(float x, float alpha) { + return erff(alpha * x); +} + +__device__ __forceinline__ float transform_tanh(float x, float alpha) { + return tanhf(alpha * x); +} + +__device__ __forceinline__ float transform_exp_stretch(float x, float alpha) { + float z = alpha * x; + z = fminf(z, 80.0f); // prevent float32 overflow (exp(80) ~ 5.5e34) + return expf(z); +} + +// ---- Transform dispatcher (returns float, no bucketing) ---- + +__device__ __forceinline__ float apply_transform(float x, const TopKMappingParams& params) { + switch (params.mode) { + case MAPPING_POWER: return transform_power(x, params.power_exp); + case MAPPING_LOG: return transform_log(x); + case MAPPING_ASINH: return transform_asinh(x, params.power_exp); + case MAPPING_LOG1P: return transform_log1p(x, params.power_exp); + case MAPPING_ERF: return transform_erf(x, params.power_exp); + case MAPPING_TANH: return transform_tanh(x, params.power_exp); + case MAPPING_EXP_STRETCH: return transform_exp_stretch(x, params.power_exp); + default: return x; + } +} + +// ---- Linear bucketing for transform modes ---- + +__device__ __forceinline__ uint8_t linear_map_to_uint8(float val, float range_min, float inv_range) { + int bin = __float2int_rd((val - range_min) * inv_range); + return static_cast(min(max(bin, 0), 255)); +} + +// ---- BF16-aware bucketing (mode 8) ---- +// BF16 has 8 exponent + 7 mantissa bits. Taking the upper 8 bits of the +// sign-flipped bf16 bit-pattern yields only ~20 distinct bins for typical +// data (the byte is almost entirely exponent). Instead, convert through +// fp16 (5 exp + 10 mantissa) which puts 5 exp + 2 mantissa bits in the +// upper byte, giving ~135+ distinct bins — equivalent to mode 0 but +// explicitly available as a named mode for documentation/benchmarking. + +__device__ __forceinline__ uint8_t convert_to_uint8_bf16(float x) { + return convert_to_uint8(x); // fp16 sign-flip bucketing +} + +// ---- Non-transform mapping functions (unchanged) ---- + +// LUT-based CDF equalization: lut[original_bin] -> equalized_bin +__device__ __forceinline__ uint8_t map_lut_cdf(float x, const uint8_t* __restrict__ s_lut) { + return s_lut[convert_to_uint8(x)]; +} + +// Quantile mapping: binary search over 256 sorted thresholds +__device__ __forceinline__ uint8_t map_quantile(float x, const float* __restrict__ s_quantiles) { + // Binary search: find largest index i such that x >= s_quantiles[i] + // s_quantiles is sorted ascending, length 256 + int lo = 0, hi = 255; +#pragma unroll 8 + for (int iter = 0; iter < 8; ++iter) { + int mid = (lo + hi + 1) >> 1; + if (x >= s_quantiles[mid]) { + lo = mid; + } else { + hi = mid - 1; + } + } + return static_cast(lo); +} + +// ---- Unified dispatcher ---- +// For modes 3/4/6/7, range_min and inv_range come from a per-block pre-pass. + +__device__ __forceinline__ uint8_t mapped_convert_to_uint8( + float x, + const TopKMappingParams& params, + const uint8_t* __restrict__ s_lut, + const float* __restrict__ s_quantiles, + float range_min, + float inv_range) +{ + switch (params.mode) { + case MAPPING_LUT_CDF: + if (params.lut != nullptr) return map_lut_cdf(x, s_lut); + return convert_to_uint8(x); // fallback to mode 0 when LUT not calibrated + case MAPPING_QUANTILE: + if (params.quantiles != nullptr) return map_quantile(x, s_quantiles); + return convert_to_uint8(x); // fallback to mode 0 when quantiles not calibrated + case MAPPING_POWER: + case MAPPING_LOG: + case MAPPING_ASINH: + case MAPPING_LOG1P: + case MAPPING_ERF: + case MAPPING_TANH: + case MAPPING_EXP_STRETCH: { + float val = apply_transform(x, params); + if (params.noscale) return convert_to_uint8(val); + return linear_map_to_uint8(val, range_min, inv_range); + } + case MAPPING_TRUNC8: + return convert_to_uint8_bf16(x); + case MAPPING_SUBTRACT: + return convert_to_uint8(x - range_min); // range_min repurposed as pivot + case MAPPING_ADAPTIVE_TAIL_WINDOW: + case MAPPING_TOPK_WINDOW: + return linear_map_to_uint8(x, range_min, inv_range); + default: // MAPPING_NONE + return convert_to_uint8(x); + } +} + +// Helper: check if a mapping mode needs the auto-range pre-pass +__device__ __forceinline__ bool needs_auto_range(int mode) { + return (mode == MAPPING_POWER || mode == MAPPING_LOG || + mode == MAPPING_ASINH || mode == MAPPING_LOG1P || + mode == MAPPING_ERF || mode == MAPPING_TANH || + mode == MAPPING_EXP_STRETCH); +} + +// Helper: check if a mapping mode needs the pivot pre-pass +__device__ __forceinline__ bool needs_pivot(int mode) { + return (mode == MAPPING_SUBTRACT); +} + +// Helper: check if mode is the adaptive tail-window pre-pass +__device__ __forceinline__ bool needs_tail_window(int mode) { + return (mode == MAPPING_ADAPTIVE_TAIL_WINDOW); +} + +// Helper: check if mode is the lightweight topk-window pre-pass +__device__ __forceinline__ bool needs_topk_window(int mode) { + return (mode == MAPPING_TOPK_WINDOW); +} diff --git a/csrc/archived/topk_sglang_ori_fastpath.cu b/csrc/archived/topk_sglang_ori_fastpath.cu new file mode 100644 index 0000000..29970ec --- /dev/null +++ b/csrc/archived/topk_sglang_ori_fastpath.cu @@ -0,0 +1,319 @@ +// Archived: not compiled. See csrc/archived/README.md +// +// Flexible-radix (RADIX_BITS 4..10) "ori fast path" for TopK. It was the +// zero-mapping-overhead fast path used when mapping_mode == MAPPING_NONE. +// No longer tested — mode 0 now routes through the fused TopKOutput_Kernel +// with mapping.mode == MAPPING_NONE, which pays no extra cost because +// mapped_convert_to_uint8 collapses to convert_to_uint8 in that branch. +// +// The code below was extracted verbatim from csrc/topk_sglang.cu as of the +// fused-kernel refactor. It references helpers (kSmem, convert_to_uint32, +// vortex_to_float, VORTEX_MAX_TOPK, kThreadsPerBlock, setup_kernel_smem_once, +// CHECK_CUDA, topk_mapping.cuh types) from the surrounding translation unit. +// Dropping this file into a build as-is will not compile; it is reference +// only. + +template +__device__ __forceinline__ uint16_t convert_to_uintN(float x) { + __half h = __float2half_rn(x); + uint16_t bits = __half_as_ushort(h); + uint16_t key = (bits & 0x8000) ? static_cast(~bits) : static_cast(bits | 0x8000); + return key >> (16 - BITS); +} + +// ====================================================================== +// Ori fast path: zero-overhead topk with no mapping infrastructure. +// Template on RADIX_BITS: 4-10 (16 to 1024 bins). +// ====================================================================== +template +__device__ void fast_topk_ori( + const ScoreT* __restrict__ input, + int* __restrict__ index, + int row_start, + int length, + int target_k) +{ + int topk = target_k; + constexpr auto BLOCK_SIZE = 1024; + constexpr auto RADIX = 1 << RADIX_BITS; + constexpr auto RADIX_PAD = RADIX / 2; + constexpr auto SMEM_INPUT_SIZE = kSmem / (2 * sizeof(int)); + static_assert(RADIX_BITS >= 4 && RADIX_BITS <= 10, "RADIX_BITS must be 4-10"); + static_assert(RADIX <= BLOCK_SIZE, "RADIX must not exceed BLOCK_SIZE"); + + alignas(128) __shared__ int s_histogram_buf[2][RADIX + RADIX_PAD]; + alignas(128) __shared__ int s_counter; + alignas(128) __shared__ int s_threshold_bin_id; + alignas(128) __shared__ int s_num_input[2]; + + auto& s_histogram = s_histogram_buf[0]; + extern __shared__ int s_input_idx[][SMEM_INPUT_SIZE]; + + const int tx = threadIdx.x; + + // Stage 1: coarse histogram with RADIX bins + if (tx < RADIX + 1) s_histogram[tx] = 0; + __syncthreads(); + + for (int idx = tx; idx < length; idx += BLOCK_SIZE) { + const auto bin = convert_to_uintN(vortex_to_float(input[idx + row_start])); + ::atomicAdd(&s_histogram[bin], 1); + } + __syncthreads(); + + const auto run_cumsum = [&] { + for (int i = 0; i < RADIX_BITS; ++i) { + if (C10_LIKELY(tx < RADIX)) { + const auto j = 1 << i; + const auto k = i & 1; + auto value = s_histogram_buf[k][tx]; + if (tx < RADIX - j) { + value += s_histogram_buf[k][tx + j]; + } + s_histogram_buf[k ^ 1][tx] = value; + } + __syncthreads(); + } + }; + // Stage 2 cumsum: always 256 sub-bins (8-bit radix on raw float bits) + const auto run_cumsum_s2 = [&] { + for (int i = 0; i < 8; ++i) { + if (C10_LIKELY(tx < 256)) { + const auto j = 1 << i; + const auto k = i & 1; + auto value = s_histogram_buf[k][tx]; + if (tx < 256 - j) { + value += s_histogram_buf[k][tx + j]; + } + s_histogram_buf[k ^ 1][tx] = value; + } + __syncthreads(); + } + }; + + run_cumsum(); + if (tx < RADIX && s_histogram[tx] > topk && s_histogram[tx + 1] <= topk) { + s_threshold_bin_id = tx; + s_num_input[0] = 0; + s_counter = 0; + } + __syncthreads(); + + const auto threshold_bin = s_threshold_bin_id; + topk -= s_histogram[threshold_bin + 1]; + + if (topk == 0) { + for (int idx = tx; idx < length; idx += BLOCK_SIZE) { + const auto bin = static_cast(convert_to_uintN(vortex_to_float(input[idx + row_start]))); + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&s_counter, 1); + index[pos] = idx; + } + } + __syncthreads(); + return; + } else { + __syncthreads(); + if (tx < 257) s_histogram[tx] = 0; + __syncthreads(); + + for (int idx = tx; idx < length; idx += BLOCK_SIZE) { + const auto raw_input = vortex_to_float(input[idx + row_start]); + const auto bin = static_cast(convert_to_uintN(raw_input)); + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&s_counter, 1); + index[pos] = idx; + } else if (bin == threshold_bin) { + const auto pos = ::atomicAdd(&s_num_input[0], 1); + if (C10_LIKELY(pos < SMEM_INPUT_SIZE)) { + s_input_idx[0][pos] = idx; + const auto b32 = convert_to_uint32(raw_input); + const auto sub_bin = (b32 >> 24) & 0xFF; + ::atomicAdd(&s_histogram[sub_bin], 1); + } + } + } + __syncthreads(); + } + + // Stage 2: refine with 8-bit radix passes +#pragma unroll 4 + for (int round = 0; round < 4; ++round) { + __shared__ int s_last_remain; + const auto r_idx = round % 2; + + const auto _raw_num_input = s_num_input[r_idx]; + const auto num_input = (_raw_num_input < int(SMEM_INPUT_SIZE)) ? _raw_num_input : int(SMEM_INPUT_SIZE); + + run_cumsum_s2(); + if (tx < 256 && s_histogram[tx] > topk && s_histogram[tx + 1] <= topk) { + s_threshold_bin_id = tx; + s_num_input[r_idx ^ 1] = 0; + s_last_remain = topk - s_histogram[tx + 1]; + } + __syncthreads(); + + const auto threshold_bin = s_threshold_bin_id; + topk -= s_histogram[threshold_bin + 1]; + + if (topk == 0) { + for (int i = tx; i < num_input; i += BLOCK_SIZE) { + const auto idx = s_input_idx[r_idx][i]; + const auto offset = 24 - round * 8; + const auto bin = (convert_to_uint32(vortex_to_float(input[idx + row_start])) >> offset) & 0xFF; + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&s_counter, 1); + index[pos] = idx; + } + } + __syncthreads(); + break; + } else { + __syncthreads(); + if (tx < 257) s_histogram[tx] = 0; + __syncthreads(); + for (int i = tx; i < num_input; i += BLOCK_SIZE) { + const auto idx = s_input_idx[r_idx][i]; + const auto raw_input = vortex_to_float(input[idx + row_start]); + const auto offset = 24 - round * 8; + const auto bin = (convert_to_uint32(raw_input) >> offset) & 0xFF; + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&s_counter, 1); + index[pos] = idx; + } else if (bin == threshold_bin) { + if (round == 3) { + const auto pos = ::atomicAdd(&s_last_remain, -1); + if (pos > 0) { + index[target_k - pos] = idx; + } + } else { + const auto pos = ::atomicAdd(&s_num_input[r_idx ^ 1], 1); + if (C10_LIKELY(pos < SMEM_INPUT_SIZE)) { + s_input_idx[r_idx ^ 1][pos] = idx; + const auto b32 = convert_to_uint32(raw_input); + const auto sub_bin = (b32 >> (offset - 8)) & 0xFF; + ::atomicAdd(&s_histogram[sub_bin], 1); + } + } + } + } + __syncthreads(); + } + } +} + +// Ori fast-path wrapper: zero mapping overhead, flexible radix +template +__global__ __launch_bounds__(kThreadsPerBlock) +void TopKOutput_Ori_Kernel( + const ScoreT* __restrict__ score, + const int* __restrict__ dense_kv_indptr, + const int* __restrict__ sparse_kv_indptr, + const int* __restrict__ dense_kv_indices, + int* __restrict__ sparse_kv_indices, + const int topk_val, + const int page_reserved_bos, + const int page_reserved_eos) +{ + const int bx = blockIdx.x; + + const int start = dense_kv_indptr[bx] + page_reserved_bos; + const int end = dense_kv_indptr[bx + 1] - page_reserved_eos; + const int nblk = end - start; + if (nblk <= topk_val) return; + + const ScoreT* __restrict__ score_blk = score + start; + const int* __restrict__ idx_blk = dense_kv_indices + start; + int* __restrict__ out_blk = sparse_kv_indices + + sparse_kv_indptr[bx] + + page_reserved_bos; + + __shared__ int s_indices[VORTEX_MAX_TOPK]; + fast_topk_ori(score_blk, s_indices, 0, nblk, topk_val); + __syncthreads(); + + const int tx = threadIdx.x; + for (int i = tx; i < topk_val; i += kThreadsPerBlock) { + out_blk[i] = idx_blk[s_indices[i]]; + } +} + +// Helper: launch TopKOutput_Ori_Kernel with radix_bits dispatch +template +void launch_ori_kernel( + const ScoreT* score, const int* dense_kv_indptr, const int* sparse_kv_indptr, + const int* dense_kv_indices, int* sparse_kv_indices, + int topk_val, int reserved_bos, int reserved_eos, + int radix_bits, dim3 nblks, dim3 nthreads, cudaStream_t stream) +{ + #define LAUNCH_ORI(BITS) \ + setup_kernel_smem_once, kSmem>(); \ + TopKOutput_Ori_Kernel<<>>( \ + score, dense_kv_indptr, sparse_kv_indptr, dense_kv_indices, sparse_kv_indices, \ + topk_val, reserved_bos, reserved_eos) + switch (radix_bits) { + case 4: LAUNCH_ORI(4); break; + case 5: LAUNCH_ORI(5); break; + case 6: LAUNCH_ORI(6); break; + case 7: LAUNCH_ORI(7); break; + case 9: LAUNCH_ORI(9); break; + case 10: LAUNCH_ORI(10); break; + default: LAUNCH_ORI(8); break; + } + #undef LAUNCH_ORI +} + +// ====================================================================== +// Explicit ori baseline entry point — always uses the ori fast path +// ====================================================================== +void topk_output_sglang_ori( + const at::Tensor& x, + const at::Tensor& dense_kv_indptr, + const at::Tensor& sparse_kv_indptr, + const at::Tensor& dense_kv_indices, + at::Tensor& sparse_kv_indices, + const int64_t eff_batch_size, + const int64_t topk_val, + const int64_t reserved_bos, + const int64_t reserved_eos, + const int64_t max_num_pages, + const int64_t radix_bits) +{ + TORCH_CHECK(topk_val <= VORTEX_MAX_TOPK, + "topk_output_sglang_ori: topk_val (", topk_val, + ") exceeds VORTEX_MAX_TOPK (", VORTEX_MAX_TOPK, ")"); + TORCH_CHECK(radix_bits >= 4 && radix_bits <= 10, + "topk_output_sglang_ori: radix_bits must be 4-10, got ", radix_bits); + + CHECK_CUDA(x); + CHECK_CUDA(dense_kv_indptr); + CHECK_CUDA(sparse_kv_indptr); + CHECK_CUDA(dense_kv_indices); + CHECK_CUDA(sparse_kv_indices); + + dim3 nblks(eff_batch_size); + dim3 nthreads(kThreadsPerBlock); + cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); + + if (x.scalar_type() == at::ScalarType::BFloat16) { + launch_ori_kernel<__nv_bfloat16>( + reinterpret_cast<__nv_bfloat16*>(x.data_ptr()), + dense_kv_indptr.data_ptr(), sparse_kv_indptr.data_ptr(), + dense_kv_indices.data_ptr(), sparse_kv_indices.data_ptr(), + topk_val, reserved_bos, reserved_eos, + radix_bits, nblks, nthreads, stream); + } else if (x.scalar_type() == at::ScalarType::Float) { + launch_ori_kernel( + x.data_ptr(), + dense_kv_indptr.data_ptr(), sparse_kv_indptr.data_ptr(), + dense_kv_indices.data_ptr(), sparse_kv_indices.data_ptr(), + topk_val, reserved_bos, reserved_eos, + radix_bits, nblks, nthreads, stream); + } else { + TORCH_CHECK(false, "topk_output_sglang_ori: unsupported dtype ", x.scalar_type()); + } + + const auto result = cudaGetLastError(); + TORCH_CHECK(result == cudaSuccess, + "topk_output_sglang_ori kernel failed: ", ::cudaGetErrorString(result)); +} diff --git a/csrc/archived/topk_slgang_ori.cu b/csrc/archived/topk_slgang_ori.cu new file mode 100644 index 0000000..04a2b73 --- /dev/null +++ b/csrc/archived/topk_slgang_ori.cu @@ -0,0 +1,546 @@ +/** + * @NOTE: This file is adapted from + * https://github.com/tile-ai/tilelang/blob/main/examples/deepseek_v32/topk_selector.py + * We: + * 1. adapt from tilelang to pure cuda + * 2. optimize the performance a little + * 3. fix the potential illegal memory access + */ +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +namespace { + +constexpr int TopK = 2048; +constexpr int kThreadsPerBlock = 1024; + +#ifdef USE_ROCM +// On ROCm, the per-workgroup LDS budget depends on the target arch, so we inject a +// per-arch value from `setup_rocm.py` via `-DSGL_TOPK_DYNAMIC_SMEM_BYTES=...`. +#ifdef SGL_TOPK_DYNAMIC_SMEM_BYTES +constexpr size_t kSmem = static_cast(SGL_TOPK_DYNAMIC_SMEM_BYTES); +#else +constexpr size_t kSmem = 48 * 1024; // bytes +#endif +#else +// Reduced from 128KB to 32KB to improve occupancy. +// Each radix pass needs at most ~TopK candidates in the threshold bin, +// so 4K entries per round (2 rounds = 8K entries = 32KB) is sufficient. +constexpr size_t kSmem = 8 * 1024 * sizeof(uint32_t); // 32KB (bytes) +#endif + +struct FastTopKParams { + const float* __restrict__ input; // [B, input_stride] + const int32_t* __restrict__ row_starts; // [B] + int32_t* __restrict__ indices; // [B, TopK] + int32_t* __restrict__ lengths; // [B] + int64_t input_stride; +}; + +// when length <= TopK, we can directly write the indices +__device__ void naive_topk_cuda(const float* __restrict__ score, int32_t* __restrict__ indice, int32_t length) { + const auto tid = threadIdx.x; + for (int i = tid; i < TopK; i += kThreadsPerBlock) { + indice[i] = (i < length) ? i : -1; + } +} + +// keep the first `length` entries, set others to -1 +__device__ void naive_topk_transform( + const float* __restrict__ score, + int32_t length, + int32_t* __restrict__ dst_page_table, + const int32_t* __restrict__ src_page_table) { + const auto tid = threadIdx.x; + for (auto i = tid; i < TopK; i += kThreadsPerBlock) { + dst_page_table[i] = (i < length) ? src_page_table[i] : -1; + } +} + +// keep the first `length` entries, set others to -1 +__device__ void naive_topk_transform_ragged( + const float* __restrict__ score, int32_t length, int32_t* __restrict__ topk_indices_ragged, int32_t offset) { + const auto tid = threadIdx.x; + for (auto i = tid; i < TopK; i += kThreadsPerBlock) { + topk_indices_ragged[i] = (i < length) ? static_cast(i) + offset : -1; + } +} + +__device__ __forceinline__ auto convert_to_uint8(float x) -> uint8_t { + __half h = __float2half_rn(x); + uint16_t bits = __half_as_ushort(h); + uint16_t key = (bits & 0x8000) ? static_cast(~bits) : static_cast(bits | 0x8000); + return static_cast(key >> 8); +} + +__device__ __forceinline__ auto convert_to_uint32(float x) -> uint32_t { + uint32_t bits = __float_as_uint(x); + return (bits & 0x80000000u) ? ~bits : (bits | 0x80000000u); +} + +__device__ void fast_topk_cuda_tl(const float* __restrict__ input, int* __restrict__ index, int row_start, int length) { + // An optimized topk kernel copied from tilelang kernel + // We assume length > TopK here, or it will crash + int topk = TopK; + constexpr auto BLOCK_SIZE = 1024; + constexpr auto RADIX = 256; + constexpr auto SMEM_INPUT_SIZE = kSmem / (2 * sizeof(int)); + + alignas(128) __shared__ int s_histogram_buf[2][RADIX + 128]; + alignas(128) __shared__ int s_counter; + alignas(128) __shared__ int s_threshold_bin_id; + alignas(128) __shared__ int s_num_input[2]; + + auto& s_histogram = s_histogram_buf[0]; + // allocate for two rounds + extern __shared__ int s_input_idx[][SMEM_INPUT_SIZE]; + + const int tx = threadIdx.x; + + // stage 1: 8bit coarse histogram + if (tx < RADIX + 1) s_histogram[tx] = 0; + __syncthreads(); + + for (int idx = tx; idx < length; idx += BLOCK_SIZE) { + const auto bin = convert_to_uint8(input[idx + row_start]); + ::atomicAdd(&s_histogram[bin], 1); + } + __syncthreads(); + + const auto run_cumsum = [&] { +#pragma unroll 8 + for (int i = 0; i < 8; ++i) { + static_assert(1 << 8 == RADIX); + if (C10_LIKELY(tx < RADIX)) { + const auto j = 1 << i; + const auto k = i & 1; + auto value = s_histogram_buf[k][tx]; + if (tx < RADIX - j) { + value += s_histogram_buf[k][tx + j]; + } + s_histogram_buf[k ^ 1][tx] = value; + } + __syncthreads(); + } + }; + + run_cumsum(); + if (tx < RADIX && s_histogram[tx] > topk && s_histogram[tx + 1] <= topk) { + s_threshold_bin_id = tx; + s_num_input[0] = 0; + s_counter = 0; + } + __syncthreads(); + + const auto threshold_bin = s_threshold_bin_id; + topk -= s_histogram[threshold_bin + 1]; + + if (topk == 0) { + for (int idx = tx; idx < length; idx += BLOCK_SIZE) { + const auto bin = static_cast(convert_to_uint8(input[idx + row_start])); + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&s_counter, 1); + index[pos] = idx; + } + } + __syncthreads(); + return; + } else { + __syncthreads(); + if (tx < RADIX + 1) { + s_histogram[tx] = 0; + } + __syncthreads(); + + for (int idx = tx; idx < length; idx += BLOCK_SIZE) { + const auto raw_input = input[idx + row_start]; + const auto bin = static_cast(convert_to_uint8(raw_input)); + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&s_counter, 1); + index[pos] = idx; + } else if (bin == threshold_bin) { + const auto pos = ::atomicAdd(&s_num_input[0], 1); + /// NOTE: (dark) fuse the histogram computation here + if (C10_LIKELY(pos < SMEM_INPUT_SIZE)) { + s_input_idx[0][pos] = idx; + const auto bin = convert_to_uint32(raw_input); + const auto sub_bin = (bin >> 24) & 0xFF; + ::atomicAdd(&s_histogram[sub_bin], 1); + } + } + } + __syncthreads(); + } + + // stage 2: refine with 8bit radix passes +#pragma unroll 4 + for (int round = 0; round < 4; ++round) { + __shared__ int s_last_remain; + const auto r_idx = round % 2; + + // clip here to prevent overflow + const auto _raw_num_input = s_num_input[r_idx]; + const auto num_input = (_raw_num_input < int(SMEM_INPUT_SIZE)) ? _raw_num_input : int(SMEM_INPUT_SIZE); + + run_cumsum(); + if (tx < RADIX && s_histogram[tx] > topk && s_histogram[tx + 1] <= topk) { + s_threshold_bin_id = tx; + s_num_input[r_idx ^ 1] = 0; + s_last_remain = topk - s_histogram[tx + 1]; + } + __syncthreads(); + + const auto threshold_bin = s_threshold_bin_id; + topk -= s_histogram[threshold_bin + 1]; + + if (topk == 0) { + for (int i = tx; i < num_input; i += BLOCK_SIZE) { + const auto idx = s_input_idx[r_idx][i]; + const auto offset = 24 - round * 8; + const auto bin = (convert_to_uint32(input[idx + row_start]) >> offset) & 0xFF; + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&s_counter, 1); + index[pos] = idx; + } + } + __syncthreads(); + break; + } else { + __syncthreads(); + if (tx < RADIX + 1) { + s_histogram[tx] = 0; + } + __syncthreads(); + for (int i = tx; i < num_input; i += BLOCK_SIZE) { + const auto idx = s_input_idx[r_idx][i]; + const auto raw_input = input[idx + row_start]; + const auto offset = 24 - round * 8; + const auto bin = (convert_to_uint32(raw_input) >> offset) & 0xFF; + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&s_counter, 1); + index[pos] = idx; + } else if (bin == threshold_bin) { + if (round == 3) { + const auto pos = ::atomicAdd(&s_last_remain, -1); + if (pos > 0) { + index[TopK - pos] = idx; + } + } else { + const auto pos = ::atomicAdd(&s_num_input[r_idx ^ 1], 1); + if (C10_LIKELY(pos < SMEM_INPUT_SIZE)) { + /// NOTE: (dark) fuse the histogram computation here + s_input_idx[r_idx ^ 1][pos] = idx; + const auto bin = convert_to_uint32(raw_input); + const auto sub_bin = (bin >> (offset - 8)) & 0xFF; + ::atomicAdd(&s_histogram[sub_bin], 1); + } + } + } + } + __syncthreads(); + } + } +} + +__global__ __launch_bounds__(kThreadsPerBlock) // topk + void topk_kernel(const FastTopKParams params) { + const auto& [input, row_starts, indices, lengths, input_stride] = params; + const auto bid = static_cast(blockIdx.x); + const auto row_start = row_starts == nullptr ? 0 : row_starts[bid]; + const auto length = lengths[bid]; + const auto indice = indices + bid * TopK; + const auto score = input + bid * input_stride; + if (length <= TopK) { + return naive_topk_cuda(score, indice, length); + } else { + return fast_topk_cuda_tl(score, indice, row_start, length); + } +} + +__global__ __launch_bounds__(kThreadsPerBlock) // decode + void topk_transform_decode_kernel( + const FastTopKParams params, + int32_t* __restrict__ dst_page_table, + const int32_t* __restrict__ src_page_table, + const int64_t src_stride) { + const auto& [input, _1, _2, lengths, input_stride] = params; + const auto bid = static_cast(blockIdx.x); + const auto tid = threadIdx.x; + const auto row_start = 0; + const auto length = lengths[bid]; + const auto src_page_entry = src_page_table + bid * src_stride; + const auto dst_page_entry = dst_page_table + bid * TopK; + const auto score = input + bid * input_stride; + if (length <= TopK) { + return naive_topk_transform(score, length, dst_page_entry, src_page_entry); + } else { + __shared__ int s_indices[TopK]; + fast_topk_cuda_tl(score, s_indices, row_start, length); + // copy src[s_indices] to dst, we manually unroll here + static_assert(TopK % kThreadsPerBlock == 0); + static_assert(TopK / kThreadsPerBlock == 2); + const auto idx_0 = tid; + const auto pos_0 = s_indices[idx_0]; + dst_page_entry[idx_0] = src_page_entry[pos_0]; + const auto idx_1 = tid + kThreadsPerBlock; + const auto pos_1 = s_indices[idx_1]; + dst_page_entry[idx_1] = src_page_entry[pos_1]; + } +} + +__global__ __launch_bounds__(kThreadsPerBlock) // prefill + void topk_transform_prefill_kernel( + const FastTopKParams params, + int32_t* __restrict__ dst_page_table, + const int32_t* __restrict__ src_page_table, + const int64_t src_stride, + const int32_t* __restrict__ cu_seqlens_q, + const int64_t prefill_bs) { + const auto& [input, row_starts, _, lengths, input_stride] = params; + const auto bid = static_cast(blockIdx.x); + const auto tid = threadIdx.x; + const auto length = lengths[bid]; + const auto row_start = row_starts == nullptr ? 0 : row_starts[bid]; + const auto dst_page_entry = dst_page_table + bid * TopK; + const auto score = input + bid * input_stride; + + /// NOTE: prefill bs is usually small, we can just use a simple loop here + /// We ensure that last cu_seqlens is equal to number of blocks launched + __shared__ const int32_t* s_src_page_entry; + if (C10_LIKELY(prefill_bs <= kThreadsPerBlock)) { + if (tid < prefill_bs) { + if (bid >= cu_seqlens_q[tid] && bid < cu_seqlens_q[tid + 1]) { + s_src_page_entry = src_page_table + tid * src_stride; + } + } + } else { + for (int64_t i = tid; i < prefill_bs; i += kThreadsPerBlock) { + if (bid >= cu_seqlens_q[i] && bid < cu_seqlens_q[i + 1]) { + s_src_page_entry = src_page_table + i * src_stride; + } + } + } + __syncthreads(); + const auto src_page_entry = s_src_page_entry; + + if (length <= TopK) { + return naive_topk_transform(score, length, dst_page_entry, src_page_entry); + } else { + __shared__ int s_indices[TopK]; + fast_topk_cuda_tl(score, s_indices, row_start, length); + // copy src[s_indices] to dst, we manually unroll here + static_assert(TopK % kThreadsPerBlock == 0); + static_assert(TopK / kThreadsPerBlock == 2); + const auto idx_0 = tid; + const auto pos_0 = s_indices[idx_0]; + dst_page_entry[idx_0] = src_page_entry[pos_0]; + const auto idx_1 = tid + kThreadsPerBlock; + const auto pos_1 = s_indices[idx_1]; + dst_page_entry[idx_1] = src_page_entry[pos_1]; + } +} + +__global__ __launch_bounds__(kThreadsPerBlock) // prefill, ragged kv + void topk_transform_prefill_ragged_kernel( + const FastTopKParams params, + int32_t* __restrict__ topk_indices_ragged, + const int32_t* __restrict__ topk_indices_offset) { + const auto& [input, row_starts, _, lengths, input_stride] = params; + const auto bid = static_cast(blockIdx.x); + const auto tid = threadIdx.x; + const auto row_start = row_starts == nullptr ? 0 : row_starts[bid]; + const auto length = lengths[bid]; + const auto dst_indices_entry = topk_indices_ragged + bid * TopK; + const auto score = input + bid * input_stride; + const auto offset = topk_indices_offset[bid]; + + if (length <= TopK) { + return naive_topk_transform_ragged(score, length, dst_indices_entry, offset); + } else { + __shared__ int s_indices[TopK]; + fast_topk_cuda_tl(score, s_indices, row_start, length); + // copy src[s_indices] to dst, we manually unroll here + static_assert(TopK % kThreadsPerBlock == 0); + static_assert(TopK / kThreadsPerBlock == 2); + const auto idx_0 = tid; + const auto pos_0 = s_indices[idx_0]; + dst_indices_entry[idx_0] = pos_0 + offset; + const auto idx_1 = tid + kThreadsPerBlock; + const auto pos_1 = s_indices[idx_1]; + dst_indices_entry[idx_1] = pos_1 + offset; + } +} + +auto get_params( + const at::Tensor& score, + const at::Tensor& lengths, + std::optional row_starts_opt = std::nullopt, + std::optional indices_opt = std::nullopt) -> FastTopKParams { + const auto B = score.size(0); + TORCH_CHECK(score.dim() == 2 && score.stride(1) == 1); + if (row_starts_opt.has_value()) { + const auto& row_starts = row_starts_opt.value(); + TORCH_CHECK(row_starts.dim() == 1); + TORCH_CHECK(row_starts.size(0) == B); + } + TORCH_CHECK(lengths.dim() == 1 && lengths.is_contiguous()); + TORCH_CHECK(lengths.size(0) == B); + int32_t* indices_data_ptr = nullptr; + if (indices_opt.has_value()) { + const auto& indices = indices_opt.value(); + TORCH_CHECK(indices.dim() == 2 && indices.is_contiguous()); + TORCH_CHECK(indices.size(0) == B); + TORCH_CHECK(indices.size(1) == TopK); + indices_data_ptr = indices.data_ptr(); + } + + return FastTopKParams{ + .input = score.data_ptr(), + .row_starts = row_starts_opt.has_value() ? row_starts_opt->data_ptr() : nullptr, + .indices = indices_data_ptr, + .lengths = lengths.data_ptr(), + .input_stride = score.stride(0), + }; +} + +template +void setup_kernel_smem_once() { + [[maybe_unused]] + static const auto result = [] { +#ifdef USE_ROCM + // hipify will turn cudaFuncSetAttribute -> hipFuncSetAttribute. On ROCm, + // hipFuncSetAttribute expects `const void*` and hipcc does not accept passing + // a function pointer directly, so cast explicitly. + return ::cudaFuncSetAttribute( + reinterpret_cast(f), ::cudaFuncAttributeMaxDynamicSharedMemorySize, max_dynamic_smem); +#else + // CUDA: keep original behavior (no cast needed). + return ::cudaFuncSetAttribute(f, ::cudaFuncAttributeMaxDynamicSharedMemorySize, max_dynamic_smem); +#endif + }(); + TORCH_CHECK(result == cudaSuccess, "set_up_kernel_once failed:", ::cudaGetErrorString(result)); +} + +} // namespace + +#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") + +void fast_topk_interface( + const at::Tensor& score, at::Tensor& indices, const at::Tensor& lengths, std::optional row_starts_opt) { + CHECK_CUDA(score); + CHECK_CUDA(indices); + if (row_starts_opt.has_value()) { + CHECK_CUDA(row_starts_opt.value()); + } + CHECK_CUDA(lengths); + const auto params = get_params(score, lengths, row_starts_opt, indices); + const auto B = score.size(0); + const auto stream = at::cuda::getCurrentCUDAStream().stream(); + const auto grid = dim3{static_cast(B)}; + const auto block = dim3{kThreadsPerBlock}; + setup_kernel_smem_once(); + topk_kernel<<>>(params); + const auto result = cudaGetLastError(); + TORCH_CHECK(result == cudaSuccess, "topk kernel failed:", ::cudaGetErrorString(result)); +} + +void fast_topk_transform_interface( + const at::Tensor& score, + const at::Tensor& lengths, + at::Tensor& dst_page_table, + const at::Tensor& src_page_table, + const at::Tensor& cu_seqlens_q, + std::optional row_starts_opt) { + CHECK_CUDA(score); + CHECK_CUDA(lengths); + CHECK_CUDA(dst_page_table); + CHECK_CUDA(src_page_table); + CHECK_CUDA(cu_seqlens_q); + if (row_starts_opt.has_value()) { + CHECK_CUDA(row_starts_opt.value()); + } + const auto params = get_params(score, lengths, row_starts_opt); + const auto B = score.size(0); + TORCH_CHECK(dst_page_table.dim() == 2 && dst_page_table.is_contiguous()); + TORCH_CHECK(src_page_table.dim() == 2 && src_page_table.stride(1) == 1); + TORCH_CHECK(cu_seqlens_q.dim() == 1 && cu_seqlens_q.is_contiguous()); + const auto prefill_bs = cu_seqlens_q.size(0) - 1; + TORCH_CHECK(dst_page_table.size(0) == B); + TORCH_CHECK(dst_page_table.size(1) == TopK); + TORCH_CHECK(src_page_table.size(0) == prefill_bs); + TORCH_CHECK(prefill_bs <= B); // prefill_bs should be smaller than expanded bs + + // launch kernel + const auto stream = at::cuda::getCurrentCUDAStream().stream(); + const auto grid = dim3{static_cast(B)}; + const auto block = dim3{kThreadsPerBlock}; + const auto src_stride = src_page_table.stride(0); + + // dispatch to decode or prefill + // extend and draft extend: row_starts_opt is not null, invokes the prefill kernel + // decode: row_starts_opt is null, invokes the decode kernel + // target verify: row_starts_opt is null, invokes the prefill kernel + const auto is_decode = !row_starts_opt.has_value() && prefill_bs == B; + if (is_decode) { + setup_kernel_smem_once(); + topk_transform_decode_kernel<<>>( + params, dst_page_table.data_ptr(), src_page_table.data_ptr(), src_stride); + } else { + setup_kernel_smem_once(); + topk_transform_prefill_kernel<<>>( + params, + dst_page_table.data_ptr(), + src_page_table.data_ptr(), + src_stride, + cu_seqlens_q.data_ptr(), + prefill_bs); + } + + const auto result = cudaGetLastError(); + TORCH_CHECK(result == cudaSuccess, "topk kernel failed:", ::cudaGetErrorString(result)); +} + +void fast_topk_transform_ragged_interface( + const at::Tensor& score, + const at::Tensor& lengths, + at::Tensor& topk_indices_ragged, + const at::Tensor& topk_indices_offset, + std::optional row_starts_opt) { + CHECK_CUDA(score); + CHECK_CUDA(lengths); + CHECK_CUDA(topk_indices_ragged); + CHECK_CUDA(topk_indices_offset); + if (row_starts_opt.has_value()) { + CHECK_CUDA(row_starts_opt.value()); + } + + const auto params = get_params(score, lengths, row_starts_opt); + const auto B = score.size(0); + TORCH_CHECK(topk_indices_ragged.dim() == 2 && topk_indices_ragged.is_contiguous()); + TORCH_CHECK(topk_indices_offset.dim() == 1); + + TORCH_CHECK(topk_indices_ragged.size(0) == B); + TORCH_CHECK(topk_indices_ragged.size(1) == TopK); + TORCH_CHECK(topk_indices_offset.size(0) == B); + + // launch kernel + const auto stream = at::cuda::getCurrentCUDAStream().stream(); + const auto grid = dim3{static_cast(B)}; + const auto block = dim3{kThreadsPerBlock}; + + setup_kernel_smem_once(); + topk_transform_prefill_ragged_kernel<<>>( + params, topk_indices_ragged.data_ptr(), topk_indices_offset.data_ptr()); + + const auto result = cudaGetLastError(); + TORCH_CHECK(result == cudaSuccess, "topk kernel failed:", ::cudaGetErrorString(result)); +} diff --git a/csrc/clean.py b/csrc/clean.py new file mode 100644 index 0000000..8d258bb --- /dev/null +++ b/csrc/clean.py @@ -0,0 +1,21 @@ +from pathlib import Path +import sys + +def clean_one_leading_space(path: str): + p = Path(path) + text = p.read_text(encoding="utf-8") + + cleaned = "".join( + line[1:] if line.startswith(" ") else line + for line in text.splitlines(keepends=True) + ) + + p.write_text(cleaned, encoding="utf-8") + print(f"Cleaned: {p}") + +if __name__ == "__main__": + if len(sys.argv) != 2: + print("Usage: python clean_indent.py ") + sys.exit(1) + + clean_one_leading_space(sys.argv[1]) \ No newline at end of file diff --git a/csrc/register.cc b/csrc/register.cc index fd9d4eb..8aa5aea 100644 --- a/csrc/register.cc +++ b/csrc/register.cc @@ -8,6 +8,54 @@ PYBIND11_MODULE(vortex_torch_C, m){ m.def("Chunkwise_NH2HN_Transpose", &Chunkwise_NH2HN_Transpose); m.def("Chunkwise_HN2NH_Transpose", &Chunkwise_HN2NH_Transpose); m.def("topk_output", &topk_output); + m.def("topk_output_sglang", &topk_output_sglang, + py::arg("x"), py::arg("dense_kv_indptr"), py::arg("sparse_kv_indptr"), + py::arg("dense_kv_indices"), py::arg("sparse_kv_indices"), + py::arg("eff_batch_size"), py::arg("topk_val"), + py::arg("reserved_bos"), py::arg("reserved_eos"), + py::arg("max_num_pages")); + m.def("topk_output_sglang_ori", &topk_output_sglang_ori, + py::arg("x"), py::arg("dense_kv_indptr"), + py::arg("indices_out"), + py::arg("eff_batch_size"), py::arg("topk_val"), + py::arg("reserved_bos"), py::arg("reserved_eos"), + py::arg("max_num_pages")); + m.def("topk_output_sglang_fused", &topk_output_sglang_fused, + py::arg("x"), py::arg("dense_kv_indptr"), py::arg("sparse_kv_indptr"), + py::arg("dense_kv_indices"), py::arg("sparse_kv_indices"), + py::arg("eff_batch_size"), py::arg("topk_val"), + py::arg("reserved_bos"), py::arg("reserved_eos"), + py::arg("max_num_pages"), + py::arg("mapping_mode"), + py::arg("mapping_power"), + py::arg("mapping_lut") = py::none(), + py::arg("mapping_quantiles") = py::none()); + m.def("topk_remap_only", &topk_remap_only, + py::arg("x"), py::arg("dense_kv_indptr"), + py::arg("remapped"), + py::arg("eff_batch_size"), + py::arg("reserved_bos"), py::arg("reserved_eos"), + py::arg("mapping_mode"), + py::arg("mapping_power")); + m.def("topk_profile_histogram", &topk_profile_histogram, + py::arg("x"), py::arg("dense_kv_indptr"), + py::arg("histograms"), py::arg("eff_batch_size"), + py::arg("reserved_bos"), py::arg("reserved_eos"), + py::arg("mapping_mode") = 0, + py::arg("mapping_power") = 0.5, + py::arg("mapping_lut") = py::none(), + py::arg("mapping_quantiles") = py::none()); + m.def("topk_profile_counters", &topk_profile_counters, + py::arg("x"), py::arg("dense_kv_indptr"), py::arg("sparse_kv_indptr"), + py::arg("dense_kv_indices"), py::arg("sparse_kv_indices"), + py::arg("counters"), + py::arg("eff_batch_size"), py::arg("topk_val"), + py::arg("reserved_bos"), py::arg("reserved_eos"), + py::arg("max_num_pages"), + py::arg("mapping_mode") = 0, + py::arg("mapping_power") = 0.5, + py::arg("mapping_lut") = py::none(), + py::arg("mapping_quantiles") = py::none()); m.def("sglang_plan_decode_fa3", &sglang_plan_decode_fa3); m.def("sglang_plan_prefill_fa3", &sglang_plan_prefill_fa3); m.def("Chunkwise_HN2NH_Transpose_FA3", &Chunkwise_HN2NH_Transpose_FA3); diff --git a/csrc/register.h b/csrc/register.h index 92499ed..afdb97f 100644 --- a/csrc/register.h +++ b/csrc/register.h @@ -85,6 +85,88 @@ const int64_t reserved_eos, const int64_t max_seq_lengths ); +void topk_output_sglang( +const at::Tensor& x, +const at::Tensor& dense_kv_indptr, +const at::Tensor& sparse_kv_indptr, +const at::Tensor& dense_kv_indices, +at::Tensor& sparse_kv_indices, +const int64_t eff_batch_size, +const int64_t topk_val, +const int64_t reserved_bos, +const int64_t reserved_eos, +const int64_t max_seq_lengths +); + +void topk_output_sglang_ori( +const at::Tensor& x, +const at::Tensor& dense_kv_indptr, +at::Tensor& indices_out, +const int64_t eff_batch_size, +const int64_t topk_val, +const int64_t reserved_bos, +const int64_t reserved_eos, +const int64_t max_num_pages +); + +void topk_output_sglang_fused( +const at::Tensor& x, +const at::Tensor& dense_kv_indptr, +const at::Tensor& sparse_kv_indptr, +const at::Tensor& dense_kv_indices, +at::Tensor& sparse_kv_indices, +const int64_t eff_batch_size, +const int64_t topk_val, +const int64_t reserved_bos, +const int64_t reserved_eos, +const int64_t max_num_pages, +const int64_t mapping_mode, +const double mapping_power, +std::optional mapping_lut = std::nullopt, +std::optional mapping_quantiles = std::nullopt +); + +void topk_remap_only( +const at::Tensor& x, +const at::Tensor& dense_kv_indptr, +at::Tensor& remapped, +const int64_t eff_batch_size, +const int64_t reserved_bos, +const int64_t reserved_eos, +const int64_t mapping_mode, +const double mapping_power +); + +void topk_profile_histogram( +const at::Tensor& x, +const at::Tensor& dense_kv_indptr, +at::Tensor& histograms, +const int64_t eff_batch_size, +const int64_t reserved_bos, +const int64_t reserved_eos, +const int64_t mapping_mode = 0, +const double mapping_power = 0.5, +std::optional mapping_lut = std::nullopt, +std::optional mapping_quantiles = std::nullopt +); + +void topk_profile_counters( +const at::Tensor& x, +const at::Tensor& dense_kv_indptr, +const at::Tensor& sparse_kv_indptr, +const at::Tensor& dense_kv_indices, +at::Tensor& sparse_kv_indices, +at::Tensor& counters, +const int64_t eff_batch_size, +const int64_t topk_val, +const int64_t reserved_bos, +const int64_t reserved_eos, +const int64_t max_num_pages, +const int64_t mapping_mode = 0, +const double mapping_power = 0.5, +std::optional mapping_lut = std::nullopt, +std::optional mapping_quantiles = std::nullopt +); void sglang_plan_decode_fa3( const at::Tensor& cached_seq_lens, diff --git a/csrc/topk.cu b/csrc/topk.cu index 62d747e..081bddf 100644 --- a/csrc/topk.cu +++ b/csrc/topk.cu @@ -117,8 +117,8 @@ const int page_reserved_eos) void topk_output( const at::Tensor& x, const at::Tensor& dense_kv_indptr, -const at::Tensor& sparse_kv_indptr, const at::Tensor& dense_kv_indices, +const at::Tensor& sparse_kv_indptr, at::Tensor& sparse_kv_indices, const int64_t eff_batch_size, const int64_t topk_val, @@ -196,8 +196,20 @@ const int64_t max_num_pages reserved_bos, reserved_eos ); + } else if (max_num_pages <= 8192){ + TopKOutput_BF16_Kernel<512, 16><<>>( + reinterpret_cast<__nv_bfloat16*>(x.data_ptr()), + dense_kv_indptr.data_ptr(), + sparse_kv_indptr.data_ptr(), + dense_kv_indices.data_ptr(), + sparse_kv_indices.data_ptr(), + topk_val, + reserved_bos, + reserved_eos + ); } else { - TORCH_CHECK(false); + TORCH_CHECK(false, "topk_output: max_num_pages=", max_num_pages, + " exceeds the supported template ladder (8192)."); } -} +} \ No newline at end of file diff --git a/csrc/topk_mapping.cuh b/csrc/topk_mapping.cuh new file mode 100644 index 0000000..c645acb --- /dev/null +++ b/csrc/topk_mapping.cuh @@ -0,0 +1,232 @@ +#pragma once +#include +#include +#include + +// ============================================================ +// TopK bucket-sort Stage-1 remap transforms (lean version). +// +// These are element-wise transforms applied to scores before +// the Stage-1 8-bit histogram bucketing. The goal is to spread +// a skewed raw distribution more uniformly across the 256 bins +// so the threshold bin shrinks and Stage-2 refinement does less +// work. Stage 2 still uses convert_to_uint32() on the remapped +// value's raw bits for tie-breaking. +// +// There is no pre-pass, no auto-range, no LUT, no quantile +// table, and no shared-memory state — each transform is a +// pure function of one float. The heavy pre-pass machinery +// (auto-range, pivot, tail-window, topk-window, LUT_CDF, +// QUANTILE, SUBTRACT, TRUNC8) lives in +// csrc/archived/fast_topk_vortex_prepass.cu. +// ============================================================ + +enum TopKMappingMode { + MAPPING_NONE = 0, // identity (no remap) + MAPPING_LUT_CDF = 1, // bin lookup: new_bin = lut[convert_to_uint8(x)] + MAPPING_QUANTILE = 2, // binary search over 256 calibrated quantile thresholds + MAPPING_POWER = 3, // sign(x) * |x|^p + MAPPING_LOG = 4, // sign(x) * log(|x| + 1) + MAPPING_ASINH = 6, // asinh(beta * x) + MAPPING_LOG1P = 7, // sign(x) * log1p(alpha * |x|) + MAPPING_TRUNC8 = 8, // identity bucketing (historical name, alias of MAPPING_NONE) + MAPPING_ERF = 9, // erf(alpha * x) + MAPPING_TANH = 10, // tanh(alpha * x) + MAPPING_SUBTRACT = 11, // x - pivot, with pivot = power_exp (free hyperparameter) + MAPPING_EXP_STRETCH = 13, // exp(alpha * x) + // Top-spreading transforms (see CLAUDE.md / remap bench plan): + // amplify differences in the high-score region so the top-K values + // occupy multiple Stage-1 bins instead of collapsing into one. + MAPPING_SHIFT_POW2 = 15, // sign(x - p) * (x - p)^2 [p = power_exp] + MAPPING_SHIFT_POW3 = 16, // (x - p)^3 [p = power_exp] + MAPPING_LINEAR_STEEP = 17, // x + k * max(x, 0) [k = power_exp] + // One-sided spread: collapse below-pivot values into a single bin so + // every above-pivot page gets its own slice of the 256-bin histogram. + MAPPING_HALF_SQUARE = 18, // max(x - p, 0)^2 [p = power_exp] + MAPPING_HALF_CUBE = 19, // max(x - p, 0)^3 [p = power_exp] + // Bit-level remap: identity value transform, but the Stage-1 bucket + // function in fast_topk_clean_fused switches to a mantissa-heavy bit + // slice (bits [23:16] of convert_to_uint32) that gives 128 sub-bins + // per exponent slot instead of 4. Zero per-element compute overhead; + // the "remap" is the bucket change. Monotonic within 2 adjacent + // fp32 exponent slots. + MAPPING_DENSE_MANT = 20, // identity; bucketing handled in fused kernel +}; + +struct TopKMappingParams { + int mode; // TopKMappingMode + float power_exp; // Free hyperparameter: p / alpha / beta / pivot depending on mode + const uint8_t* __restrict__ lut; // [256] uint8 LUT, MAPPING_LUT_CDF only + const float* __restrict__ quantiles; // [256] float quantile breakpoints, MAPPING_QUANTILE only +}; + +// ---- Element-wise transforms ---- + +__device__ __forceinline__ float transform_power(float x, float p) { + return copysignf(__powf(fabsf(x), p), x); +} + +__device__ __forceinline__ float transform_log(float x) { + return copysignf(__logf(fabsf(x) + 1.0f), x); +} + +__device__ __forceinline__ float transform_asinh(float x, float beta) { + return asinhf(beta * x); +} + +__device__ __forceinline__ float transform_log1p(float x, float alpha) { + return copysignf(log1pf(alpha * fabsf(x)), x); +} + +__device__ __forceinline__ float transform_erf(float x, float alpha) { + return erff(alpha * x); +} + +__device__ __forceinline__ float transform_tanh(float x, float alpha) { + return tanhf(alpha * x); +} + +__device__ __forceinline__ float transform_exp_stretch(float x, float alpha) { + float z = alpha * x; + z = fminf(z, 80.0f); // prevent float32 overflow (exp(80) ~ 5.5e34) + return expf(z); +} + +// Signed squared distance from a pivot. ~3 ops (1 sub, 1 mul, 1 copysign). +// Quadratically amplifies differences between values far from pivot so the +// top-K region gets spread across multiple Stage-1 bins. +__device__ __forceinline__ float transform_shift_pow2(float x, float pivot) { + const float d = x - pivot; + return copysignf(d * d, d); +} + +// Signed cubic of distance from pivot. ~3 ops (1 sub, 2 mul; odd function so +// no copysign). Steeper growth than pow2 for even tighter top-K clusters. +__device__ __forceinline__ float transform_shift_pow3(float x, float pivot) { + const float d = x - pivot; + return d * d * d; +} + +// Half-range linear stretch: positive values get multiplied by (1 + k), +// negative values pass through untouched. ~2 ops (fmax + fma). For softmax- +// style attention scores (which are non-negative after softmax), k = 8..16 +// shifts the positive fp16 exponent up by 3..4 slots and empties out the +// collision at the top of the distribution. +__device__ __forceinline__ float transform_linear_steep(float x, float k) { + return fmaf(k, fmaxf(x, 0.0f), x); +} + +// One-sided shifted square: values below pivot collapse to 0 (they all end +// up in the same low Stage-1 bin), above-pivot values are squared so their +// differences amplify quadratically. ~2 ops (fmax + mul). The whole 256-bin +// histogram becomes dedicated to the top slice of the distribution. +__device__ __forceinline__ float transform_half_square(float x, float pivot) { + const float d = fmaxf(x - pivot, 0.0f); + return d * d; +} + +// One-sided shifted cube: like half_square but cubic. ~3 ops. Best when the +// top-K region is even more tightly clustered and needs steeper amplification. +__device__ __forceinline__ float transform_half_cube(float x, float pivot) { + const float d = fmaxf(x - pivot, 0.0f); + return d * d * d; +} + +// Compile-time templated dispatcher. When the caller knows the mapping mode +// at template-instantiation time, this lets the compiler fully inline the +// transform into the Stage-1 inner loop and eliminate the runtime switch +// that `apply_transform` would otherwise perform per element. Used by the +// per-mode specializations of `fast_topk_clean_fused` in topk_sglang.cu. +template +__device__ __forceinline__ float apply_transform_tmpl(float x, float p) { + if constexpr (MODE == MAPPING_POWER) return transform_power(x, p); + else if constexpr (MODE == MAPPING_LOG) return transform_log(x); + else if constexpr (MODE == MAPPING_ASINH) return transform_asinh(x, p); + else if constexpr (MODE == MAPPING_LOG1P) return transform_log1p(x, p); + else if constexpr (MODE == MAPPING_ERF) return transform_erf(x, p); + else if constexpr (MODE == MAPPING_TANH) return transform_tanh(x, p); + else if constexpr (MODE == MAPPING_SUBTRACT) return x - p; + else if constexpr (MODE == MAPPING_EXP_STRETCH) return transform_exp_stretch(x, p); + else if constexpr (MODE == MAPPING_SHIFT_POW2) return transform_shift_pow2(x, p); + else if constexpr (MODE == MAPPING_SHIFT_POW3) return transform_shift_pow3(x, p); + else if constexpr (MODE == MAPPING_LINEAR_STEEP) return transform_linear_steep(x, p); + else if constexpr (MODE == MAPPING_HALF_SQUARE) return transform_half_square(x, p); + else if constexpr (MODE == MAPPING_HALF_CUBE) return transform_half_cube(x, p); + else if constexpr (MODE == MAPPING_DENSE_MANT) return fmaxf(x, p); + else return x; // NONE / TRUNC8 +} + +// Pure element-wise dispatcher. Returns the *float value* after the transform. +// For bin-selection modes (LUT_CDF / QUANTILE) this is identity: the mapping +// happens in compute_stage1_bin() below instead of via a float transform, so +// Stage-2 tie-breaking uses the raw score bits for those modes. +__device__ __forceinline__ float apply_transform(float x, const TopKMappingParams& params) { + switch (params.mode) { + case MAPPING_POWER: return transform_power(x, params.power_exp); + case MAPPING_LOG: return transform_log(x); + case MAPPING_ASINH: return transform_asinh(x, params.power_exp); + case MAPPING_LOG1P: return transform_log1p(x, params.power_exp); + case MAPPING_ERF: return transform_erf(x, params.power_exp); + case MAPPING_TANH: return transform_tanh(x, params.power_exp); + case MAPPING_SUBTRACT: return x - params.power_exp; + case MAPPING_EXP_STRETCH: return transform_exp_stretch(x, params.power_exp); + case MAPPING_SHIFT_POW2: return transform_shift_pow2(x, params.power_exp); + case MAPPING_SHIFT_POW3: return transform_shift_pow3(x, params.power_exp); + case MAPPING_LINEAR_STEEP: return transform_linear_steep(x, params.power_exp); + case MAPPING_HALF_SQUARE: return transform_half_square(x, params.power_exp); + case MAPPING_HALF_CUBE: return transform_half_cube(x, params.power_exp); + // MAPPING_DENSE_MANT clamps small/negative values to `power_exp` + // (default 0.5) so the subsequent dense bit bucket in the fused + // kernel sees a narrow 1–2 exponent window of positive values. + // Values at/below the clamp all hash to the lowest bin, which + // is always below the topk threshold in practice. + case MAPPING_DENSE_MANT: return fmaxf(x, params.power_exp); + case MAPPING_LUT_CDF: + case MAPPING_QUANTILE: + case MAPPING_TRUNC8: + default: return x; // NONE / TRUNC8 / LUT_CDF / QUANTILE + } +} + +// Whether the mapping mode is a direct bin-selection function (LUT_CDF / +// QUANTILE). These modes need per-block shared-memory tables. +__device__ __forceinline__ bool mapping_uses_table(int mode) { + return mode == MAPPING_LUT_CDF || mode == MAPPING_QUANTILE; +} + +// Binary search over a sorted [256] quantile table. Returns the largest +// index i such that x >= quantiles[i], in [0, 255]. +__device__ __forceinline__ uint8_t quantile_bin_lookup( + float x, const float* __restrict__ s_quantiles) +{ + int lo = 0, hi = 255; +#pragma unroll 8 + for (int iter = 0; iter < 8; ++iter) { + int mid = (lo + hi + 1) >> 1; + if (x >= s_quantiles[mid]) lo = mid; + else hi = mid - 1; + } + return static_cast(lo); +} + +// Forward decl so compute_stage1_bin can call it. Defined in the enclosing TU. +__device__ __forceinline__ uint8_t convert_to_uint8(float x); + +// Compute the Stage-1 bin for a raw score under any mapping mode. LUT_CDF / +// QUANTILE use the shared-memory tables loaded at the kernel entry; every +// other mode falls back to convert_to_uint8(apply_transform(x)). +__device__ __forceinline__ uint8_t compute_stage1_bin( + float raw, + const TopKMappingParams& params, + const uint8_t* __restrict__ s_lut, + const float* __restrict__ s_quantiles) +{ + switch (params.mode) { + case MAPPING_LUT_CDF: + return s_lut[convert_to_uint8(raw)]; + case MAPPING_QUANTILE: + return quantile_bin_lookup(raw, s_quantiles); + default: + return convert_to_uint8(apply_transform(raw, params)); + } +} diff --git a/csrc/topk_sglang.cu b/csrc/topk_sglang.cu new file mode 100644 index 0000000..73366df --- /dev/null +++ b/csrc/topk_sglang.cu @@ -0,0 +1,1377 @@ +/** + * Vortex TopK kernels. + * + * Three production kernels: + * - fast_topk_clean : unmapped baseline (two-stage radix). + * - fast_topk_clean_fused : remap + topk fused (apply_transform + * applied inline in Stage-1 bucketing). + * - TopKRemapOnly_Kernel : standalone element-wise remap pass + * used by the split-phase benchmark. + * + * Profiling kernels (counter collection, histogram collection) live in + * topk_sglang_profile.cu and MUST NOT be used for latency measurements — + * they intentionally write extra diagnostic state to global memory. + * + * Archived / historical kernels: csrc/archived/ (fast_topk_vortex with + * pre-pass modes, TopKOutput_Ori_Kernel with flexible radix_bits, the + * original SGLang reference kernel). + */ +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +namespace { + +constexpr int TopK = 2048; +constexpr int kThreadsPerBlock = 1024; + +#ifdef USE_ROCM +// On ROCm, the per-workgroup LDS budget depends on the target arch, so we inject a +// per-arch value from `setup_rocm.py` via `-DSGL_TOPK_DYNAMIC_SMEM_BYTES=...`. +#ifdef SGL_TOPK_DYNAMIC_SMEM_BYTES +constexpr size_t kSmem = static_cast(SGL_TOPK_DYNAMIC_SMEM_BYTES); +#else +constexpr size_t kSmem = 48 * 1024; // bytes +#endif +#else +// Reduced from 128KB to 32KB to improve occupancy. +// Each radix pass needs at most ~TopK candidates in the threshold bin, +// so 4K entries per round (2 rounds = 8K entries = 32KB) is sufficient. +constexpr size_t kSmem = 8 * 1024 * sizeof(uint32_t); // 32KB (bytes) +#endif + +// Fused-kernel dynamic smem ceiling. The fused kernel uses `kSmem` bytes for +// f_input_idx (2 × SMEM_INPUT_SIZE ints) AND an extra `max_num_pages` bytes +// for s_bins (one uint8_t per page). Ceiling of 96 KB covers max_num_pages up +// to 65536 and fits the opt-in dynamic-smem limits on every target in +// setup.py (sm_86 ≥99KB, sm_89 100KB, sm_90 228KB, sm_100a/120 ≥100KB). +// Only `topk_output_sglang_fused` uses this ceiling; the other kernels keep +// kSmem as their dynamic-smem budget. +constexpr size_t kFusedSmemMax = 96 * 1024; + +struct FastTopKParams { + const float* __restrict__ input; // [B, input_stride] + const int32_t* __restrict__ row_starts; // [B] + int32_t* __restrict__ indices; // [B, TopK] + int32_t* __restrict__ lengths; // [B] + int64_t input_stride; +}; + +// when length <= TopK, we can directly write the indices +__device__ void naive_topk_cuda(const float* __restrict__ score, int32_t* __restrict__ indice, int32_t length) { + const auto tid = threadIdx.x; + for (int i = tid; i < TopK; i += kThreadsPerBlock) { + indice[i] = (i < length) ? i : -1; + } +} + +// keep the first `length` entries, set others to -1 +__device__ void naive_topk_transform( + const float* __restrict__ score, + int32_t length, + int32_t* __restrict__ dst_page_table, + const int32_t* __restrict__ src_page_table) { + const auto tid = threadIdx.x; + for (auto i = tid; i < TopK; i += kThreadsPerBlock) { + dst_page_table[i] = (i < length) ? src_page_table[i] : -1; + } +} + +// keep the first `length` entries, set others to -1 +__device__ void naive_topk_transform_ragged( + const float* __restrict__ score, int32_t length, int32_t* __restrict__ topk_indices_ragged, int32_t offset) { + const auto tid = threadIdx.x; + for (auto i = tid; i < TopK; i += kThreadsPerBlock) { + topk_indices_ragged[i] = (i < length) ? static_cast(i) + offset : -1; + } +} + +__device__ __forceinline__ auto convert_to_uint8(float x) -> uint8_t { + __half h = __float2half_rn(x); + uint16_t bits = __half_as_ushort(h); + uint16_t key = (bits & 0x8000) ? static_cast(~bits) : static_cast(bits | 0x8000); + return static_cast(key >> 8); +} + +__device__ __forceinline__ auto convert_to_uint32(float x) -> uint32_t { + uint32_t bits = __float_as_uint(x); + return (bits & 0x80000000u) ? ~bits : (bits | 0x80000000u); +} + +// Mantissa-heavy Stage-1 bucket for MAPPING_DENSE_MANT. Returns bits +// [23:16] of the sign-adjusted float32 key = 1 exp LSB + 7 top +// mantissa bits. This yields 128 mantissa sub-bins per exp slot (vs +// 4 in the current fp16 scheme — 32× more resolution) and is strictly +// monotonic across 2 adjacent fp32 exponent slots (factor-of-4 value +// range). Designed for the common case where the top-K scores cluster +// tightly: softmax-attention outputs on Qwen / Llama typically live +// in ~1 exp slot of magnitude near the top. Values with exponents +// outside the 2-slot monotonic window collide with lower bins, which +// only causes a correctness issue if top-K elements span more than +// 2 exp slots — verified empirically before shipping. +__device__ __forceinline__ auto convert_to_uint8_dense(float x) -> uint8_t { + const uint32_t bits = __float_as_uint(x); + const uint32_t key = (bits & 0x80000000u) ? ~bits : (bits | 0x80000000u); + return static_cast((key >> 16) & 0xFFu); +} + +// ---- Vortex additions ---- + +template +__device__ __forceinline__ float vortex_to_float(T x); +template <> +__device__ __forceinline__ float vortex_to_float(float x) { return x; } +template <> +__device__ __forceinline__ float vortex_to_float<__nv_bfloat16>(__nv_bfloat16 x) { + return __bfloat162float(x); +} + +constexpr int VORTEX_MAX_TOPK = 2048; + +#include "topk_mapping.cuh" + + +__device__ void fast_topk_cuda_tl(const float* __restrict__ input, int* __restrict__ index, int row_start, int length) { + // An optimized topk kernel copied from tilelang kernel + // We assume length > TopK here, or it will crash + int topk = TopK; + constexpr auto BLOCK_SIZE = 1024; + constexpr auto RADIX = 256; + constexpr auto SMEM_INPUT_SIZE = kSmem / (2 * sizeof(int)); + + alignas(128) __shared__ int s_histogram_buf[2][RADIX + 128]; + alignas(128) __shared__ int s_counter; + alignas(128) __shared__ int s_threshold_bin_id; + alignas(128) __shared__ int s_num_input[2]; + + auto& s_histogram = s_histogram_buf[0]; + // allocate for two rounds + extern __shared__ int s_input_idx[][SMEM_INPUT_SIZE]; + + const int tx = threadIdx.x; + + // stage 1: 8bit coarse histogram + if (tx < RADIX + 1) s_histogram[tx] = 0; + __syncthreads(); + + for (int idx = tx; idx < length; idx += BLOCK_SIZE) { + const auto bin = convert_to_uint8(input[idx + row_start]); + ::atomicAdd(&s_histogram[bin], 1); + } + __syncthreads(); + + const auto run_cumsum = [&] { +#pragma unroll 8 + for (int i = 0; i < 8; ++i) { + static_assert(1 << 8 == RADIX); + if (C10_LIKELY(tx < RADIX)) { + const auto j = 1 << i; + const auto k = i & 1; + auto value = s_histogram_buf[k][tx]; + if (tx < RADIX - j) { + value += s_histogram_buf[k][tx + j]; + } + s_histogram_buf[k ^ 1][tx] = value; + } + __syncthreads(); + } + }; + + run_cumsum(); + if (tx < RADIX && s_histogram[tx] > topk && s_histogram[tx + 1] <= topk) { + s_threshold_bin_id = tx; + s_num_input[0] = 0; + s_counter = 0; + } + __syncthreads(); + + const auto threshold_bin = s_threshold_bin_id; + topk -= s_histogram[threshold_bin + 1]; + + if (topk == 0) { + for (int idx = tx; idx < length; idx += BLOCK_SIZE) { + const auto bin = static_cast(convert_to_uint8(input[idx + row_start])); + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&s_counter, 1); + index[pos] = idx; + } + } + __syncthreads(); + return; + } else { + __syncthreads(); + if (tx < RADIX + 1) { + s_histogram[tx] = 0; + } + __syncthreads(); + + for (int idx = tx; idx < length; idx += BLOCK_SIZE) { + const auto raw_input = input[idx + row_start]; + const auto bin = static_cast(convert_to_uint8(raw_input)); + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&s_counter, 1); + index[pos] = idx; + } else if (bin == threshold_bin) { + const auto pos = ::atomicAdd(&s_num_input[0], 1); + /// NOTE: (dark) fuse the histogram computation here + if (C10_LIKELY(pos < SMEM_INPUT_SIZE)) { + s_input_idx[0][pos] = idx; + const auto bin = convert_to_uint32(raw_input); + const auto sub_bin = (bin >> 24) & 0xFF; + ::atomicAdd(&s_histogram[sub_bin], 1); + } + } + } + __syncthreads(); + } + + // stage 2: refine with 8bit radix passes +#pragma unroll 4 + for (int round = 0; round < 4; ++round) { + __shared__ int s_last_remain; + const auto r_idx = round % 2; + + // clip here to prevent overflow + const auto _raw_num_input = s_num_input[r_idx]; + const auto num_input = (_raw_num_input < int(SMEM_INPUT_SIZE)) ? _raw_num_input : int(SMEM_INPUT_SIZE); + + run_cumsum(); + if (tx < RADIX && s_histogram[tx] > topk && s_histogram[tx + 1] <= topk) { + s_threshold_bin_id = tx; + s_num_input[r_idx ^ 1] = 0; + s_last_remain = topk - s_histogram[tx + 1]; + } + __syncthreads(); + + const auto threshold_bin = s_threshold_bin_id; + topk -= s_histogram[threshold_bin + 1]; + + if (topk == 0) { + for (int i = tx; i < num_input; i += BLOCK_SIZE) { + const auto idx = s_input_idx[r_idx][i]; + const auto offset = 24 - round * 8; + const auto bin = (convert_to_uint32(input[idx + row_start]) >> offset) & 0xFF; + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&s_counter, 1); + index[pos] = idx; + } + } + __syncthreads(); + break; + } else { + __syncthreads(); + if (tx < RADIX + 1) { + s_histogram[tx] = 0; + } + __syncthreads(); + for (int i = tx; i < num_input; i += BLOCK_SIZE) { + const auto idx = s_input_idx[r_idx][i]; + const auto raw_input = input[idx + row_start]; + const auto offset = 24 - round * 8; + const auto bin = (convert_to_uint32(raw_input) >> offset) & 0xFF; + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&s_counter, 1); + index[pos] = idx; + } else if (bin == threshold_bin) { + if (round == 3) { + const auto pos = ::atomicAdd(&s_last_remain, -1); + if (pos > 0) { + index[TopK - pos] = idx; + } + } else { + const auto pos = ::atomicAdd(&s_num_input[r_idx ^ 1], 1); + if (C10_LIKELY(pos < SMEM_INPUT_SIZE)) { + /// NOTE: (dark) fuse the histogram computation here + s_input_idx[r_idx ^ 1][pos] = idx; + const auto bin = convert_to_uint32(raw_input); + const auto sub_bin = (bin >> (offset - 8)) & 0xFF; + ::atomicAdd(&s_histogram[sub_bin], 1); + } + } + } + } + __syncthreads(); + } + } +} + +__global__ __launch_bounds__(kThreadsPerBlock) // topk + void topk_kernel(const FastTopKParams params) { + const auto& [input, row_starts, indices, lengths, input_stride] = params; + const auto bid = static_cast(blockIdx.x); + const auto row_start = row_starts == nullptr ? 0 : row_starts[bid]; + const auto length = lengths[bid]; + const auto indice = indices + bid * TopK; + const auto score = input + bid * input_stride; + if (length <= TopK) { + return naive_topk_cuda(score, indice, length); + } else { + return fast_topk_cuda_tl(score, indice, row_start, length); + } +} + +__global__ __launch_bounds__(kThreadsPerBlock) // decode + void topk_transform_decode_kernel( + const FastTopKParams params, + int32_t* __restrict__ dst_page_table, + const int32_t* __restrict__ src_page_table, + const int64_t src_stride) { + const auto& [input, _1, _2, lengths, input_stride] = params; + const auto bid = static_cast(blockIdx.x); + const auto tid = threadIdx.x; + const auto row_start = 0; + const auto length = lengths[bid]; + const auto src_page_entry = src_page_table + bid * src_stride; + const auto dst_page_entry = dst_page_table + bid * TopK; + const auto score = input + bid * input_stride; + if (length <= TopK) { + return naive_topk_transform(score, length, dst_page_entry, src_page_entry); + } else { + __shared__ int s_indices[TopK]; + fast_topk_cuda_tl(score, s_indices, row_start, length); + // copy src[s_indices] to dst, we manually unroll here + static_assert(TopK % kThreadsPerBlock == 0); + static_assert(TopK / kThreadsPerBlock == 2); + const auto idx_0 = tid; + const auto pos_0 = s_indices[idx_0]; + dst_page_entry[idx_0] = src_page_entry[pos_0]; + const auto idx_1 = tid + kThreadsPerBlock; + const auto pos_1 = s_indices[idx_1]; + dst_page_entry[idx_1] = src_page_entry[pos_1]; + } +} + +__global__ __launch_bounds__(kThreadsPerBlock) // prefill + void topk_transform_prefill_kernel( + const FastTopKParams params, + int32_t* __restrict__ dst_page_table, + const int32_t* __restrict__ src_page_table, + const int64_t src_stride, + const int32_t* __restrict__ cu_seqlens_q, + const int64_t prefill_bs) { + const auto& [input, row_starts, _, lengths, input_stride] = params; + const auto bid = static_cast(blockIdx.x); + const auto tid = threadIdx.x; + const auto length = lengths[bid]; + const auto row_start = row_starts == nullptr ? 0 : row_starts[bid]; + const auto dst_page_entry = dst_page_table + bid * TopK; + const auto score = input + bid * input_stride; + + /// NOTE: prefill bs is usually small, we can just use a simple loop here + /// We ensure that last cu_seqlens is equal to number of blocks launched + __shared__ const int32_t* s_src_page_entry; + if (C10_LIKELY(prefill_bs <= kThreadsPerBlock)) { + if (tid < prefill_bs) { + if (bid >= cu_seqlens_q[tid] && bid < cu_seqlens_q[tid + 1]) { + s_src_page_entry = src_page_table + tid * src_stride; + } + } + } else { + for (int64_t i = tid; i < prefill_bs; i += kThreadsPerBlock) { + if (bid >= cu_seqlens_q[i] && bid < cu_seqlens_q[i + 1]) { + s_src_page_entry = src_page_table + i * src_stride; + } + } + } + __syncthreads(); + const auto src_page_entry = s_src_page_entry; + + if (length <= TopK) { + return naive_topk_transform(score, length, dst_page_entry, src_page_entry); + } else { + __shared__ int s_indices[TopK]; + fast_topk_cuda_tl(score, s_indices, row_start, length); + // copy src[s_indices] to dst, we manually unroll here + static_assert(TopK % kThreadsPerBlock == 0); + static_assert(TopK / kThreadsPerBlock == 2); + const auto idx_0 = tid; + const auto pos_0 = s_indices[idx_0]; + dst_page_entry[idx_0] = src_page_entry[pos_0]; + const auto idx_1 = tid + kThreadsPerBlock; + const auto pos_1 = s_indices[idx_1]; + dst_page_entry[idx_1] = src_page_entry[pos_1]; + } +} + +__global__ __launch_bounds__(kThreadsPerBlock) // prefill, ragged kv + void topk_transform_prefill_ragged_kernel( + const FastTopKParams params, + int32_t* __restrict__ topk_indices_ragged, + const int32_t* __restrict__ topk_indices_offset) { + const auto& [input, row_starts, _, lengths, input_stride] = params; + const auto bid = static_cast(blockIdx.x); + const auto tid = threadIdx.x; + const auto row_start = row_starts == nullptr ? 0 : row_starts[bid]; + const auto length = lengths[bid]; + const auto dst_indices_entry = topk_indices_ragged + bid * TopK; + const auto score = input + bid * input_stride; + const auto offset = topk_indices_offset[bid]; + + if (length <= TopK) { + return naive_topk_transform_ragged(score, length, dst_indices_entry, offset); + } else { + __shared__ int s_indices[TopK]; + fast_topk_cuda_tl(score, s_indices, row_start, length); + // copy src[s_indices] to dst, we manually unroll here + static_assert(TopK % kThreadsPerBlock == 0); + static_assert(TopK / kThreadsPerBlock == 2); + const auto idx_0 = tid; + const auto pos_0 = s_indices[idx_0]; + dst_indices_entry[idx_0] = pos_0 + offset; + const auto idx_1 = tid + kThreadsPerBlock; + const auto pos_1 = s_indices[idx_1]; + dst_indices_entry[idx_1] = pos_1 + offset; + } +} + +auto get_params( + const at::Tensor& score, + const at::Tensor& lengths, + std::optional row_starts_opt = std::nullopt, + std::optional indices_opt = std::nullopt) -> FastTopKParams { + const auto B = score.size(0); + TORCH_CHECK(score.dim() == 2 && score.stride(1) == 1); + if (row_starts_opt.has_value()) { + const auto& row_starts = row_starts_opt.value(); + TORCH_CHECK(row_starts.dim() == 1); + TORCH_CHECK(row_starts.size(0) == B); + } + TORCH_CHECK(lengths.dim() == 1 && lengths.is_contiguous()); + TORCH_CHECK(lengths.size(0) == B); + int32_t* indices_data_ptr = nullptr; + if (indices_opt.has_value()) { + const auto& indices = indices_opt.value(); + TORCH_CHECK(indices.dim() == 2 && indices.is_contiguous()); + TORCH_CHECK(indices.size(0) == B); + TORCH_CHECK(indices.size(1) == TopK); + indices_data_ptr = indices.data_ptr(); + } + + return FastTopKParams{ + .input = score.data_ptr(), + .row_starts = row_starts_opt.has_value() ? row_starts_opt->data_ptr() : nullptr, + .indices = indices_data_ptr, + .lengths = lengths.data_ptr(), + .input_stride = score.stride(0), + }; +} + +template +void setup_kernel_smem_once() { + [[maybe_unused]] + static const auto result = [] { +#ifdef USE_ROCM + // hipify will turn cudaFuncSetAttribute -> hipFuncSetAttribute. On ROCm, + // hipFuncSetAttribute expects `const void*` and hipcc does not accept passing + // a function pointer directly, so cast explicitly. + return ::cudaFuncSetAttribute( + reinterpret_cast(f), ::cudaFuncAttributeMaxDynamicSharedMemorySize, max_dynamic_smem); +#else + // CUDA: keep original behavior (no cast needed). + return ::cudaFuncSetAttribute(f, ::cudaFuncAttributeMaxDynamicSharedMemorySize, max_dynamic_smem); +#endif + }(); + TORCH_CHECK(result == cudaSuccess, "set_up_kernel_once failed:", ::cudaGetErrorString(result)); +} + +// ====================================================================== +// Templated clean baseline: identical algorithm to fast_topk_cuda_tl but +// parameterised on ScoreT (float or __nv_bfloat16) for the GQA / paged +// call paths that operate on bf16 attention scores. No mapping, no +// pre-pass — pure two-stage radix topk on fp16 bit-pattern bins. +// ====================================================================== +template +__device__ void fast_topk_clean( + const ScoreT* __restrict__ input, + int* __restrict__ index, + int row_start, + int length, + int target_k) +{ + int topk = target_k; + constexpr auto BLOCK_SIZE = 1024; + constexpr auto RADIX = 256; + constexpr auto SMEM_INPUT_SIZE = kSmem / (2 * sizeof(int)); + + alignas(128) __shared__ int s_histogram_buf[2][RADIX + 128]; + alignas(128) __shared__ int s_counter; + alignas(128) __shared__ int s_threshold_bin_id; + alignas(128) __shared__ int s_num_input[2]; + + auto& s_histogram = s_histogram_buf[0]; + extern __shared__ int s_input_idx[][SMEM_INPUT_SIZE]; + + const int tx = threadIdx.x; + + // stage 1: 8-bit coarse histogram + if (tx < RADIX + 1) s_histogram[tx] = 0; + __syncthreads(); + + for (int idx = tx; idx < length; idx += BLOCK_SIZE) { + const auto bin = convert_to_uint8(vortex_to_float(input[idx + row_start])); + ::atomicAdd(&s_histogram[bin], 1); + } + __syncthreads(); + + const auto run_cumsum = [&] { +#pragma unroll 8 + for (int i = 0; i < 8; ++i) { + static_assert(1 << 8 == RADIX); + if (C10_LIKELY(tx < RADIX)) { + const auto j = 1 << i; + const auto k = i & 1; + auto value = s_histogram_buf[k][tx]; + if (tx < RADIX - j) { + value += s_histogram_buf[k][tx + j]; + } + s_histogram_buf[k ^ 1][tx] = value; + } + __syncthreads(); + } + }; + + run_cumsum(); + if (tx < RADIX && s_histogram[tx] > topk && s_histogram[tx + 1] <= topk) { + s_threshold_bin_id = tx; + s_num_input[0] = 0; + s_counter = 0; + } + __syncthreads(); + + const auto threshold_bin = s_threshold_bin_id; + topk -= s_histogram[threshold_bin + 1]; + + if (topk == 0) { + for (int idx = tx; idx < length; idx += BLOCK_SIZE) { + const auto bin = static_cast(convert_to_uint8(vortex_to_float(input[idx + row_start]))); + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&s_counter, 1); + index[pos] = idx; + } + } + __syncthreads(); + return; + } else { + __syncthreads(); + if (tx < RADIX + 1) s_histogram[tx] = 0; + __syncthreads(); + + for (int idx = tx; idx < length; idx += BLOCK_SIZE) { + const auto raw_input = vortex_to_float(input[idx + row_start]); + const auto bin = static_cast(convert_to_uint8(raw_input)); + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&s_counter, 1); + index[pos] = idx; + } else if (bin == threshold_bin) { + const auto pos = ::atomicAdd(&s_num_input[0], 1); + if (C10_LIKELY(pos < SMEM_INPUT_SIZE)) { + s_input_idx[0][pos] = idx; + const auto b32 = convert_to_uint32(raw_input); + const auto sub_bin = (b32 >> 24) & 0xFF; + ::atomicAdd(&s_histogram[sub_bin], 1); + } + } + } + __syncthreads(); + } + + // stage 2: refine with 8-bit radix passes on raw fp32 bits +#pragma unroll 4 + for (int round = 0; round < 4; ++round) { + __shared__ int s_last_remain; + const auto r_idx = round % 2; + + const auto _raw_num_input = s_num_input[r_idx]; + const auto num_input = (_raw_num_input < int(SMEM_INPUT_SIZE)) ? _raw_num_input : int(SMEM_INPUT_SIZE); + + run_cumsum(); + if (tx < RADIX && s_histogram[tx] > topk && s_histogram[tx + 1] <= topk) { + s_threshold_bin_id = tx; + s_num_input[r_idx ^ 1] = 0; + s_last_remain = topk - s_histogram[tx + 1]; + } + __syncthreads(); + + const auto threshold_bin = s_threshold_bin_id; + topk -= s_histogram[threshold_bin + 1]; + + if (topk == 0) { + for (int i = tx; i < num_input; i += BLOCK_SIZE) { + const auto idx = s_input_idx[r_idx][i]; + const auto offset = 24 - round * 8; + const auto bin = (convert_to_uint32(vortex_to_float(input[idx + row_start])) >> offset) & 0xFF; + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&s_counter, 1); + index[pos] = idx; + } + } + __syncthreads(); + break; + } else { + __syncthreads(); + if (tx < RADIX + 1) s_histogram[tx] = 0; + __syncthreads(); + for (int i = tx; i < num_input; i += BLOCK_SIZE) { + const auto idx = s_input_idx[r_idx][i]; + const auto raw_input = vortex_to_float(input[idx + row_start]); + const auto offset = 24 - round * 8; + const auto bin = (convert_to_uint32(raw_input) >> offset) & 0xFF; + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&s_counter, 1); + index[pos] = idx; + } else if (bin == threshold_bin) { + if (round == 3) { + const auto pos = ::atomicAdd(&s_last_remain, -1); + if (pos > 0) { + index[target_k - pos] = idx; + } + } else { + const auto pos = ::atomicAdd(&s_num_input[r_idx ^ 1], 1); + if (C10_LIKELY(pos < SMEM_INPUT_SIZE)) { + s_input_idx[r_idx ^ 1][pos] = idx; + const auto b32 = convert_to_uint32(raw_input); + const auto sub_bin = (b32 >> (offset - 8)) & 0xFF; + ::atomicAdd(&s_histogram[sub_bin], 1); + } + } + } + } + __syncthreads(); + } + } +} + +// ====================================================================== +// Templated fused kernel: apply_transform(score) -> convert_to_uint8 +// is fused into Stage 1. Stage 2 still uses raw bits for tie-breaking +// (on the *remapped* value, not the original score) — this is a +// benchmarking kernel, the remapped Stage-2 ordering is acceptable. +// No pre-pass, no LUT, no shared-memory mapping state. +// ====================================================================== +template +__device__ void fast_topk_clean_fused( + const ScoreT* __restrict__ input, + int* __restrict__ index, + int row_start, + int length, + int target_k, + const TopKMappingParams mapping) +{ + int topk = target_k; + constexpr auto BLOCK_SIZE = 1024; + constexpr auto RADIX = 256; + constexpr auto SMEM_INPUT_SIZE = kSmem / (2 * sizeof(int)); + + alignas(128) __shared__ int f_histogram_buf[2][RADIX + 128]; + alignas(128) __shared__ int f_counter; + alignas(128) __shared__ int f_threshold_bin_id; + alignas(128) __shared__ int f_num_input[2]; + + // Per-element Stage-1 bin cache. Pass 1 of Stage 1 writes one byte per + // element; pass 2 reads it back so each element only pays a single + // apply_transform + global score read instead of two. + // + // s_bins lives in DYNAMIC shared memory, placed immediately after the + // f_input_idx[2][SMEM_INPUT_SIZE] 2D array in the same extern __shared__ + // region. The host launch reserves `kSmem + max_num_pages` dynamic bytes + // (see `topk_output_sglang_fused`) so every block has `max_num_pages` + // bytes available past f_input_idx's 32 KB span. Per-block `length` + // (from dense_kv_indptr) is ≤ max_num_pages, so indexing stays in bounds. + // + // This layout keeps smem usage at kSmem + 4 KB for the existing + // pages_per_seg ≤ 4096 regimes (identical to the old 32 KB dynamic + + // 4 KB static) and only grows when the caller asks for a larger + // pages_per_seg — no occupancy regression on small configs. + + auto& f_histogram = f_histogram_buf[0]; + extern __shared__ int f_input_idx[][SMEM_INPUT_SIZE]; + uint8_t* const s_bins = reinterpret_cast(&f_input_idx[2][0]); + + const int tx = threadIdx.x; + + // MODE is a compile-time template parameter, so every comparison below + // becomes a constant-folded `if constexpr` branch. The dense bucket + // path (MAPPING_DENSE_MANT) stays in the kernel but is completely + // elided when MODE != MAPPING_DENSE_MANT, and the value-space transform + // path stays in place for standard modes. LUT_CDF / QUANTILE are not + // supported by this templated kernel (they were dropped from the bench + // comparison earlier). + constexpr bool use_dense_bucket = (MODE == MAPPING_DENSE_MANT); + + if (tx < RADIX + 1) f_histogram[tx] = 0; + __syncthreads(); + + // Stage 1 pass 1: read each score from global, compute the Stage-1 + // bin via the compile-time-dispatched transform, cache it in s_bins so + // pass 2 can skip the second global read. With MODE known at compile + // time, apply_transform_tmpl inlines to just the chosen + // transform's instructions — no runtime switch overhead. + for (int idx = tx; idx < length; idx += BLOCK_SIZE) { + const float raw = vortex_to_float(input[idx + row_start]); + const float remapped = apply_transform_tmpl(raw, mapping.power_exp); + int bin; + if constexpr (use_dense_bucket) { + bin = static_cast(convert_to_uint8_dense(remapped)); + } else { + bin = static_cast(convert_to_uint8(remapped)); + } + s_bins[idx] = static_cast(bin); + ::atomicAdd(&f_histogram[bin], 1); + } + __syncthreads(); + + const auto run_cumsum = [&] { +#pragma unroll 8 + for (int i = 0; i < 8; ++i) { + static_assert(1 << 8 == RADIX); + if (C10_LIKELY(tx < RADIX)) { + const auto j = 1 << i; + const auto k = i & 1; + auto value = f_histogram_buf[k][tx]; + if (tx < RADIX - j) { + value += f_histogram_buf[k][tx + j]; + } + f_histogram_buf[k ^ 1][tx] = value; + } + __syncthreads(); + } + }; + + run_cumsum(); + if (tx < RADIX && f_histogram[tx] > topk && f_histogram[tx + 1] <= topk) { + f_threshold_bin_id = tx; + f_num_input[0] = 0; + f_counter = 0; + } + __syncthreads(); + + const auto threshold_bin = f_threshold_bin_id; + topk -= f_histogram[threshold_bin + 1]; + + if (topk == 0) { + // Shortcut: every page above threshold gets selected. Read the bin + // from the cache so we don't re-touch global memory or recompute + // apply_transform. + for (int idx = tx; idx < length; idx += BLOCK_SIZE) { + const int bin = static_cast(s_bins[idx]); + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&f_counter, 1); + index[pos] = idx; + } + } + __syncthreads(); + return; + } else { + __syncthreads(); + if (tx < RADIX + 1) f_histogram[tx] = 0; + __syncthreads(); + + // Stage 1 pass 2: read the cached bin from SMEM. For elements + // outside the threshold bin we skip the global-memory load AND the + // apply_transform call entirely. Only the ~thr_size threshold-bin + // candidates re-read raw and re-apply the templated transform to + // compute the sub-bin needed for Stage-2 refinement. + // + // Sub-bin shift selection (compile-time constant): + // - standard modes: Stage-1 used fp16 top-8-bit bucketing, so + // Stage-2 round 0 refines on uint32 bits [31:24] (the most + // significant bits not captured by the fp16 bucket). + // - MAPPING_DENSE_MANT: Stage-1 used bits [23:16], so the next + // useful discriminator is bits [15:8]. Skipping to offset 8 + // directly avoids two wasted Stage-2 rounds. + constexpr int sub_bin_offset_start = use_dense_bucket ? 8 : 24; + for (int idx = tx; idx < length; idx += BLOCK_SIZE) { + const int bin = static_cast(s_bins[idx]); + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&f_counter, 1); + index[pos] = idx; + } else if (bin == threshold_bin) { + const float raw = vortex_to_float(input[idx + row_start]); + const float remapped = apply_transform_tmpl(raw, mapping.power_exp); + const auto pos = ::atomicAdd(&f_num_input[0], 1); + if (C10_LIKELY(pos < SMEM_INPUT_SIZE)) { + f_input_idx[0][pos] = idx; + const auto b32 = convert_to_uint32(remapped); + const auto sub_bin = (b32 >> sub_bin_offset_start) & 0xFF; + ::atomicAdd(&f_histogram[sub_bin], 1); + } + } + } + __syncthreads(); + } + + // stage 2: refine on raw bits of the remapped value. The per-round + // bit offset matches the sub_bin shift chosen above: standard modes + // start at offset 24 (bits [31:24]) and step down by 8 per round; + // MAPPING_DENSE_MANT starts at offset 8 (bits [15:8]) because Stage 1 + // already consumed bits [23:16] in the dense bucket. Both values are + // compile-time constants since MODE is a template parameter. + constexpr int stage2_offset_start = use_dense_bucket ? 8 : 24; + constexpr int stage2_max_rounds = use_dense_bucket ? 2 : 4; +#pragma unroll 4 + for (int round = 0; round < 4; ++round) { + if (round >= stage2_max_rounds) break; + __shared__ int f_last_remain; + const auto r_idx = round % 2; + + const auto _raw_num_input = f_num_input[r_idx]; + const auto num_input = (_raw_num_input < int(SMEM_INPUT_SIZE)) ? _raw_num_input : int(SMEM_INPUT_SIZE); + + run_cumsum(); + if (tx < RADIX && f_histogram[tx] > topk && f_histogram[tx + 1] <= topk) { + f_threshold_bin_id = tx; + f_num_input[r_idx ^ 1] = 0; + f_last_remain = topk - f_histogram[tx + 1]; + } + __syncthreads(); + + const auto threshold_bin = f_threshold_bin_id; + topk -= f_histogram[threshold_bin + 1]; + + if (topk == 0) { + for (int i = tx; i < num_input; i += BLOCK_SIZE) { + const auto idx = f_input_idx[r_idx][i]; + const auto offset = stage2_offset_start - round * 8; + const float raw = vortex_to_float(input[idx + row_start]); + const float remapped = apply_transform_tmpl(raw, mapping.power_exp); + const auto bin = (convert_to_uint32(remapped) >> offset) & 0xFF; + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&f_counter, 1); + index[pos] = idx; + } + } + __syncthreads(); + break; + } else { + __syncthreads(); + if (tx < RADIX + 1) f_histogram[tx] = 0; + __syncthreads(); + for (int i = tx; i < num_input; i += BLOCK_SIZE) { + const auto idx = f_input_idx[r_idx][i]; + const float raw = vortex_to_float(input[idx + row_start]); + const float remapped = apply_transform_tmpl(raw, mapping.power_exp); + const auto offset = stage2_offset_start - round * 8; + const auto bin = (convert_to_uint32(remapped) >> offset) & 0xFF; + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&f_counter, 1); + index[pos] = idx; + } else if (bin == threshold_bin) { + // Last refinement round: we have no more discriminator bits + // below the current offset, so emit any remaining elements as + // "tie-break fallback" via f_last_remain (ensures topk is met + // even when thr_size > sel_thr at the finest granularity). + if (round == stage2_max_rounds - 1) { + const auto pos = ::atomicAdd(&f_last_remain, -1); + if (pos > 0) { + index[target_k - pos] = idx; + } + } else { + const auto pos = ::atomicAdd(&f_num_input[r_idx ^ 1], 1); + if (C10_LIKELY(pos < SMEM_INPUT_SIZE)) { + f_input_idx[r_idx ^ 1][pos] = idx; + const auto b32 = convert_to_uint32(remapped); + const auto sub_bin = (b32 >> (offset - 8)) & 0xFF; + ::atomicAdd(&f_histogram[sub_bin], 1); + } + } + } + } + __syncthreads(); + } + } +} + +// Wrapper kernels: one CUDA block per (batch*head) segment. + +template +__global__ __launch_bounds__(kThreadsPerBlock) +void TopKOutput_Clean_Kernel( + const ScoreT* __restrict__ score, + const int* __restrict__ dense_kv_indptr, + const int* __restrict__ sparse_kv_indptr, + const int* __restrict__ dense_kv_indices, + int* __restrict__ sparse_kv_indices, + const int topk_val, + const int page_reserved_bos, + const int page_reserved_eos) +{ + const int bx = blockIdx.x; + + const int start = dense_kv_indptr[bx] + page_reserved_bos; + const int end = dense_kv_indptr[bx + 1] - page_reserved_eos; + const int nblk = end - start; + if (nblk <= topk_val) return; + + const ScoreT* __restrict__ score_blk = score + start; + const int* __restrict__ idx_blk = dense_kv_indices + start; + int* __restrict__ out_blk = sparse_kv_indices + + sparse_kv_indptr[bx] + + page_reserved_bos; + + __shared__ int s_indices[VORTEX_MAX_TOPK]; + fast_topk_clean(score_blk, s_indices, 0, nblk, topk_val); + __syncthreads(); + + const int tx = threadIdx.x; + for (int i = tx; i < topk_val; i += kThreadsPerBlock) { + out_blk[i] = idx_blk[s_indices[i]]; + } +} + +template +__global__ __launch_bounds__(kThreadsPerBlock) +void TopKOutput_Fused_Kernel( + const ScoreT* __restrict__ score, + const int* __restrict__ dense_kv_indptr, + const int* __restrict__ sparse_kv_indptr, + const int* __restrict__ dense_kv_indices, + int* __restrict__ sparse_kv_indices, + const int topk_val, + const int page_reserved_bos, + const int page_reserved_eos, + const TopKMappingParams mapping) +{ + const int bx = blockIdx.x; + + const int start = dense_kv_indptr[bx] + page_reserved_bos; + const int end = dense_kv_indptr[bx + 1] - page_reserved_eos; + const int nblk = end - start; + if (nblk <= topk_val) return; + + const ScoreT* __restrict__ score_blk = score + start; + const int* __restrict__ idx_blk = dense_kv_indices + start; + int* __restrict__ out_blk = sparse_kv_indices + + sparse_kv_indptr[bx] + + page_reserved_bos; + + __shared__ int s_indices[VORTEX_MAX_TOPK]; + fast_topk_clean_fused(score_blk, s_indices, 0, nblk, topk_val, mapping); + __syncthreads(); + + const int tx = threadIdx.x; + for (int i = tx; i < topk_val; i += kThreadsPerBlock) { + out_blk[i] = idx_blk[s_indices[i]]; + } +} + +// Inverse of vortex_to_float: narrow a float back to ScoreT for the +// bf16-output remap path so the subsequent topk kernel can read half +// the bytes of a fp32 remapped buffer. +template +__device__ __forceinline__ T float_to_vortex(float x); +template <> +__device__ __forceinline__ float float_to_vortex(float x) { return x; } +template <> +__device__ __forceinline__ __nv_bfloat16 float_to_vortex<__nv_bfloat16>(float x) { + return __float2bfloat16(x); +} + +// Remap-only kernel: applies the element-wise transform to each score +// in the [dense_kv_indptr[b] + reserved_bos, dense_kv_indptr[b+1] - reserved_eos) +// range and writes the result into an output tensor (OutT = float or +// bf16). Used by the split-phase benchmark (remap → unmapped topk). +// Writing bf16 halves memory bandwidth on the output and on the +// subsequent topk read; precision-wise it's lossless for the Stage-1 +// 8-bit bucket because fp16/bf16 both discard more mantissa than the +// bucket uses. +template +__global__ __launch_bounds__(kThreadsPerBlock) +void TopKRemapOnly_Kernel( + const ScoreT* __restrict__ score, + const int* __restrict__ dense_kv_indptr, + OutT* __restrict__ remapped, + const int page_reserved_bos, + const int page_reserved_eos, + const TopKMappingParams mapping) +{ + const int bx = blockIdx.x; + const int tx = threadIdx.x; + + const int start = dense_kv_indptr[bx] + page_reserved_bos; + const int end = dense_kv_indptr[bx + 1] - page_reserved_eos; + const int nblk = end - start; + if (nblk <= 0) return; + + const ScoreT* __restrict__ score_blk = score + start; + OutT* __restrict__ remap_blk = remapped + start; + + for (int i = tx; i < nblk; i += kThreadsPerBlock) { + const float y = apply_transform(vortex_to_float(score_blk[i]), mapping); + remap_blk[i] = float_to_vortex(y); + } +} + +} // namespace + +#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") + +void fast_topk_interface( + const at::Tensor& score, at::Tensor& indices, const at::Tensor& lengths, std::optional row_starts_opt) { + CHECK_CUDA(score); + CHECK_CUDA(indices); + if (row_starts_opt.has_value()) { + CHECK_CUDA(row_starts_opt.value()); + } + CHECK_CUDA(lengths); + const auto params = get_params(score, lengths, row_starts_opt, indices); + const auto B = score.size(0); + const auto stream = at::cuda::getCurrentCUDAStream().stream(); + const auto grid = dim3{static_cast(B)}; + const auto block = dim3{kThreadsPerBlock}; + setup_kernel_smem_once(); + topk_kernel<<>>(params); + const auto result = cudaGetLastError(); + TORCH_CHECK(result == cudaSuccess, "topk kernel failed:", ::cudaGetErrorString(result)); +} + +void fast_topk_transform_interface( + const at::Tensor& score, + const at::Tensor& lengths, + at::Tensor& dst_page_table, + const at::Tensor& src_page_table, + const at::Tensor& cu_seqlens_q, + std::optional row_starts_opt) { + CHECK_CUDA(score); + CHECK_CUDA(lengths); + CHECK_CUDA(dst_page_table); + CHECK_CUDA(src_page_table); + CHECK_CUDA(cu_seqlens_q); + if (row_starts_opt.has_value()) { + CHECK_CUDA(row_starts_opt.value()); + } + const auto params = get_params(score, lengths, row_starts_opt); + const auto B = score.size(0); + TORCH_CHECK(dst_page_table.dim() == 2 && dst_page_table.is_contiguous()); + TORCH_CHECK(src_page_table.dim() == 2 && src_page_table.stride(1) == 1); + TORCH_CHECK(cu_seqlens_q.dim() == 1 && cu_seqlens_q.is_contiguous()); + const auto prefill_bs = cu_seqlens_q.size(0) - 1; + TORCH_CHECK(dst_page_table.size(0) == B); + TORCH_CHECK(dst_page_table.size(1) == TopK); + TORCH_CHECK(src_page_table.size(0) == prefill_bs); + TORCH_CHECK(prefill_bs <= B); // prefill_bs should be smaller than expanded bs + + // launch kernel + const auto stream = at::cuda::getCurrentCUDAStream().stream(); + const auto grid = dim3{static_cast(B)}; + const auto block = dim3{kThreadsPerBlock}; + const auto src_stride = src_page_table.stride(0); + + // dispatch to decode or prefill + // extend and draft extend: row_starts_opt is not null, invokes the prefill kernel + // decode: row_starts_opt is null, invokes the decode kernel + // target verify: row_starts_opt is null, invokes the prefill kernel + const auto is_decode = !row_starts_opt.has_value() && prefill_bs == B; + if (is_decode) { + setup_kernel_smem_once(); + topk_transform_decode_kernel<<>>( + params, dst_page_table.data_ptr(), src_page_table.data_ptr(), src_stride); + } else { + setup_kernel_smem_once(); + topk_transform_prefill_kernel<<>>( + params, + dst_page_table.data_ptr(), + src_page_table.data_ptr(), + src_stride, + cu_seqlens_q.data_ptr(), + prefill_bs); + } + + const auto result = cudaGetLastError(); + TORCH_CHECK(result == cudaSuccess, "topk kernel failed:", ::cudaGetErrorString(result)); +} + +void fast_topk_transform_ragged_interface( + const at::Tensor& score, + const at::Tensor& lengths, + at::Tensor& topk_indices_ragged, + const at::Tensor& topk_indices_offset, + std::optional row_starts_opt) { + CHECK_CUDA(score); + CHECK_CUDA(lengths); + CHECK_CUDA(topk_indices_ragged); + CHECK_CUDA(topk_indices_offset); + if (row_starts_opt.has_value()) { + CHECK_CUDA(row_starts_opt.value()); + } + + const auto params = get_params(score, lengths, row_starts_opt); + const auto B = score.size(0); + TORCH_CHECK(topk_indices_ragged.dim() == 2 && topk_indices_ragged.is_contiguous()); + TORCH_CHECK(topk_indices_offset.dim() == 1); + + TORCH_CHECK(topk_indices_ragged.size(0) == B); + TORCH_CHECK(topk_indices_ragged.size(1) == TopK); + TORCH_CHECK(topk_indices_offset.size(0) == B); + + // launch kernel + const auto stream = at::cuda::getCurrentCUDAStream().stream(); + const auto grid = dim3{static_cast(B)}; + const auto block = dim3{kThreadsPerBlock}; + + setup_kernel_smem_once(); + topk_transform_prefill_ragged_kernel<<>>( + params, topk_indices_ragged.data_ptr(), topk_indices_offset.data_ptr()); + + const auto result = cudaGetLastError(); + TORCH_CHECK(result == cudaSuccess, "topk kernel failed:", ::cudaGetErrorString(result)); +} + +// ====================================================================== +// Vortex host entry point — unmapped baseline topk (no remap). +// This is the "original topk kernel" used as the benchmarking baseline. +// ====================================================================== +void topk_output_sglang( + const at::Tensor& x, + const at::Tensor& dense_kv_indptr, + const at::Tensor& sparse_kv_indptr, + const at::Tensor& dense_kv_indices, + at::Tensor& sparse_kv_indices, + const int64_t eff_batch_size, + const int64_t topk_val, + const int64_t reserved_bos, + const int64_t reserved_eos, + const int64_t max_num_pages) +{ + TORCH_CHECK(topk_val <= VORTEX_MAX_TOPK, + "topk_output: topk_val (", topk_val, + ") exceeds VORTEX_MAX_TOPK (", VORTEX_MAX_TOPK, ")"); + + CHECK_CUDA(x); + CHECK_CUDA(dense_kv_indptr); + CHECK_CUDA(sparse_kv_indptr); + CHECK_CUDA(dense_kv_indices); + CHECK_CUDA(sparse_kv_indices); + + dim3 nblks(eff_batch_size); + dim3 nthreads(kThreadsPerBlock); + cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); + + if (x.scalar_type() == at::ScalarType::BFloat16) { + setup_kernel_smem_once, kSmem>(); + TopKOutput_Clean_Kernel<__nv_bfloat16><<>>( + reinterpret_cast<__nv_bfloat16*>(x.data_ptr()), + dense_kv_indptr.data_ptr(), + sparse_kv_indptr.data_ptr(), + dense_kv_indices.data_ptr(), + sparse_kv_indices.data_ptr(), + topk_val, reserved_bos, reserved_eos); + } else if (x.scalar_type() == at::ScalarType::Float) { + setup_kernel_smem_once, kSmem>(); + TopKOutput_Clean_Kernel<<>>( + x.data_ptr(), + dense_kv_indptr.data_ptr(), + sparse_kv_indptr.data_ptr(), + dense_kv_indices.data_ptr(), + sparse_kv_indices.data_ptr(), + topk_val, reserved_bos, reserved_eos); + } else { + TORCH_CHECK(false, "topk_output: unsupported dtype ", x.scalar_type()); + } + + const auto result = cudaGetLastError(); + TORCH_CHECK(result == cudaSuccess, + "topk_output kernel failed: ", ::cudaGetErrorString(result)); +} + +// ====================================================================== +// Fused remap + topk host entry. Applies apply_transform(score, mapping) +// inline inside the Stage-1 histogram build — single kernel launch, +// single pass over the score tensor. +// ====================================================================== +void topk_output_sglang_fused( + const at::Tensor& x, + const at::Tensor& dense_kv_indptr, + const at::Tensor& sparse_kv_indptr, + const at::Tensor& dense_kv_indices, + at::Tensor& sparse_kv_indices, + const int64_t eff_batch_size, + const int64_t topk_val, + const int64_t reserved_bos, + const int64_t reserved_eos, + const int64_t max_num_pages, + const int64_t mapping_mode, + const double mapping_power, + std::optional mapping_lut, + std::optional mapping_quantiles) +{ + TORCH_CHECK(topk_val <= VORTEX_MAX_TOPK, + "topk_output_sglang_fused: topk_val (", topk_val, + ") exceeds VORTEX_MAX_TOPK (", VORTEX_MAX_TOPK, ")"); + + // Dynamic-smem layout for the fused kernel: + // [ f_input_idx (2 × SMEM_INPUT_SIZE × sizeof(int) = kSmem bytes) + // s_bins (bins_bytes = align_up(max_num_pages, 16)) ] + // The per-launch smem request equals the total of both. It must fit + // under kFusedSmemMax, which setup_kernel_smem_once opted this kernel + // into via cudaFuncSetAttribute(MaxDynamicSharedMemorySize, ...). + const size_t bins_bytes = (static_cast(max_num_pages) + size_t(15)) & ~size_t(15); + const size_t smem_bytes = kSmem + bins_bytes; + TORCH_CHECK(smem_bytes <= kFusedSmemMax, + "topk_output_sglang_fused: max_num_pages (", max_num_pages, + ") exceeds the fused kernel's dynamic smem ceiling. " + "Requested smem=", smem_bytes, " bytes, ceiling=", kFusedSmemMax, + " bytes. Raise kFusedSmemMax (and verify GPU opt-in limits) or " + "reduce pages_per_seg."); + + // The `mapping_lut` / `mapping_quantiles` optional tensors are + // retained in the pybind signature for API backward compatibility + // but are ignored: the templated fused kernel drops the LUT_CDF / + // QUANTILE code paths entirely. + (void)mapping_lut; + (void)mapping_quantiles; + + CHECK_CUDA(x); + CHECK_CUDA(dense_kv_indptr); + CHECK_CUDA(sparse_kv_indptr); + CHECK_CUDA(dense_kv_indices); + CHECK_CUDA(sparse_kv_indices); + + TopKMappingParams mapping{}; + mapping.mode = static_cast(mapping_mode); + mapping.power_exp = static_cast(mapping_power); + mapping.lut = nullptr; + mapping.quantiles = nullptr; + + dim3 nblks(eff_batch_size); + dim3 nthreads(kThreadsPerBlock); + cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); + + // Each mapping mode compiles to its own kernel specialization so + // apply_transform_tmpl is fully inlined (no runtime switch on + // mode in the inner loop). The wrapper's outer dispatch is a one- + // time per-call cost, negligible relative to the kernel runtime. + #define VORTEX_DISPATCH_FUSED(DTYPE, PTR_EXPR, MODE_VAL) \ + do { \ + setup_kernel_smem_once, kFusedSmemMax>(); \ + TopKOutput_Fused_Kernel<<>>( \ + PTR_EXPR, \ + dense_kv_indptr.data_ptr(), \ + sparse_kv_indptr.data_ptr(), \ + dense_kv_indices.data_ptr(), \ + sparse_kv_indices.data_ptr(), \ + topk_val, reserved_bos, reserved_eos, mapping); \ + } while (0) + + #define VORTEX_DISPATCH_MODE(DTYPE, PTR_EXPR) \ + do { \ + switch (mapping.mode) { \ + case MAPPING_NONE: VORTEX_DISPATCH_FUSED(DTYPE, PTR_EXPR, MAPPING_NONE); break; \ + case MAPPING_POWER: VORTEX_DISPATCH_FUSED(DTYPE, PTR_EXPR, MAPPING_POWER); break; \ + case MAPPING_LOG: VORTEX_DISPATCH_FUSED(DTYPE, PTR_EXPR, MAPPING_LOG); break; \ + case MAPPING_ASINH: VORTEX_DISPATCH_FUSED(DTYPE, PTR_EXPR, MAPPING_ASINH); break; \ + case MAPPING_LOG1P: VORTEX_DISPATCH_FUSED(DTYPE, PTR_EXPR, MAPPING_LOG1P); break; \ + case MAPPING_TRUNC8: VORTEX_DISPATCH_FUSED(DTYPE, PTR_EXPR, MAPPING_TRUNC8); break; \ + case MAPPING_ERF: VORTEX_DISPATCH_FUSED(DTYPE, PTR_EXPR, MAPPING_ERF); break; \ + case MAPPING_TANH: VORTEX_DISPATCH_FUSED(DTYPE, PTR_EXPR, MAPPING_TANH); break; \ + case MAPPING_SUBTRACT: VORTEX_DISPATCH_FUSED(DTYPE, PTR_EXPR, MAPPING_SUBTRACT); break; \ + case MAPPING_EXP_STRETCH: VORTEX_DISPATCH_FUSED(DTYPE, PTR_EXPR, MAPPING_EXP_STRETCH); break; \ + case MAPPING_SHIFT_POW2: VORTEX_DISPATCH_FUSED(DTYPE, PTR_EXPR, MAPPING_SHIFT_POW2); break; \ + case MAPPING_SHIFT_POW3: VORTEX_DISPATCH_FUSED(DTYPE, PTR_EXPR, MAPPING_SHIFT_POW3); break; \ + case MAPPING_LINEAR_STEEP:VORTEX_DISPATCH_FUSED(DTYPE, PTR_EXPR, MAPPING_LINEAR_STEEP); break; \ + case MAPPING_HALF_SQUARE: VORTEX_DISPATCH_FUSED(DTYPE, PTR_EXPR, MAPPING_HALF_SQUARE); break; \ + case MAPPING_HALF_CUBE: VORTEX_DISPATCH_FUSED(DTYPE, PTR_EXPR, MAPPING_HALF_CUBE); break; \ + case MAPPING_DENSE_MANT: VORTEX_DISPATCH_FUSED(DTYPE, PTR_EXPR, MAPPING_DENSE_MANT); break; \ + default: \ + TORCH_CHECK(false, "topk_output_sglang_fused: unsupported mapping_mode ", mapping.mode); \ + } \ + } while (0) + + if (x.scalar_type() == at::ScalarType::BFloat16) { + VORTEX_DISPATCH_MODE(__nv_bfloat16, reinterpret_cast<__nv_bfloat16*>(x.data_ptr())); + } else if (x.scalar_type() == at::ScalarType::Float) { + VORTEX_DISPATCH_MODE(float, x.data_ptr()); + } else { + TORCH_CHECK(false, "topk_output_sglang_fused: unsupported dtype ", x.scalar_type()); + } + + #undef VORTEX_DISPATCH_MODE + #undef VORTEX_DISPATCH_FUSED + + const auto result = cudaGetLastError(); + TORCH_CHECK(result == cudaSuccess, + "topk_output_sglang_fused kernel failed: ", ::cudaGetErrorString(result)); +} + +// ====================================================================== +// Standalone remap kernel. Writes apply_transform(score) into a +// float32 output buffer without running topk. Used by the split-phase +// benchmark (remap → unmapped topk) to measure each phase independently. +// ====================================================================== +void topk_remap_only( + const at::Tensor& x, + const at::Tensor& dense_kv_indptr, + at::Tensor& remapped, // float32 or bfloat16, same numel as x + const int64_t eff_batch_size, + const int64_t reserved_bos, + const int64_t reserved_eos, + const int64_t mapping_mode, + const double mapping_power) +{ + CHECK_CUDA(x); + CHECK_CUDA(dense_kv_indptr); + CHECK_CUDA(remapped); + TORCH_CHECK(remapped.scalar_type() == at::ScalarType::Float + || remapped.scalar_type() == at::ScalarType::BFloat16, + "remapped output must be float32 or bfloat16"); + + TopKMappingParams mapping{}; + mapping.mode = static_cast(mapping_mode); + mapping.power_exp = static_cast(mapping_power); + mapping.lut = nullptr; + mapping.quantiles = nullptr; + + dim3 nblks(eff_batch_size); + dim3 nthreads(kThreadsPerBlock); + cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); + + // Four-way dispatch on (input dtype, output dtype). bf16→bf16 is the + // new "batch pre-transform" path that halves memory bandwidth vs the + // fp32 output: the remap writes half the bytes and the subsequent + // topk_output_sglang reads half the bytes. Precision is preserved + // because Stage-1 bucketing only uses the top 8 bits of an fp16 key + // which both fp32 and bf16 capture. + #define VORTEX_DISPATCH_REMAP(IN_CPP, OUT_CPP, IN_PTR_EXPR, OUT_PTR_EXPR) \ + TopKRemapOnly_Kernel<<>>( \ + IN_PTR_EXPR, dense_kv_indptr.data_ptr(), OUT_PTR_EXPR, \ + reserved_bos, reserved_eos, mapping) + + const bool in_bf16 = (x.scalar_type() == at::ScalarType::BFloat16); + const bool in_fp32 = (x.scalar_type() == at::ScalarType::Float); + const bool out_bf16 = (remapped.scalar_type() == at::ScalarType::BFloat16); + + if (in_bf16 && out_bf16) { + VORTEX_DISPATCH_REMAP(__nv_bfloat16, __nv_bfloat16, + reinterpret_cast<__nv_bfloat16*>(x.data_ptr()), + reinterpret_cast<__nv_bfloat16*>(remapped.data_ptr())); + } else if (in_bf16 && !out_bf16) { + VORTEX_DISPATCH_REMAP(__nv_bfloat16, float, + reinterpret_cast<__nv_bfloat16*>(x.data_ptr()), + remapped.data_ptr()); + } else if (in_fp32 && out_bf16) { + VORTEX_DISPATCH_REMAP(float, __nv_bfloat16, + x.data_ptr(), + reinterpret_cast<__nv_bfloat16*>(remapped.data_ptr())); + } else if (in_fp32 && !out_bf16) { + VORTEX_DISPATCH_REMAP(float, float, + x.data_ptr(), + remapped.data_ptr()); + } else { + TORCH_CHECK(false, "topk_remap_only: unsupported dtype ", x.scalar_type()); + } + + #undef VORTEX_DISPATCH_REMAP + + const auto result = cudaGetLastError(); + TORCH_CHECK(result == cudaSuccess, + "topk_remap_only kernel failed: ", ::cudaGetErrorString(result)); +} diff --git a/csrc/topk_sglang_ori.cu b/csrc/topk_sglang_ori.cu new file mode 100644 index 0000000..55a99b2 --- /dev/null +++ b/csrc/topk_sglang_ori.cu @@ -0,0 +1,619 @@ +/** + * @NOTE: This file is adapted from + * https://github.com/tile-ai/tilelang/blob/main/examples/deepseek_v32/topk_selector.py + * We: + * 1. adapt from tilelang to pure cuda + * 2. optimize the performance a little + * 3. fix the potential illegal memory access + */ + #include + #include + #include + #include + #include + #include + #include + #include + + #include + #include + #include + + namespace { + + // NOTE: TopK is a compile-time constant here because shared-memory + // allocations inside the transform kernels depend on it. We drop it to + // 30 to match the vortex benchmark's --topk-val 30 configuration. The + // transform kernels (decode/prefill/prefill_ragged) still carry a manual + // unroll that assumes TopK==2048; that code path is unreachable from the + // bench (we only invoke fast_topk_interface), so the corresponding + // static_asserts have been removed below. + constexpr int TopK = 30; + constexpr int kThreadsPerBlock = 1024; + + #ifdef USE_ROCM + // On ROCm, the per-workgroup LDS budget depends on the target arch, so we inject a + // per-arch value from `setup_rocm.py` via `-DSGL_TOPK_DYNAMIC_SMEM_BYTES=...`. + #ifdef SGL_TOPK_DYNAMIC_SMEM_BYTES + constexpr size_t kSmem = static_cast(SGL_TOPK_DYNAMIC_SMEM_BYTES); + #else + constexpr size_t kSmem = 48 * 1024; // bytes + #endif + #else + // Reduced from 128KB to 32KB to improve occupancy. + // Each radix pass needs at most ~TopK candidates in the threshold bin, + // so 4K entries per round (2 rounds = 8K entries = 32KB) is sufficient. + constexpr size_t kSmem = 8 * 1024 * sizeof(uint32_t); // 32KB (bytes) + #endif + + struct FastTopKParams { + const float* __restrict__ input; // [B, input_stride] + const int32_t* __restrict__ row_starts; // [B] + int32_t* __restrict__ indices; // [B, TopK] + int32_t* __restrict__ lengths; // [B] + int64_t input_stride; + }; + + // when length <= TopK, we can directly write the indices + __device__ void naive_topk_cuda(const float* __restrict__ score, int32_t* __restrict__ indice, int32_t length) { + const auto tid = threadIdx.x; + for (int i = tid; i < TopK; i += kThreadsPerBlock) { + indice[i] = (i < length) ? i : -1; + } + } + + // keep the first `length` entries, set others to -1 + __device__ void naive_topk_transform( + const float* __restrict__ score, + int32_t length, + int32_t* __restrict__ dst_page_table, + const int32_t* __restrict__ src_page_table) { + const auto tid = threadIdx.x; + for (auto i = tid; i < TopK; i += kThreadsPerBlock) { + dst_page_table[i] = (i < length) ? src_page_table[i] : -1; + } + } + + // keep the first `length` entries, set others to -1 + __device__ void naive_topk_transform_ragged( + const float* __restrict__ score, int32_t length, int32_t* __restrict__ topk_indices_ragged, int32_t offset) { + const auto tid = threadIdx.x; + for (auto i = tid; i < TopK; i += kThreadsPerBlock) { + topk_indices_ragged[i] = (i < length) ? static_cast(i) + offset : -1; + } + } + + __device__ __forceinline__ auto convert_to_uint8(float x) -> uint8_t { + __half h = __float2half_rn(x); + uint16_t bits = __half_as_ushort(h); + uint16_t key = (bits & 0x8000) ? static_cast(~bits) : static_cast(bits | 0x8000); + return static_cast(key >> 8); + } + + __device__ __forceinline__ auto convert_to_uint32(float x) -> uint32_t { + uint32_t bits = __float_as_uint(x); + return (bits & 0x80000000u) ? ~bits : (bits | 0x80000000u); + } + + __device__ void fast_topk_cuda_tl(const float* __restrict__ input, int* __restrict__ index, int row_start, int length) { + // An optimized topk kernel copied from tilelang kernel + // We assume length > TopK here, or it will crash + int topk = TopK; + constexpr auto BLOCK_SIZE = 1024; + constexpr auto RADIX = 256; + constexpr auto SMEM_INPUT_SIZE = kSmem / (2 * sizeof(int)); + + alignas(128) __shared__ int s_histogram_buf[2][RADIX + 128]; + alignas(128) __shared__ int s_counter; + alignas(128) __shared__ int s_threshold_bin_id; + alignas(128) __shared__ int s_num_input[2]; + + auto& s_histogram = s_histogram_buf[0]; + // allocate for two rounds + extern __shared__ int s_input_idx[][SMEM_INPUT_SIZE]; + + const int tx = threadIdx.x; + + // stage 1: 8bit coarse histogram + if (tx < RADIX + 1) s_histogram[tx] = 0; + __syncthreads(); + + for (int idx = tx; idx < length; idx += BLOCK_SIZE) { + const auto bin = convert_to_uint8(input[idx + row_start]); + ::atomicAdd(&s_histogram[bin], 1); + } + __syncthreads(); + + const auto run_cumsum = [&] { + #pragma unroll 8 + for (int i = 0; i < 8; ++i) { + static_assert(1 << 8 == RADIX); + if (C10_LIKELY(tx < RADIX)) { + const auto j = 1 << i; + const auto k = i & 1; + auto value = s_histogram_buf[k][tx]; + if (tx < RADIX - j) { + value += s_histogram_buf[k][tx + j]; + } + s_histogram_buf[k ^ 1][tx] = value; + } + __syncthreads(); + } + }; + + run_cumsum(); + if (tx < RADIX && s_histogram[tx] > topk && s_histogram[tx + 1] <= topk) { + s_threshold_bin_id = tx; + s_num_input[0] = 0; + s_counter = 0; + } + __syncthreads(); + + const auto threshold_bin = s_threshold_bin_id; + topk -= s_histogram[threshold_bin + 1]; + + if (topk == 0) { + for (int idx = tx; idx < length; idx += BLOCK_SIZE) { + const auto bin = static_cast(convert_to_uint8(input[idx + row_start])); + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&s_counter, 1); + index[pos] = idx; + } + } + __syncthreads(); + return; + } else { + __syncthreads(); + if (tx < RADIX + 1) { + s_histogram[tx] = 0; + } + __syncthreads(); + + for (int idx = tx; idx < length; idx += BLOCK_SIZE) { + const auto raw_input = input[idx + row_start]; + const auto bin = static_cast(convert_to_uint8(raw_input)); + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&s_counter, 1); + index[pos] = idx; + } else if (bin == threshold_bin) { + const auto pos = ::atomicAdd(&s_num_input[0], 1); + /// NOTE: (dark) fuse the histogram computation here + if (C10_LIKELY(pos < SMEM_INPUT_SIZE)) { + s_input_idx[0][pos] = idx; + const auto bin = convert_to_uint32(raw_input); + const auto sub_bin = (bin >> 24) & 0xFF; + ::atomicAdd(&s_histogram[sub_bin], 1); + } + } + } + __syncthreads(); + } + + // stage 2: refine with 8bit radix passes + #pragma unroll 4 + for (int round = 0; round < 4; ++round) { + __shared__ int s_last_remain; + const auto r_idx = round % 2; + + // clip here to prevent overflow + const auto _raw_num_input = s_num_input[r_idx]; + const auto num_input = (_raw_num_input < int(SMEM_INPUT_SIZE)) ? _raw_num_input : int(SMEM_INPUT_SIZE); + + run_cumsum(); + if (tx < RADIX && s_histogram[tx] > topk && s_histogram[tx + 1] <= topk) { + s_threshold_bin_id = tx; + s_num_input[r_idx ^ 1] = 0; + s_last_remain = topk - s_histogram[tx + 1]; + } + __syncthreads(); + + const auto threshold_bin = s_threshold_bin_id; + topk -= s_histogram[threshold_bin + 1]; + + if (topk == 0) { + for (int i = tx; i < num_input; i += BLOCK_SIZE) { + const auto idx = s_input_idx[r_idx][i]; + const auto offset = 24 - round * 8; + const auto bin = (convert_to_uint32(input[idx + row_start]) >> offset) & 0xFF; + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&s_counter, 1); + index[pos] = idx; + } + } + __syncthreads(); + break; + } else { + __syncthreads(); + if (tx < RADIX + 1) { + s_histogram[tx] = 0; + } + __syncthreads(); + for (int i = tx; i < num_input; i += BLOCK_SIZE) { + const auto idx = s_input_idx[r_idx][i]; + const auto raw_input = input[idx + row_start]; + const auto offset = 24 - round * 8; + const auto bin = (convert_to_uint32(raw_input) >> offset) & 0xFF; + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&s_counter, 1); + index[pos] = idx; + } else if (bin == threshold_bin) { + if (round == 3) { + const auto pos = ::atomicAdd(&s_last_remain, -1); + if (pos > 0) { + index[TopK - pos] = idx; + } + } else { + const auto pos = ::atomicAdd(&s_num_input[r_idx ^ 1], 1); + if (C10_LIKELY(pos < SMEM_INPUT_SIZE)) { + /// NOTE: (dark) fuse the histogram computation here + s_input_idx[r_idx ^ 1][pos] = idx; + const auto bin = convert_to_uint32(raw_input); + const auto sub_bin = (bin >> (offset - 8)) & 0xFF; + ::atomicAdd(&s_histogram[sub_bin], 1); + } + } + } + } + __syncthreads(); + } + } + } + + __global__ __launch_bounds__(kThreadsPerBlock) // topk + void topk_kernel(const FastTopKParams params) { + const auto& [input, row_starts, indices, lengths, input_stride] = params; + const auto bid = static_cast(blockIdx.x); + const auto row_start = row_starts == nullptr ? 0 : row_starts[bid]; + const auto length = lengths[bid]; + const auto indice = indices + bid * TopK; + const auto score = input + bid * input_stride; + if (length <= TopK) { + return naive_topk_cuda(score, indice, length); + } else { + return fast_topk_cuda_tl(score, indice, row_start, length); + } + } + + __global__ __launch_bounds__(kThreadsPerBlock) // decode + void topk_transform_decode_kernel( + const FastTopKParams params, + int32_t* __restrict__ dst_page_table, + const int32_t* __restrict__ src_page_table, + const int64_t src_stride) { + const auto& [input, _1, _2, lengths, input_stride] = params; + const auto bid = static_cast(blockIdx.x); + const auto tid = threadIdx.x; + const auto row_start = 0; + const auto length = lengths[bid]; + const auto src_page_entry = src_page_table + bid * src_stride; + const auto dst_page_entry = dst_page_table + bid * TopK; + const auto score = input + bid * input_stride; + if (length <= TopK) { + return naive_topk_transform(score, length, dst_page_entry, src_page_entry); + } else { + __shared__ int s_indices[TopK]; + fast_topk_cuda_tl(score, s_indices, row_start, length); + // copy src[s_indices] to dst, we manually unroll here + // (static_asserts removed because TopK != 2048 in this build; the + // manual unroll below is unreachable from bench_topk.py which only + // calls fast_topk_interface, not this transform variant.) + const auto idx_0 = tid; + const auto pos_0 = s_indices[idx_0]; + dst_page_entry[idx_0] = src_page_entry[pos_0]; + const auto idx_1 = tid + kThreadsPerBlock; + const auto pos_1 = s_indices[idx_1]; + dst_page_entry[idx_1] = src_page_entry[pos_1]; + } + } + + __global__ __launch_bounds__(kThreadsPerBlock) // prefill + void topk_transform_prefill_kernel( + const FastTopKParams params, + int32_t* __restrict__ dst_page_table, + const int32_t* __restrict__ src_page_table, + const int64_t src_stride, + const int32_t* __restrict__ cu_seqlens_q, + const int64_t prefill_bs) { + const auto& [input, row_starts, _, lengths, input_stride] = params; + const auto bid = static_cast(blockIdx.x); + const auto tid = threadIdx.x; + const auto length = lengths[bid]; + const auto row_start = row_starts == nullptr ? 0 : row_starts[bid]; + const auto dst_page_entry = dst_page_table + bid * TopK; + const auto score = input + bid * input_stride; + + /// NOTE: prefill bs is usually small, we can just use a simple loop here + /// We ensure that last cu_seqlens is equal to number of blocks launched + __shared__ const int32_t* s_src_page_entry; + if (C10_LIKELY(prefill_bs <= kThreadsPerBlock)) { + if (tid < prefill_bs) { + if (bid >= cu_seqlens_q[tid] && bid < cu_seqlens_q[tid + 1]) { + s_src_page_entry = src_page_table + tid * src_stride; + } + } + } else { + for (int64_t i = tid; i < prefill_bs; i += kThreadsPerBlock) { + if (bid >= cu_seqlens_q[i] && bid < cu_seqlens_q[i + 1]) { + s_src_page_entry = src_page_table + i * src_stride; + } + } + } + __syncthreads(); + const auto src_page_entry = s_src_page_entry; + + if (length <= TopK) { + return naive_topk_transform(score, length, dst_page_entry, src_page_entry); + } else { + __shared__ int s_indices[TopK]; + fast_topk_cuda_tl(score, s_indices, row_start, length); + // copy src[s_indices] to dst, we manually unroll here + // (static_asserts removed because TopK != 2048 in this build; the + // manual unroll below is unreachable from bench_topk.py which only + // calls fast_topk_interface, not this transform variant.) + const auto idx_0 = tid; + const auto pos_0 = s_indices[idx_0]; + dst_page_entry[idx_0] = src_page_entry[pos_0]; + const auto idx_1 = tid + kThreadsPerBlock; + const auto pos_1 = s_indices[idx_1]; + dst_page_entry[idx_1] = src_page_entry[pos_1]; + } + } + + __global__ __launch_bounds__(kThreadsPerBlock) // prefill, ragged kv + void topk_transform_prefill_ragged_kernel( + const FastTopKParams params, + int32_t* __restrict__ topk_indices_ragged, + const int32_t* __restrict__ topk_indices_offset) { + const auto& [input, row_starts, _, lengths, input_stride] = params; + const auto bid = static_cast(blockIdx.x); + const auto tid = threadIdx.x; + const auto row_start = row_starts == nullptr ? 0 : row_starts[bid]; + const auto length = lengths[bid]; + const auto dst_indices_entry = topk_indices_ragged + bid * TopK; + const auto score = input + bid * input_stride; + const auto offset = topk_indices_offset[bid]; + + if (length <= TopK) { + return naive_topk_transform_ragged(score, length, dst_indices_entry, offset); + } else { + __shared__ int s_indices[TopK]; + fast_topk_cuda_tl(score, s_indices, row_start, length); + // copy src[s_indices] to dst, we manually unroll here + // (static_asserts removed because TopK != 2048 in this build; the + // manual unroll below is unreachable from bench_topk.py which only + // calls fast_topk_interface, not this transform variant.) + const auto idx_0 = tid; + const auto pos_0 = s_indices[idx_0]; + dst_indices_entry[idx_0] = pos_0 + offset; + const auto idx_1 = tid + kThreadsPerBlock; + const auto pos_1 = s_indices[idx_1]; + dst_indices_entry[idx_1] = pos_1 + offset; + } + } + + auto get_params( + const at::Tensor& score, + const at::Tensor& lengths, + std::optional row_starts_opt = std::nullopt, + std::optional indices_opt = std::nullopt) -> FastTopKParams { + const auto B = score.size(0); + TORCH_CHECK(score.dim() == 2 && score.stride(1) == 1); + if (row_starts_opt.has_value()) { + const auto& row_starts = row_starts_opt.value(); + TORCH_CHECK(row_starts.dim() == 1); + TORCH_CHECK(row_starts.size(0) == B); + } + TORCH_CHECK(lengths.dim() == 1 && lengths.is_contiguous()); + TORCH_CHECK(lengths.size(0) == B); + int32_t* indices_data_ptr = nullptr; + if (indices_opt.has_value()) { + const auto& indices = indices_opt.value(); + TORCH_CHECK(indices.dim() == 2 && indices.is_contiguous()); + TORCH_CHECK(indices.size(0) == B); + TORCH_CHECK(indices.size(1) == TopK); + indices_data_ptr = indices.data_ptr(); + } + + return FastTopKParams{ + .input = score.data_ptr(), + .row_starts = row_starts_opt.has_value() ? row_starts_opt->data_ptr() : nullptr, + .indices = indices_data_ptr, + .lengths = lengths.data_ptr(), + .input_stride = score.stride(0), + }; + } + + template + void setup_kernel_smem_once() { + [[maybe_unused]] + static const auto result = [] { + #ifdef USE_ROCM + // hipify will turn cudaFuncSetAttribute -> hipFuncSetAttribute. On ROCm, + // hipFuncSetAttribute expects `const void*` and hipcc does not accept passing + // a function pointer directly, so cast explicitly. + return ::cudaFuncSetAttribute( + reinterpret_cast(f), ::cudaFuncAttributeMaxDynamicSharedMemorySize, max_dynamic_smem); + #else + // CUDA: keep original behavior (no cast needed). + return ::cudaFuncSetAttribute(f, ::cudaFuncAttributeMaxDynamicSharedMemorySize, max_dynamic_smem); + #endif + }(); + TORCH_CHECK(result == cudaSuccess, "set_up_kernel_once failed:", ::cudaGetErrorString(result)); + } + + } // namespace + + // The public interface functions below collide by name with identically + // named symbols in topk_sglang.cu. Wrap them in `sglang_ori` so both + // translation units can be linked into the same vortex_torch_C extension. + namespace sglang_ori { + + #ifndef CHECK_CUDA + #define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") + #endif + + void fast_topk_interface( + const at::Tensor& score, at::Tensor& indices, const at::Tensor& lengths, std::optional row_starts_opt) { + CHECK_CUDA(score); + CHECK_CUDA(indices); + if (row_starts_opt.has_value()) { + CHECK_CUDA(row_starts_opt.value()); + } + CHECK_CUDA(lengths); + const auto params = get_params(score, lengths, row_starts_opt, indices); + const auto B = score.size(0); + const auto stream = at::cuda::getCurrentCUDAStream().stream(); + const auto grid = dim3{static_cast(B)}; + const auto block = dim3{kThreadsPerBlock}; + setup_kernel_smem_once(); + topk_kernel<<>>(params); + const auto result = cudaGetLastError(); + TORCH_CHECK(result == cudaSuccess, "topk kernel failed:", ::cudaGetErrorString(result)); + } + + void fast_topk_transform_interface( + const at::Tensor& score, + const at::Tensor& lengths, + at::Tensor& dst_page_table, + const at::Tensor& src_page_table, + const at::Tensor& cu_seqlens_q, + std::optional row_starts_opt) { + CHECK_CUDA(score); + CHECK_CUDA(lengths); + CHECK_CUDA(dst_page_table); + CHECK_CUDA(src_page_table); + CHECK_CUDA(cu_seqlens_q); + if (row_starts_opt.has_value()) { + CHECK_CUDA(row_starts_opt.value()); + } + const auto params = get_params(score, lengths, row_starts_opt); + const auto B = score.size(0); + TORCH_CHECK(dst_page_table.dim() == 2 && dst_page_table.is_contiguous()); + TORCH_CHECK(src_page_table.dim() == 2 && src_page_table.stride(1) == 1); + TORCH_CHECK(cu_seqlens_q.dim() == 1 && cu_seqlens_q.is_contiguous()); + const auto prefill_bs = cu_seqlens_q.size(0) - 1; + TORCH_CHECK(dst_page_table.size(0) == B); + TORCH_CHECK(dst_page_table.size(1) == TopK); + TORCH_CHECK(src_page_table.size(0) == prefill_bs); + TORCH_CHECK(prefill_bs <= B); // prefill_bs should be smaller than expanded bs + + // launch kernel + const auto stream = at::cuda::getCurrentCUDAStream().stream(); + const auto grid = dim3{static_cast(B)}; + const auto block = dim3{kThreadsPerBlock}; + const auto src_stride = src_page_table.stride(0); + + // dispatch to decode or prefill + // extend and draft extend: row_starts_opt is not null, invokes the prefill kernel + // decode: row_starts_opt is null, invokes the decode kernel + // target verify: row_starts_opt is null, invokes the prefill kernel + const auto is_decode = !row_starts_opt.has_value() && prefill_bs == B; + if (is_decode) { + setup_kernel_smem_once(); + topk_transform_decode_kernel<<>>( + params, dst_page_table.data_ptr(), src_page_table.data_ptr(), src_stride); + } else { + setup_kernel_smem_once(); + topk_transform_prefill_kernel<<>>( + params, + dst_page_table.data_ptr(), + src_page_table.data_ptr(), + src_stride, + cu_seqlens_q.data_ptr(), + prefill_bs); + } + + const auto result = cudaGetLastError(); + TORCH_CHECK(result == cudaSuccess, "topk kernel failed:", ::cudaGetErrorString(result)); + } + + void fast_topk_transform_ragged_interface( + const at::Tensor& score, + const at::Tensor& lengths, + at::Tensor& topk_indices_ragged, + const at::Tensor& topk_indices_offset, + std::optional row_starts_opt) { + CHECK_CUDA(score); + CHECK_CUDA(lengths); + CHECK_CUDA(topk_indices_ragged); + CHECK_CUDA(topk_indices_offset); + if (row_starts_opt.has_value()) { + CHECK_CUDA(row_starts_opt.value()); + } + + const auto params = get_params(score, lengths, row_starts_opt); + const auto B = score.size(0); + TORCH_CHECK(topk_indices_ragged.dim() == 2 && topk_indices_ragged.is_contiguous()); + TORCH_CHECK(topk_indices_offset.dim() == 1); + + TORCH_CHECK(topk_indices_ragged.size(0) == B); + TORCH_CHECK(topk_indices_ragged.size(1) == TopK); + TORCH_CHECK(topk_indices_offset.size(0) == B); + + // launch kernel + const auto stream = at::cuda::getCurrentCUDAStream().stream(); + const auto grid = dim3{static_cast(B)}; + const auto block = dim3{kThreadsPerBlock}; + + setup_kernel_smem_once(); + topk_transform_prefill_ragged_kernel<<>>( + params, topk_indices_ragged.data_ptr(), topk_indices_offset.data_ptr()); + + const auto result = cudaGetLastError(); + TORCH_CHECK(result == cudaSuccess, "topk kernel failed:", ::cudaGetErrorString(result)); + } + + } // namespace sglang_ori + +// ====================================================================== +// Thin vortex_torch_C adapter: accepts the same CSR-ish inputs as +// topk_output_sglang so bench_topk.py can treat the original SGLang kernel +// as an alternate baseline. The ori kernel has TopK baked in as a compile- +// time constant; this build sets it to 30 to match --topk-val 30. +// ====================================================================== +void topk_output_sglang_ori( + const at::Tensor& x, // [total_dense, 1, 1] or [total_dense], bf16/fp32 + const at::Tensor& dense_kv_indptr, // int32 [eff_bs + 1] (unused — synthetic bench rows are uniform) + at::Tensor& indices_out, // int32 [eff_bs, TopK] + const int64_t eff_batch_size, + const int64_t topk_val, + const int64_t reserved_bos, + const int64_t reserved_eos, + const int64_t max_num_pages) +{ + TORCH_CHECK(x.is_cuda(), "x must be a CUDA tensor"); + TORCH_CHECK(dense_kv_indptr.is_cuda(), "dense_kv_indptr must be a CUDA tensor"); + TORCH_CHECK(indices_out.is_cuda(), "indices_out must be a CUDA tensor"); + TORCH_CHECK(indices_out.scalar_type() == at::ScalarType::Int, + "indices_out must be int32"); + TORCH_CHECK(topk_val == static_cast(30), + "topk_output_sglang_ori: this build of the ori kernel hard-codes TopK=30; " + "rebuild topk_sglang_ori.cu with a different TopK if you need another value. " + "Got topk_val=", topk_val); + TORCH_CHECK(indices_out.dim() == 2 + && indices_out.size(0) == eff_batch_size + && indices_out.size(1) == 30, + "indices_out must be [eff_batch_size, 30]"); + + // ori kernel requires fp32 [B, stride] scores. Caller typically passes + // the bf16 score tensor; we materialize an fp32 view once per call. + at::Tensor score_f32; + if (x.scalar_type() == at::ScalarType::Float) { + score_f32 = x.contiguous().view({eff_batch_size, max_num_pages}); + } else if (x.scalar_type() == at::ScalarType::BFloat16) { + score_f32 = x.to(at::kFloat).contiguous().view({eff_batch_size, max_num_pages}); + } else { + TORCH_CHECK(false, "topk_output_sglang_ori: unsupported dtype ", x.scalar_type()); + } + + auto opts_i32 = at::TensorOptions().dtype(at::kInt).device(x.device()); + const int32_t usable_len = + static_cast(max_num_pages - reserved_bos - reserved_eos); + at::Tensor lengths = at::full({eff_batch_size}, usable_len, opts_i32); + at::Tensor row_starts = at::full({eff_batch_size}, + static_cast(reserved_bos), opts_i32); + + sglang_ori::fast_topk_interface( + score_f32, indices_out, lengths, + std::optional(row_starts)); +} \ No newline at end of file diff --git a/csrc/topk_sglang_profile.cu b/csrc/topk_sglang_profile.cu new file mode 100644 index 0000000..7fe9981 --- /dev/null +++ b/csrc/topk_sglang_profile.cu @@ -0,0 +1,620 @@ +/** + * TopK profiling kernels: histogram collection, stage-1-only timing, + * and diagnostic counter collection. + * + * Separated from topk_sglang.cu to reduce template instantiation + * pressure on CUDA shared memory resources. + */ +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +namespace { + + +constexpr int TopK = 2048; +constexpr int kThreadsPerBlock = 1024; + +#ifdef USE_ROCM +// On ROCm, the per-workgroup LDS budget depends on the target arch, so we inject a +// per-arch value from `setup_rocm.py` via `-DSGL_TOPK_DYNAMIC_SMEM_BYTES=...`. +#ifdef SGL_TOPK_DYNAMIC_SMEM_BYTES +constexpr size_t kSmem = static_cast(SGL_TOPK_DYNAMIC_SMEM_BYTES); +#else +constexpr size_t kSmem = 48 * 1024; // bytes +#endif +#else +// Reduced from 128KB to 32KB to improve occupancy. +// Each radix pass needs at most ~TopK candidates in the threshold bin, +// so 4K entries per round (2 rounds = 8K entries = 32KB) is sufficient. +constexpr size_t kSmem = 8 * 1024 * sizeof(uint32_t); // 32KB (bytes) +#endif + +struct FastTopKParams { + const float* __restrict__ input; // [B, input_stride] + const int32_t* __restrict__ row_starts; // [B] + int32_t* __restrict__ indices; // [B, TopK] + int32_t* __restrict__ lengths; // [B] + int64_t input_stride; +}; + +// when length <= TopK, we can directly write the indices +__device__ void naive_topk_cuda(const float* __restrict__ score, int32_t* __restrict__ indice, int32_t length) { + const auto tid = threadIdx.x; + for (int i = tid; i < TopK; i += kThreadsPerBlock) { + indice[i] = (i < length) ? i : -1; + } +} + +// keep the first `length` entries, set others to -1 +__device__ void naive_topk_transform( + const float* __restrict__ score, + int32_t length, + int32_t* __restrict__ dst_page_table, + const int32_t* __restrict__ src_page_table) { + const auto tid = threadIdx.x; + for (auto i = tid; i < TopK; i += kThreadsPerBlock) { + dst_page_table[i] = (i < length) ? src_page_table[i] : -1; + } +} + +// keep the first `length` entries, set others to -1 +__device__ void naive_topk_transform_ragged( + const float* __restrict__ score, int32_t length, int32_t* __restrict__ topk_indices_ragged, int32_t offset) { + const auto tid = threadIdx.x; + for (auto i = tid; i < TopK; i += kThreadsPerBlock) { + topk_indices_ragged[i] = (i < length) ? static_cast(i) + offset : -1; + } +} + +__device__ __forceinline__ auto convert_to_uint8(float x) -> uint8_t { + __half h = __float2half_rn(x); + uint16_t bits = __half_as_ushort(h); + uint16_t key = (bits & 0x8000) ? static_cast(~bits) : static_cast(bits | 0x8000); + return static_cast(key >> 8); +} + +__device__ __forceinline__ auto convert_to_uint32(float x) -> uint32_t { + uint32_t bits = __float_as_uint(x); + return (bits & 0x80000000u) ? ~bits : (bits | 0x80000000u); +} + +// Mirror of convert_to_uint8_dense in topk_sglang.cu so that the +// profile kernel (topk_profile_histogram / topk_profile_counters) +// reports accurate thr_bin / thr_size / abv_bins / pg/bin for +// MAPPING_DENSE_MANT. Keep in sync with the production kernel. +__device__ __forceinline__ auto convert_to_uint8_dense(float x) -> uint8_t { + const uint32_t bits = __float_as_uint(x); + const uint32_t key = (bits & 0x80000000u) ? ~bits : (bits | 0x80000000u); + return static_cast((key >> 16) & 0xFFu); +} + +template +__device__ __forceinline__ float vortex_to_float(T x); +template <> +__device__ __forceinline__ float vortex_to_float(float x) { return x; } +template <> +__device__ __forceinline__ float vortex_to_float<__nv_bfloat16>(__nv_bfloat16 x) { + return __bfloat162float(x); +} + + +constexpr int VORTEX_MAX_TOPK = 2048; + +// Diagnostic counters written by the profiling kernel. These kernels are +// NOT used for latency measurements — they intentionally add global-memory +// writes that distort timings. Latency is measured against the clean +// production kernels in topk_sglang.cu. +constexpr int COUNTER_THRESHOLD_BIN = 0; +constexpr int COUNTER_NUM_ABOVE = 1; +constexpr int COUNTER_NUM_EQUAL = 2; +constexpr int COUNTER_REMAINING_K = 3; +constexpr int COUNTER_REFINE_ROUNDS = 4; +constexpr int COUNTER_STAGE2_INPUT = 5; +constexpr int NUM_TOPK_COUNTERS = 6; + +#include "topk_mapping.cuh" + +template +void setup_kernel_smem_once() { + [[maybe_unused]] + static const auto result = [] { +#ifdef USE_ROCM + return ::cudaFuncSetAttribute( + reinterpret_cast(f), ::cudaFuncAttributeMaxDynamicSharedMemorySize, max_dynamic_smem); +#else + return ::cudaFuncSetAttribute(f, ::cudaFuncAttributeMaxDynamicSharedMemorySize, max_dynamic_smem); +#endif + }(); + TORCH_CHECK(result == cudaSuccess, "set_up_kernel_once failed:", ::cudaGetErrorString(result)); +} + +// ====================================================================== +// Profiling variant of fast_topk_clean_fused that writes diagnostic +// counters at the end of Stage 1 and at each Stage 2 early-exit. +// Shape / semantics identical to the production kernel, with one extra +// global-memory write pass at the end of each stage. Do not use for +// latency measurements. +// ====================================================================== +template +__device__ void fast_topk_profile( + const ScoreT* __restrict__ input, + int* __restrict__ index, + int row_start, + int length, + int target_k, + const TopKMappingParams mapping, + int* __restrict__ counters) // [NUM_TOPK_COUNTERS] +{ + int topk = target_k; + constexpr auto BLOCK_SIZE = 1024; + constexpr auto RADIX = 256; + constexpr auto SMEM_INPUT_SIZE = kSmem / (2 * sizeof(int)); + + alignas(128) __shared__ int p_histogram_buf[2][RADIX + 128]; + alignas(128) __shared__ int p_counter; + alignas(128) __shared__ int p_threshold_bin_id; + alignas(128) __shared__ int p_num_input[2]; + + __shared__ uint8_t s_mapping_lut[256]; + __shared__ float s_mapping_quantiles[256]; + + auto& p_histogram = p_histogram_buf[0]; + extern __shared__ int p_input_idx[][SMEM_INPUT_SIZE]; + + const int tx = threadIdx.x; + + // Mirror of the production kernel: MAPPING_DENSE_MANT bypasses + // apply_transform and uses a mantissa-heavy fp32 bit slice for the + // Stage-1 bucket. + const bool use_dense_bucket = (mapping.mode == MAPPING_DENSE_MANT); + + if (mapping.mode == MAPPING_LUT_CDF && mapping.lut != nullptr) { + if (tx < 256) s_mapping_lut[tx] = mapping.lut[tx]; + __syncthreads(); + } + if (mapping.mode == MAPPING_QUANTILE && mapping.quantiles != nullptr) { + if (tx < 256) s_mapping_quantiles[tx] = mapping.quantiles[tx]; + __syncthreads(); + } + + if (tx < RADIX + 1) p_histogram[tx] = 0; + __syncthreads(); + + for (int idx = tx; idx < length; idx += BLOCK_SIZE) { + const float raw = vortex_to_float(input[idx + row_start]); + int bin; + if (use_dense_bucket) { + const float clamped = apply_transform(raw, mapping); // fmaxf(x, pivot) + bin = static_cast(convert_to_uint8_dense(clamped)); + } else { + bin = static_cast(compute_stage1_bin(raw, mapping, s_mapping_lut, s_mapping_quantiles)); + } + ::atomicAdd(&p_histogram[bin], 1); + } + __syncthreads(); + + const auto run_cumsum = [&] { +#pragma unroll 8 + for (int i = 0; i < 8; ++i) { + static_assert(1 << 8 == RADIX); + if (C10_LIKELY(tx < RADIX)) { + const auto j = 1 << i; + const auto k = i & 1; + auto value = p_histogram_buf[k][tx]; + if (tx < RADIX - j) { + value += p_histogram_buf[k][tx + j]; + } + p_histogram_buf[k ^ 1][tx] = value; + } + __syncthreads(); + } + }; + + run_cumsum(); + if (tx < RADIX && p_histogram[tx] > topk && p_histogram[tx + 1] <= topk) { + p_threshold_bin_id = tx; + p_num_input[0] = 0; + p_counter = 0; + } + __syncthreads(); + + const int threshold_bin_0 = p_threshold_bin_id; + const int threshold_bin_size = p_histogram[threshold_bin_0]; // pre-reset count + topk -= p_histogram[threshold_bin_0 + 1]; + + if (tx == 0 && counters) { + counters[COUNTER_THRESHOLD_BIN] = threshold_bin_0; + counters[COUNTER_NUM_EQUAL] = threshold_bin_size; + counters[COUNTER_REMAINING_K] = topk; + } + + if (topk == 0) { + for (int idx = tx; idx < length; idx += BLOCK_SIZE) { + const float raw = vortex_to_float(input[idx + row_start]); + int bin; + if (use_dense_bucket) { + const float clamped = apply_transform(raw, mapping); + bin = static_cast(convert_to_uint8_dense(clamped)); + } else { + bin = static_cast(compute_stage1_bin(raw, mapping, s_mapping_lut, s_mapping_quantiles)); + } + if (bin > threshold_bin_0) { + const auto pos = ::atomicAdd(&p_counter, 1); + index[pos] = idx; + } + } + __syncthreads(); + if (tx == 0 && counters) { + counters[COUNTER_NUM_ABOVE] = p_counter; + counters[COUNTER_REFINE_ROUNDS] = 0; + counters[COUNTER_STAGE2_INPUT] = 0; + } + return; + } else { + __syncthreads(); + if (tx < RADIX + 1) p_histogram[tx] = 0; + __syncthreads(); + + const int sub_bin_offset_start = use_dense_bucket ? 8 : 24; + for (int idx = tx; idx < length; idx += BLOCK_SIZE) { + const float raw = vortex_to_float(input[idx + row_start]); + const float remapped = apply_transform(raw, mapping); + const auto bin = use_dense_bucket + ? static_cast(convert_to_uint8_dense(remapped)) + : static_cast(compute_stage1_bin(raw, mapping, s_mapping_lut, s_mapping_quantiles)); + if (bin > threshold_bin_0) { + const auto pos = ::atomicAdd(&p_counter, 1); + index[pos] = idx; + } else if (bin == threshold_bin_0) { + const auto pos = ::atomicAdd(&p_num_input[0], 1); + if (C10_LIKELY(pos < SMEM_INPUT_SIZE)) { + p_input_idx[0][pos] = idx; + const auto b32 = convert_to_uint32(remapped); + const auto sub_bin = (b32 >> sub_bin_offset_start) & 0xFF; + ::atomicAdd(&p_histogram[sub_bin], 1); + } + } + } + __syncthreads(); + if (tx == 0 && counters) { + counters[COUNTER_NUM_ABOVE] = p_counter; + counters[COUNTER_STAGE2_INPUT] = p_num_input[0]; + } + } + + // Stage 2 refinement. Standard modes run up to 4 rounds (offsets + // 24/16/8/0); MAPPING_DENSE_MANT runs up to 2 rounds (offsets 8/0) + // because Stage 1 already consumed bits [23:16] of the fp32 key. + const int stage2_offset_start = use_dense_bucket ? 8 : 24; + const int stage2_max_rounds = use_dense_bucket ? 2 : 4; + if (tx == 0 && counters) counters[COUNTER_REFINE_ROUNDS] = stage2_max_rounds; +#pragma unroll 4 + for (int round = 0; round < 4; ++round) { + if (round >= stage2_max_rounds) break; + __shared__ int p_last_remain; + const auto r_idx = round % 2; + const auto _raw_num_input = p_num_input[r_idx]; + const auto num_input = (_raw_num_input < int(SMEM_INPUT_SIZE)) ? _raw_num_input : int(SMEM_INPUT_SIZE); + + run_cumsum(); + if (tx < RADIX && p_histogram[tx] > topk && p_histogram[tx + 1] <= topk) { + p_threshold_bin_id = tx; + p_num_input[r_idx ^ 1] = 0; + p_last_remain = topk - p_histogram[tx + 1]; + } + __syncthreads(); + + const auto threshold_bin = p_threshold_bin_id; + topk -= p_histogram[threshold_bin + 1]; + + if (topk == 0) { + for (int i = tx; i < num_input; i += BLOCK_SIZE) { + const auto idx = p_input_idx[r_idx][i]; + const float raw = vortex_to_float(input[idx + row_start]); + const float remapped = apply_transform(raw, mapping); + const auto offset = stage2_offset_start - round * 8; + const auto bin = (convert_to_uint32(remapped) >> offset) & 0xFF; + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&p_counter, 1); + index[pos] = idx; + } + } + __syncthreads(); + if (tx == 0 && counters) counters[COUNTER_REFINE_ROUNDS] = round + 1; + break; + } else { + __syncthreads(); + if (tx < RADIX + 1) p_histogram[tx] = 0; + __syncthreads(); + for (int i = tx; i < num_input; i += BLOCK_SIZE) { + const auto idx = p_input_idx[r_idx][i]; + const float raw = vortex_to_float(input[idx + row_start]); + const float remapped = apply_transform(raw, mapping); + const auto offset = stage2_offset_start - round * 8; + const auto bin = (convert_to_uint32(remapped) >> offset) & 0xFF; + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&p_counter, 1); + index[pos] = idx; + } else if (bin == threshold_bin) { + if (round == stage2_max_rounds - 1) { + const auto pos = ::atomicAdd(&p_last_remain, -1); + if (pos > 0) { + index[target_k - pos] = idx; + } + } else { + const auto pos = ::atomicAdd(&p_num_input[r_idx ^ 1], 1); + if (C10_LIKELY(pos < SMEM_INPUT_SIZE)) { + p_input_idx[r_idx ^ 1][pos] = idx; + const auto b32 = convert_to_uint32(remapped); + const auto sub_bin = (b32 >> (offset - 8)) & 0xFF; + ::atomicAdd(&p_histogram[sub_bin], 1); + } + } + } + } + __syncthreads(); + } + } +} + +// Wrapper: one block per (batch*head) segment. Writes counters per +// segment into a [eff_batch_size, NUM_TOPK_COUNTERS] int32 tensor. +template +__global__ __launch_bounds__(kThreadsPerBlock) +void TopKProfileCounters_Kernel( + const ScoreT* __restrict__ score, + const int* __restrict__ dense_kv_indptr, + const int* __restrict__ sparse_kv_indptr, + const int* __restrict__ dense_kv_indices, + int* __restrict__ sparse_kv_indices, + int* __restrict__ counters, + const int topk_val, + const int page_reserved_bos, + const int page_reserved_eos, + const TopKMappingParams mapping) +{ + const int bx = blockIdx.x; + + const int start = dense_kv_indptr[bx] + page_reserved_bos; + const int end = dense_kv_indptr[bx + 1] - page_reserved_eos; + const int nblk = end - start; + if (nblk <= topk_val) return; + + const ScoreT* __restrict__ score_blk = score + start; + const int* __restrict__ idx_blk = dense_kv_indices + start; + int* __restrict__ out_blk = sparse_kv_indices + + sparse_kv_indptr[bx] + + page_reserved_bos; + + __shared__ int s_indices[VORTEX_MAX_TOPK]; + fast_topk_profile( + score_blk, s_indices, 0, nblk, topk_val, mapping, + counters + bx * NUM_TOPK_COUNTERS); + __syncthreads(); + + const int tx = threadIdx.x; + for (int i = tx; i < topk_val; i += kThreadsPerBlock) { + out_blk[i] = idx_blk[s_indices[i]]; + } +} + +// Histogram-only profiling kernel: builds a 256-bin histogram of the +// remapped bins for each segment. Purely diagnostic — never timed. +template +__global__ __launch_bounds__(kThreadsPerBlock) +void TopKProfileHistogram_Kernel( + const ScoreT* __restrict__ score, + const int* __restrict__ dense_kv_indptr, + int* __restrict__ histograms, // [eff_batch_size, 256] + const int page_reserved_bos, + const int page_reserved_eos, + const TopKMappingParams mapping) +{ + constexpr auto RADIX = 256; + constexpr auto BLOCK_SIZE = kThreadsPerBlock; + __shared__ int s_histogram[RADIX]; + __shared__ uint8_t s_mapping_lut[256]; + __shared__ float s_mapping_quantiles[256]; + + const int bx = blockIdx.x; + const int tx = threadIdx.x; + + const int start = dense_kv_indptr[bx] + page_reserved_bos; + const int end = dense_kv_indptr[bx + 1] - page_reserved_eos; + const int nblk = end - start; + + if (mapping.mode == MAPPING_LUT_CDF && mapping.lut != nullptr) { + if (tx < 256) s_mapping_lut[tx] = mapping.lut[tx]; + __syncthreads(); + } + if (mapping.mode == MAPPING_QUANTILE && mapping.quantiles != nullptr) { + if (tx < 256) s_mapping_quantiles[tx] = mapping.quantiles[tx]; + __syncthreads(); + } + + if (tx < RADIX) s_histogram[tx] = 0; + __syncthreads(); + + const bool use_dense_bucket = (mapping.mode == MAPPING_DENSE_MANT); + if (nblk > 0) { + const ScoreT* __restrict__ score_blk = score + start; + for (int i = tx; i < nblk; i += BLOCK_SIZE) { + const float raw = vortex_to_float(score_blk[i]); + int bin; + if (use_dense_bucket) { + const float clamped = apply_transform(raw, mapping); + bin = static_cast(convert_to_uint8_dense(clamped)); + } else { + bin = static_cast(compute_stage1_bin(raw, mapping, s_mapping_lut, s_mapping_quantiles)); + } + ::atomicAdd(&s_histogram[bin], 1); + } + } + __syncthreads(); + + int* __restrict__ out = histograms + bx * RADIX; + if (tx < RADIX) out[tx] = s_histogram[tx]; +} + +} // namespace + +#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") + +static TopKMappingParams build_mapping_params( + int64_t mapping_mode, double mapping_power, + std::optional& mapping_lut, + std::optional& mapping_quantiles) +{ + TopKMappingParams m{}; + m.mode = static_cast(mapping_mode); + m.power_exp = static_cast(mapping_power); + m.lut = nullptr; + m.quantiles = nullptr; + if (mapping_lut.has_value()) { + const auto& lut = mapping_lut.value(); + TORCH_CHECK(lut.is_cuda(), "mapping_lut must be a CUDA tensor"); + TORCH_CHECK(lut.dim() == 1 && lut.size(0) == 256 && lut.scalar_type() == at::ScalarType::Byte, + "mapping_lut must be a 1D uint8 tensor of size 256"); + m.lut = lut.data_ptr(); + } + if (mapping_quantiles.has_value()) { + const auto& q = mapping_quantiles.value(); + TORCH_CHECK(q.is_cuda(), "mapping_quantiles must be a CUDA tensor"); + TORCH_CHECK(q.dim() == 1 && q.size(0) == 256 && q.scalar_type() == at::ScalarType::Float, + "mapping_quantiles must be a 1D float32 tensor of size 256"); + m.quantiles = q.data_ptr(); + } + return m; +} + +// ====================================================================== +// Profiling: per-segment 256-bin histograms of Stage 1 remapped bins. +// ====================================================================== +void topk_profile_histogram( + const at::Tensor& x, + const at::Tensor& dense_kv_indptr, + at::Tensor& histograms, + const int64_t eff_batch_size, + const int64_t reserved_bos, + const int64_t reserved_eos, + const int64_t mapping_mode, + const double mapping_power, + std::optional mapping_lut, + std::optional mapping_quantiles) +{ + CHECK_CUDA(x); + CHECK_CUDA(dense_kv_indptr); + CHECK_CUDA(histograms); + TORCH_CHECK(histograms.dim() == 2 && histograms.size(0) == eff_batch_size + && histograms.size(1) == 256, + "histograms must be [eff_batch_size, 256]"); + TORCH_CHECK(histograms.scalar_type() == at::ScalarType::Int, + "histograms must be int32"); + + auto mapping = build_mapping_params(mapping_mode, mapping_power, mapping_lut, mapping_quantiles); + + dim3 nblks(eff_batch_size); + dim3 nthreads(kThreadsPerBlock); + cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); + + if (x.scalar_type() == at::ScalarType::BFloat16) { + TopKProfileHistogram_Kernel<__nv_bfloat16><<>>( + reinterpret_cast<__nv_bfloat16*>(x.data_ptr()), + dense_kv_indptr.data_ptr(), + histograms.data_ptr(), + reserved_bos, reserved_eos, mapping); + } else if (x.scalar_type() == at::ScalarType::Float) { + TopKProfileHistogram_Kernel<<>>( + x.data_ptr(), + dense_kv_indptr.data_ptr(), + histograms.data_ptr(), + reserved_bos, reserved_eos, mapping); + } else { + TORCH_CHECK(false, "topk_profile_histogram: unsupported dtype ", x.scalar_type()); + } + + const auto result = cudaGetLastError(); + TORCH_CHECK(result == cudaSuccess, + "topk_profile_histogram kernel failed: ", ::cudaGetErrorString(result)); +} + +// ====================================================================== +// Profiling: full pipeline + per-segment diagnostic counters. +// Adds extra global-memory writes — never use for latency measurement. +// ====================================================================== +void topk_profile_counters( + const at::Tensor& x, + const at::Tensor& dense_kv_indptr, + const at::Tensor& sparse_kv_indptr, + const at::Tensor& dense_kv_indices, + at::Tensor& sparse_kv_indices, + at::Tensor& counters, + const int64_t eff_batch_size, + const int64_t topk_val, + const int64_t reserved_bos, + const int64_t reserved_eos, + const int64_t max_num_pages, + const int64_t mapping_mode, + const double mapping_power, + std::optional mapping_lut, + std::optional mapping_quantiles) +{ + TORCH_CHECK(topk_val <= VORTEX_MAX_TOPK, + "topk_profile_counters: topk_val (", topk_val, + ") exceeds VORTEX_MAX_TOPK (", VORTEX_MAX_TOPK, ")"); + CHECK_CUDA(x); + CHECK_CUDA(dense_kv_indptr); + CHECK_CUDA(sparse_kv_indptr); + CHECK_CUDA(dense_kv_indices); + CHECK_CUDA(sparse_kv_indices); + CHECK_CUDA(counters); + TORCH_CHECK(counters.dim() == 2 && counters.size(0) == eff_batch_size + && counters.size(1) == NUM_TOPK_COUNTERS, + "counters must be [eff_batch_size, ", NUM_TOPK_COUNTERS, "]"); + TORCH_CHECK(counters.scalar_type() == at::ScalarType::Int, "counters must be int32"); + + auto mapping = build_mapping_params(mapping_mode, mapping_power, mapping_lut, mapping_quantiles); + + dim3 nblks(eff_batch_size); + dim3 nthreads(kThreadsPerBlock); + cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); + + if (x.scalar_type() == at::ScalarType::BFloat16) { + setup_kernel_smem_once, kSmem>(); + TopKProfileCounters_Kernel<__nv_bfloat16><<>>( + reinterpret_cast<__nv_bfloat16*>(x.data_ptr()), + dense_kv_indptr.data_ptr(), + sparse_kv_indptr.data_ptr(), + dense_kv_indices.data_ptr(), + sparse_kv_indices.data_ptr(), + counters.data_ptr(), + topk_val, reserved_bos, reserved_eos, mapping); + } else if (x.scalar_type() == at::ScalarType::Float) { + setup_kernel_smem_once, kSmem>(); + TopKProfileCounters_Kernel<<>>( + x.data_ptr(), + dense_kv_indptr.data_ptr(), + sparse_kv_indptr.data_ptr(), + dense_kv_indices.data_ptr(), + sparse_kv_indices.data_ptr(), + counters.data_ptr(), + topk_val, reserved_bos, reserved_eos, mapping); + } else { + TORCH_CHECK(false, "topk_profile_counters: unsupported dtype ", x.scalar_type()); + } + + const auto result = cudaGetLastError(); + TORCH_CHECK(result == cudaSuccess, + "topk_profile_counters kernel failed: ", ::cudaGetErrorString(result)); +} diff --git a/csrc/utils_sglang.cu b/csrc/utils_sglang.cu index 1420e9e..a7ddf42 100644 --- a/csrc/utils_sglang.cu +++ b/csrc/utils_sglang.cu @@ -82,16 +82,20 @@ const int page_reserved_eos #pragma unroll for (int i = 0; i < ITEM_PER_THREAD; ++i){ - int16_t w = ((tx_offset + i) < eff_batch_size) ? - (dense_kv_indptr[tx_offset+i+1] - dense_kv_indptr[tx_offset+i] + int16_t w = ((tx_offset + i) < eff_batch_size) ? + (dense_kv_indptr[tx_offset+i+1] - dense_kv_indptr[tx_offset+i] - page_reserved_bos - page_reserved_eos): 0; - - page_count[i] = (w > topk_val) ? w : 0; + + // See note in Sgl_Decode_Plan_Workload_Kernel: we used to skip slots + // where w ≤ topk_val, but downstream (GeMV / topK / histogram) has no + // matching skip, so it read uninitialised scores and silently + // produced all-zero results. Emit workloads for every slot with w > 0. + page_count[i] = (w > 0) ? w : 0; chunked_page_count_prefix_sum[i + 1] = int((page_count[i] + max_chunk_size - 1) / max_chunk_size); } BlockScanInt(temp.scan_int).InclusiveSum(chunked_page_count_prefix_sum, chunked_page_count_prefix_sum); - + if (tx == 1023){ *winfo_num_workload = chunked_page_count_prefix_sum[ITEM_PER_THREAD]; *winfo_chunk_size = max_chunk_size; @@ -218,16 +222,22 @@ const int page_reserved_eos #pragma unroll for (int i = 0; i < ITEM_PER_THREAD; ++i){ - int16_t w = ((tx_offset + i) < eff_batch_size) ? - (dense_kv_indptr[tx_offset+i+1] - dense_kv_indptr[tx_offset+i] + int16_t w = ((tx_offset + i) < eff_batch_size) ? + (dense_kv_indptr[tx_offset+i+1] - dense_kv_indptr[tx_offset+i] - page_reserved_bos - page_reserved_eos): 0; - - page_count[i] = (w > topk_val) ? w : 0; + + // Previously: (w > topk_val) ? w : 0, which skipped scoring on slots + // where the dense page count is already ≤ topk_val. Downstream (GeMV, + // topK, histogram profiling) does NOT have a matching skip, so it + // would read uninitialised scores and silently return garbage (all + // zero). Emit workloads for every slot with w > 0 so scoring always + // runs; when w ≤ topk_val the topK degenerates to "select all w". + page_count[i] = (w > 0) ? w : 0; chunked_page_count_prefix_sum[i + 1] = int((page_count[i] + max_chunk_size - 1) / max_chunk_size); } BlockScanInt(temp.scan_int).InclusiveSum(chunked_page_count_prefix_sum, chunked_page_count_prefix_sum); - + if (tx == 1023){ *winfo_num_workloads = chunked_page_count_prefix_sum[ITEM_PER_THREAD]; *winfo_chunk_size = max_chunk_size; diff --git a/examples/README.md b/examples/README.md new file mode 100644 index 0000000..d14650f --- /dev/null +++ b/examples/README.md @@ -0,0 +1,399 @@ +# Vortex Torch Examples + +End-to-end accuracy evaluation and profiling pipelines for Vortex sparse attention on top of the SGLang inference engine. The scripts in this directory evaluate different TopK kernel variants, mapping functions, KV-cache quantization settings, and external sparse-attention backends on math reasoning benchmarks. + +--- + +## Mapping Functions Reference + +The TopK Stage-1 radix histogram uses 256 uint8 bins. A **mapping function** transforms raw attention scores before binning to improve bucket uniformity and reduce tail latency. Set via `--topk-mapping-mode`. + +| Mode | Name | Formula | Requires Calibration | Hyperparameter (`--topk-mapping-power`) | +|------|------|---------|---------------------|-----------------------------------------| +| 0 | None | FP16 bit-pattern bucketing | No | — | +| 1 | LUT CDF | `lut[original_bin]` (CDF equalization) | Yes (`--topk-mapping-lut-path`) | — | +| 2 | Quantile | Binary search over 256 float thresholds | Yes (`--topk-mapping-quantiles-path`) | — | +| 3 | Power | `sign(x) * \|x\|^p` | No | `p` (exponent, default 0.5) | +| 4 | Log | `sign(x) * log(\|x\| + 1)` | No | — | +| 5 | Index Cache | Reuse top-k indices from a preceding layer | No | — (see `--index-cache-shared-layers`) | +| 6 | Asinh | `asinh(beta * x)` | No | `beta` (default 0.5) | +| 7 | Log1p | `sign(x) * log1p(alpha * \|x\|)` | No | `alpha` (default 0.5) | +| 8 | Trunc8 | BF16 upper-8-bit bucketing | No | — | + +Modes 1 and 2 require an offline calibration step (see `calibrate_topk.py` in `benchmarks/`). Modes 3, 6, and 7 accept a tunable hyperparameter via `--topk-mapping-power`. + +--- + +## Python Scripts + +### `verify_algo.py` — End-to-End Accuracy Benchmark + +The primary evaluation script. Loads AMC 2023 math problems from `amc23.jsonl`, runs inference via the SGLang engine with Vortex sparse attention, and scores answers using `lighteval`'s extractive-match metric. Reports `mean@N`, `pass@N`, throughput, and memory access cost. + +**Usage:** + +```bash +python verify_algo.py [OPTIONS] +``` + +**CLI Arguments:** + +| Argument | Default | Description | +|----------|---------|-------------| +| `--trials` | 2 | Number of trials (each prompt repeated N times) | +| `--topk-val` | 30 | Number of top-k pages to select per segment | +| `--page-size` | 16 | Tokens per KV-cache page | +| `--vortex-module-name` | `gqa_block_sparse_attention` | Sparse attention algorithm module | +| `--model-name` | `Qwen/Qwen3-1.7B` | HuggingFace model identifier | +| `-f`, `--full-attention` | off | Disable sparse attention (full-attention baseline) | +| `--mem` | 0.8 | Static GPU memory fraction for SGLang | +| `--kv-cache-dtype` | `auto` | KV cache dtype: `auto`, `fp8_e5m2`, `fp8_e4m3`, `int8` | +| `--topk-type` | `naive` | TopK kernel: `naive` (CUB radix sort) or `sglang` (fast two-stage radix) | +| `--topk-mapping-mode` | 0 | Mapping function for Stage-1 binning (see table above) | +| `--topk-mapping-power` | 0.5 | Hyperparameter for modes 3/6/7 | +| `--topk-mapping-lut-path` | None | `.npy` uint8[256] LUT for mode 1 | +| `--topk-mapping-quantiles-path` | None | `.npy` float32[256] quantiles for mode 2 | +| `--index-cache-shared-layers` | None | Layer IDs that skip the indexer and reuse a previous layer's indices | + +**Fixed engine settings:** `attention_backend=flashinfer`, `vortex_max_seq_lens=12288`, layer 0 skipped, `reserved_bos=1`, `reserved_eos=2`. Sampling: `temperature=0.6`, `top_p=0.95`, `top_k=20`, `max_new_tokens=8192`. + +**Index cache note (mode 5):** When `--topk-mapping-mode 5` is set without `--index-cache-shared-layers`, the script defaults to even layers `[2, 4, 6, ..., 26]` and internally resets the mapping mode to 0 while passing the shared-layer list to the engine. + +**Example — full-attention baseline:** + +```bash +python verify_algo.py --full-attention --trials 8 --mem 0.7 +``` + +**Example — sglang TopK with power mapping:** + +```bash +python verify_algo.py \ + --topk-type sglang \ + --topk-mapping-mode 3 \ + --topk-mapping-power 0.25 \ + --trials 8 --topk-val 30 --mem 0.7 +``` + +**Example — sglang TopK with calibrated LUT:** + +```bash +python verify_algo.py \ + --topk-type sglang \ + --topk-mapping-mode 1 \ + --topk-mapping-lut-path calibration/lut.npy \ + --trials 8 --topk-val 30 --mem 0.7 +``` + +--- + +### `verify_aim24.py` — AIME 2024 Throughput Test (Legacy) + +A standalone throughput script that loads AIME 2024 from HuggingFace (`HuggingFaceH4/aime_2024`), builds chat prompts using the Qwen3 tokenizer with `enable_thinking=True`, and repeats each prompt 8 times. Outputs a JSONL file with generation results and timing metadata. Does **not** compute accuracy metrics. + +**Usage:** + +```bash +python verify_aim24.py +``` + +All settings are hard-coded (no CLI arguments): + +| Setting | Value | +|---------|-------| +| Model | `Qwen/Qwen3-0.6B` | +| Page size | 16 | +| Selected pages | 29 | +| Max sequence length | 20480 | +| Module | `block_sparse_attention` | +| Memory fraction | 0.9 | +| Max new tokens | 16384 | +| CUDA graph | Enabled | + +--- + +## Shell Scripts + +All shell scripts set `CUDA_VISIBLE_DEVICES` and save timestamped logs to `results/`. + +### `verify_algo.sh` — Baseline TopK Comparison (Naive vs SGLang) + +Runs `verify_algo.py` with `block_sparse_attention` comparing the `naive` and `sglang` TopK kernels. Each configuration is repeated `REPEAT_COUNT` times (default 3, overridable via environment variable). + +```bash +REPEAT_COUNT=5 bash verify_algo.sh +``` + +### `verify_algo_topk.sh` — Naive vs SGLang Comparison + +Similar to `verify_algo.sh` but simpler: runs `naive` TopK and `sglang` TopK back-to-back for `block_sparse_attention`, each with 8 trials. + +### `verify_algo_quant.sh` — INT8 KV-Cache Quantization + +Tests sparse attention with `--kv-cache-dtype int8` to measure accuracy under quantized KV caches. + +```bash +bash verify_algo_quant.sh +``` + +### `verify_sparse_backends.sh` — External Sparse Attention Backends + +Evaluates three external sparse-attention algorithms integrated via the Vortex flow interface: + +- `nsa` (Native Sparse Attention) +- `fsa` (Flash Sparse Attention) +- `flash_moba` (Flash MoBA) + +```bash +bash verify_sparse_backends.sh +``` + +### `verify_algo_topk_mapping.sh` — Full Mapping Mode Sweep + +Comprehensive sweep across all mapping modes: + +1. **Baseline:** `naive` TopK, mode 0 +2. **Calibration:** runs `calibrate_topk.py` to generate `lut.npy` and `quantiles.npy` (skipped if files exist) +3. **Mode 1** (LUT CDF) and **Mode 2** (Quantile) with calibrated tables +4. **Modes 0, 3, 4** (no calibration needed) — Power mode uses `--topk-mapping-power 0.5` +5. **Mode 6** (Asinh) — sweeps `beta` in `[0.5, 1.0, 2.0]` +6. **Mode 7** (Log1p) — sweeps `alpha` in `[0.5, 1.0, 2.0]` + +```bash +export CUDA_VISIBLE_DEVICES=0 +bash verify_algo_topk_mapping.sh +``` + +### `verify_algo_topk_mapping_new.sh` — Parametric Mapping Sweep (Modes 3, 6, 7) + +Focused hyperparameter sweep for the three parametric modes, preceded by an auto-tuning step: + +| Mode | Parameter | Sweep Values | +|------|-----------|-------------| +| 3 (Power) | `p` | 0.1, 0.25, 0.75, 0.9 | +| 6 (Asinh) | `beta` | 0.1, 0.5, 1.0, 2.0, 4.0 | +| 7 (Log1p) | `alpha` | 0.1, 0.5, 0.75, 1.0, 2.0, 4.0, 8.0 | + +Requires `calibration/raw_histograms.npy` for the auto-tune step. + +```bash +export CUDA_VISIBLE_DEVICES=5 +bash verify_algo_topk_mapping_new.sh +``` + +### `verify_algo_topk_mapping_indexcache.sh` — Index Cache (Mode 5) + +Tests the index-cache optimization where even-numbered layers `[2, 4, 6, ..., 26]` reuse top-k indices from the nearest preceding full layer, skipping their indexer entirely. + +```bash +bash verify_algo_topk_mapping_indexcache.sh +``` + +### `run_topk_benchmark.sh` — Unified TopK Benchmark Pipeline + +The most comprehensive benchmarking script. Three-step pipeline: + +1. **Calibrate** — collect real-data histograms + LUT/quantile tables +2. **Kernel bench** — latency + histogram profiling across batch sizes, sequence lengths, and distributions, followed by distribution analysis plots and auto-tuning +3. **E2E accuracy** — full-attention baseline plus every mapping mode + +```bash +bash run_topk_benchmark.sh --gpu 5 --trials 8 --model-name Qwen/Qwen3-1.7B +``` + +| Option | Default | Description | +|--------|---------|-------------| +| `--model-name` | `Qwen/Qwen3-1.7B` | HuggingFace model | +| `--topk-val` | 30 | Top-k pages | +| `--trials` | 8 | E2E trial count | +| `--mem` | 0.7 | GPU memory fraction | +| `--gpu` | 5 | CUDA device | +| `--algo` | `block_sparse_attention` | Sparse attention algorithm | +| `--skip-calibrate` | off | Reuse existing calibration | +| `--skip-kernel` | off | Skip kernel-level latency step | +| `--skip-e2e` | off | Skip E2E accuracy step | + +### `run_distribution_analysis.sh` — Bucket Distribution Profiling (All Modes) + +Three-step pipeline to analyze how each mapping mode affects the 256-bin bucket distribution: + +1. **Calibrate** — collect real-data histograms (skippable with `--real-histograms`) +2. **Bench** — histogram profiling with modes 0–8 on `bucket_uniform` and `normal` distributions +3. **Analyze** — generate comparison plots and CSV bucket count tables + +```bash +bash run_distribution_analysis.sh --gpu 5 +bash run_distribution_analysis.sh --gpu 5 --real-histograms /path/to/raw_histograms.npy +``` + +### `run_distribution_analysis_new.sh` — Bucket Distribution Profiling (Modes 3, 6, 7) + +Same pipeline as above but focused on parametric modes only, with an additional auto-tune step: + +1. **Calibrate** (or skip with existing histograms) +2. **Auto-tune** — sweep hyperparameters on synthetic data +3. **Bench** — histogram profiling for modes 3, 6, 7, 8 +4. **Analyze** — comparison plots + tables + +```bash +bash run_distribution_analysis_new.sh --gpu 5 +``` + +--- + +## Benchmarks Directory Scripts + +The `benchmarks/` directory contains standalone profiling and analysis tools used by the shell pipelines above. These can also be run independently. + +### `calibrate_topk.py` — Offline Calibration + +Runs the SGLang engine on real prompts from `amc23.jsonl` with histogram collection enabled. Produces three files: + +- `lut.npy` — uint8[256] CDF-equalized LUT for mode 1 +- `quantiles.npy` — float32[256] quantile breakpoints for mode 2 +- `raw_histograms.npy` — raw per-sample 256-bin histograms + +```bash +python benchmarks/calibrate_topk.py \ + --model-name Qwen/Qwen3-1.7B \ + --topk-val 30 --mem 0.7 \ + --output-dir calibration/ +``` + +### `bench_topk.py` — Kernel-Level Latency Benchmark + +Benchmarks `topk_output` (naive/CUB) and `topk_output_sglang` (fast radix) across configurable sweeps of batch size, sequence length, TopK value, KV heads, and score distributions. Optionally collects 256-bin histogram statistics. + +```bash +python benchmarks/bench_topk.py \ + --batch-sizes 4 8 16 \ + --seq-lens 2048 4096 8192 \ + --topk-vals 30 \ + --num-kv-heads 2 \ + --distributions normal lognormal uniform bucket_uniform \ + --histogram \ + --repeat 100 \ + --output-json results.json +``` + +### `autotune_topk_mapping.py` — Hyperparameter Auto-Tuning + +Sweeps hyperparameters for parametric mapping modes (3, 6, 7) using the `topk_profile_histogram` kernel on synthetic data. Ranks configurations by resolution rate, Gini coefficient, max/mean ratio, and nonzero bins. + +```bash +python benchmarks/autotune_topk_mapping.py \ + --topk-val 30 --batch-size 4 --seq-len 4096 --num-kv-heads 2 \ + --real-histograms calibration/raw_histograms.npy \ + --output-json autotune_results.json +``` + +### `analyze_topk_distribution.py` — Visualization and Analysis + +Loads profiling data and generates: +- Per-segment 256-bin bar charts +- Heatmaps (segments x bins, log-scale) +- Before/after LUT mapping comparisons +- Mode comparison grouped bar charts (Gini + max/mean) +- Distribution comparison plots across data sources +- CSV bucket count tables + +```bash +python benchmarks/analyze_topk_distribution.py \ + --bench-json bench_distribution.json \ + --real-histograms calibration/raw_histograms.npy \ + --output-dir plots/ +``` + +### `profile_topk_distribution.py` — Offline Table Generation + +Computes LUT and quantile tables from pre-collected histograms or raw scores without running a model. Outputs a single `.npz` archive. + +```bash +python benchmarks/profile_topk_distribution.py \ + --histograms-input raw_histograms.npy \ + --output mapping_tables.npz +``` + +### `greedy_layer_search.py` — Index Cache Layer Selection + +Greedy forward-selection of layers whose indexer can be skipped (index cache). Iteratively adds layers to the shared set as long as accuracy stays above `--threshold` times the baseline. + +```bash +cd examples && python ../benchmarks/greedy_layer_search.py \ + --model-name Qwen/Qwen3-1.7B \ + --topk-val 30 \ + --threshold 0.95 \ + --trials 1 \ + --num-layers 28 \ + --mem 0.7 +``` + +--- + +## Data Files + +| File | Description | +|------|-------------| +| `amc23.jsonl` | AMC 2023 math problems with `prompt` and `answer` fields, used by `verify_algo.py` and `calibrate_topk.py` | + +--- + +## Output Structure + +Results are saved under `results/` in timestamped directories: + +``` +results/ +├── dist_analysis_YYYYMMDD_HHMMSS/ +│ ├── step1_calibrate.log +│ ├── step2_autotune.log / step2_bench.log +│ ├── step3_bench.log / step3_analyze.log +│ ├── step4_analyze.log +│ ├── autotune_results.json +│ ├── bench_distribution.json +│ ├── distribution_comparison_*.png +│ ├── bucket_counts_*.csv +│ └── calibration/ +│ ├── lut.npy +│ ├── quantiles.npy +│ └── raw_histograms.npy +├── topk_benchmark_YYYYMMDD_HHMMSS/ +│ ├── kernel_latency.json +│ ├── e2e/ +│ │ ├── full_attention_baseline.log +│ │ ├── sglang_mode0_none.log +│ │ └── ... +│ └── calibration/ +└── *.log (individual run logs) +``` + +--- + +## Quick Start: Typical Workflow + +```bash +export CUDA_VISIBLE_DEVICES=0 + +# 1. Calibrate to generate LUT + quantile tables +python benchmarks/calibrate_topk.py \ + --model-name Qwen/Qwen3-1.7B --topk-val 30 --mem 0.7 \ + --output-dir examples/calibration/ + +# 2. Run full-attention baseline +python examples/verify_algo.py --full-attention --trials 8 --mem 0.7 + +# 3. Evaluate sparse attention with different mapping modes +python examples/verify_algo.py \ + --topk-type sglang --topk-mapping-mode 0 --trials 8 --mem 0.7 + +python examples/verify_algo.py \ + --topk-type sglang --topk-mapping-mode 3 --topk-mapping-power 0.25 \ + --trials 8 --mem 0.7 + +python examples/verify_algo.py \ + --topk-type sglang --topk-mapping-mode 6 --topk-mapping-power 1.0 \ + --trials 8 --mem 0.7 + +# 4. Or run the full pipeline in one shot +bash examples/run_topk_benchmark.sh --gpu 0 --trials 8 +``` diff --git a/examples/remap_function_bench_topk2028.sh b/examples/remap_function_bench_topk2028.sh new file mode 100755 index 0000000..26c529c --- /dev/null +++ b/examples/remap_function_bench_topk2028.sh @@ -0,0 +1,284 @@ +#!/usr/bin/env bash +# ============================================================ +# Remap Function Benchmark +# +# Compares four kernel configurations for TopK page selection: +# 1. baseline — unmapped topk (topk_output_sglang) +# 2. fused remap + topk — topk_output_sglang_fused +# 3. remap only — topk_remap_only (standalone kernel) +# 4. unmapped topk on remapped — topk_output_sglang on the output +# buffer of step 3 +# +# Per configuration the script also reports the threshold-bin +# position, the threshold-bin size, and how many values are +# selected from the threshold bin (derived from +# topk_profile_counters — collected after all timing measurements, +# never interleaved with latency measurements). +# +# Pipeline: +# 1. Calibrate — run `calibrate_topk.py` on the chosen model to +# collect the REAL per-segment topk distribution +# (raw_histograms.npy). Skippable via +# --real-histograms /path/to/raw_histograms.npy. +# 2. Autotune — run `autotune_topk_mapping.py` on those real +# histograms and pick the per-mode hyperparameter +# with the LOWEST measured topk kernel latency. +# 3. Remap bench— run `bench_topk.py --remap-bench` with the +# autotune-selected per-mode hyperparameters. +# +# Argument layout mirrors run_distribution_analysis_new.sh. +# +# Usage: +# # Default (Qwen/Qwen3-1.7B, block_size=16): +# bash remap_function_bench.sh --gpu 5 +# +# # Larger model + larger page/block size: +# bash remap_function_bench.sh --gpu 0 \ +# --model-name Qwen/Qwen3-8B \ +# --block-size 32 \ +# --seq-len 16384 --topk-val 512 \ +# --modes "0 3 6 7" +# +# # Reuse an existing calibration: +# bash remap_function_bench.sh --gpu 0 \ +# --model-name Qwen/Qwen3-8B \ +# --real-histograms /path/to/calibration/raw_histograms.npy +# ============================================================ +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +BENCH_DIR="${SCRIPT_DIR}/../benchmarks" + +# ── Defaults ────────────────────────────────────────────────── +GPU_ID=4 +MODEL_NAME="Qwen/Qwen3-1.7B" +TOPK_VAL=2048 +MEM=0.7 +# Cap KV / VTX sparse prefill buffer sizing during Step 1 (see calibrate_topk.py --help). +MAX_TOTAL_TOKENS=64768 +# Min free GiB on the output-dir filesystem before Step 1 (HF weights + cache + logs). +MIN_FREE_DISK_GB=22 +ALGO="block_sparse_attention" +SAMPLE_STRIDE=1 +SEQ_LEN=32768 +BLOCK_SIZE=1 +BATCH_SIZE=4 +NUM_KV_HEADS=8 +DISTRIBUTIONS="normal bucket_uniform" +# Modes 1 (LUT_CDF) and 2 (Quantile) are no longer benchmarked — their +# mapping happens inside compute_stage1_bin, not apply_transform, so +# split-phase timing isn't meaningful for them. +MAPPING_MODES="0 3 6 7 9 10 11 13 15 16 17 18 19" +# Fallback hparam used only if autotune is explicitly skipped. +MAPPING_HPARAM=0.5 +REPEAT=100 +WARMUP=20 +# Empty by default — Step 1 will calibrate on the selected model. +# Pass --real-histograms /path/to/raw_histograms.npy to skip calibration. +# REAL_HISTOGRAMS="/var/tmp/zhuominc/vortex_torch/calibration/raw_histograms.npy" +#REAL_HISTOGRAMS="/var/tmp/zhuominc/vortex_torch/calibration/raw_histograms_qwen3-4B.npy" +REAL_HISTOGRAMS="/var/tmp/zhuominc/vortex_torch/calibration/raw_histograms_qwen3-1.7B.npy" +SKIP_AUTOTUNE=0 +# Optional: pre-built autotune JSON to bypass Step 2 entirely. When set, +# Step 2 is skipped and Step 3 reads its per-mode hparams from this file +# instead. Useful for verification runs where we want to pin the exact +# (mode, hparam) pairs without re-running the latency sweep. +PINNED_AUTOTUNE_JSON="" + +# ── Parse arguments ─────────────────────────────────────────── +while [[ $# -gt 0 ]]; do + case "$1" in + --model-name) MODEL_NAME="$2"; shift 2 ;; + --topk-val) TOPK_VAL="$2"; shift 2 ;; + --mem) MEM="$2"; shift 2 ;; + --max-total-tokens) MAX_TOTAL_TOKENS="$2"; shift 2 ;; + --min-free-disk-gb) MIN_FREE_DISK_GB="$2"; shift 2 ;; + --gpu) GPU_ID="$2"; shift 2 ;; + --algo) ALGO="$2"; shift 2 ;; + --real-histograms) REAL_HISTOGRAMS="$2"; shift 2 ;; + --sample-stride) SAMPLE_STRIDE="$2"; shift 2 ;; + --seq-len) SEQ_LEN="$2"; shift 2 ;; + --block-size|--page-size) BLOCK_SIZE="$2"; shift 2 ;; + --batch-size) BATCH_SIZE="$2"; shift 2 ;; + --num-kv-heads) NUM_KV_HEADS="$2"; shift 2 ;; + --distributions) DISTRIBUTIONS="$2"; shift 2 ;; + --modes) MAPPING_MODES="$2"; shift 2 ;; + --mapping-hparam) MAPPING_HPARAM="$2"; shift 2 ;; + --repeat) REPEAT="$2"; shift 2 ;; + --warmup) WARMUP="$2"; shift 2 ;; + --skip-autotune) SKIP_AUTOTUNE=1; shift 1 ;; + --pinned-autotune-json) PINNED_AUTOTUNE_JSON="$2"; SKIP_AUTOTUNE=1; shift 2 ;; + *) echo "Unknown option: $1"; exit 1 ;; + esac +done + +export CUDA_VISIBLE_DEVICES="${GPU_ID}" +export PYTORCH_CUDA_ALLOC_CONF="${PYTORCH_CUDA_ALLOC_CONF:-expandable_segments:True}" + +# Qwen3-1.7B does not use DeepGEMM (no FP8/MoE path). +# Disable its JIT to silence "NVCC Compiler not found ... use NVRTC" on Blackwell. +export SGL_ENABLE_JIT_DEEPGEMM="${SGL_ENABLE_JIT_DEEPGEMM:-true}" + +# If DeepGEMM JIT is ever re-enabled, make sure it can find nvcc. +if [ -z "${DG_JIT_NVCC_COMPILER:-}" ]; then + if [ -x /usr/local/cuda/bin/nvcc ]; then + export CUDA_HOME="${CUDA_HOME:-/usr/local/cuda}" + export PATH="${CUDA_HOME}/bin:${PATH}" + export DG_JIT_NVCC_COMPILER="${CUDA_HOME}/bin/nvcc" + elif command -v nvcc >/dev/null 2>&1; then + export DG_JIT_NVCC_COMPILER="$(command -v nvcc)" + fi +fi + +# Validate seq_len: need pages/seg > topk_val (3 reserved pages) +MIN_SEQ_LEN=$(( (TOPK_VAL + 4) * BLOCK_SIZE )) +if [ "${SEQ_LEN}" -lt "${MIN_SEQ_LEN}" ]; then + echo "ERROR: --seq-len ${SEQ_LEN} too small for --topk-val ${TOPK_VAL} @ --block-size ${BLOCK_SIZE}." + echo " Minimum: ${MIN_SEQ_LEN} (pages/seg must exceed topk_val + 3 reserved pages)" + exit 1 +fi + +RESULTS_DIR="${SCRIPT_DIR}/results" +mkdir -p "${RESULTS_DIR}" +TIMESTAMP=$(date +%Y%m%d_%H%M%S) +MODEL_SLUG="$(echo "${MODEL_NAME}" | tr '/' '_')" +RUN_DIR="${RESULTS_DIR}/remap_bench_${MODEL_SLUG}_topk${TOPK_VAL}_bs${BLOCK_SIZE}_${TIMESTAMP}" +mkdir -p "${RUN_DIR}" + +# Calibration artifacts live on /var/tmp (large disk), keyed by model. +# Example: /var/tmp/zhuominc/vortex_torch/calibration/raw_histograms_qwen3-8B.npy +CALIBRATION_BASE="/var/tmp/zhuominc/vortex_torch/calibration" +MODEL_TAG="$(echo "${MODEL_NAME##*/}" | sed 's/^Q/q/')" +DEFAULT_REAL_HIST="${CALIBRATION_BASE}/raw_histograms_${MODEL_TAG}.npy" +mkdir -p "${CALIBRATION_BASE}" + +# If no explicit --real-histograms and a cached file exists, reuse it. +if [ -z "${REAL_HISTOGRAMS}" ] && [ -f "${DEFAULT_REAL_HIST}" ]; then + REAL_HISTOGRAMS="${DEFAULT_REAL_HIST}" +fi + +echo "============================================================" +echo "Remap Function Benchmark" +echo " Model: ${MODEL_NAME}" +echo " Algorithm: ${ALGO}" +echo " TopK: ${TOPK_VAL}" +echo " Block size: ${BLOCK_SIZE}" +echo " Seq len: ${SEQ_LEN} ($(( SEQ_LEN / BLOCK_SIZE )) pages/seg)" +echo " Batch size: ${BATCH_SIZE}" +echo " KV heads: ${NUM_KV_HEADS}" +echo " Distributions: ${DISTRIBUTIONS}" +echo " Mapping modes: ${MAPPING_MODES}" +echo " Fallback hparam: ${MAPPING_HPARAM} (used only when --skip-autotune)" +echo " Max total tokens: ${MAX_TOTAL_TOKENS} (calibration KV / VTX buffer cap)" +echo " Min free disk: ${MIN_FREE_DISK_GB} GiB (Step 1 preflight; 0 = skip)" +echo " GPU: ${GPU_ID}" +echo " Sample stride: ${SAMPLE_STRIDE}" +echo " Real histograms: ${REAL_HISTOGRAMS:-}" +echo " Output: ${RUN_DIR}" +echo "============================================================" + +# ── Step 1: Calibrate — collect real-distribution topk histograms ── +# calibrate_topk.py runs the model end-to-end with histogram profiling +# enabled and writes per-segment raw_histograms.npy. The histograms are +# aggregated over every layer and every decode/prefill step so the +# autotune in Step 2 sees the true attention-score distribution. +if [ -n "${REAL_HISTOGRAMS}" ]; then + echo "" + echo ">>> Step 1: SKIPPED (using provided --real-histograms ${REAL_HISTOGRAMS})" + REAL_HIST_PATH="${REAL_HISTOGRAMS}" +else + echo "" + echo ">>> Step 1: Calibrating ${MODEL_NAME} — collecting real topk histograms" + CALIBRATION_DIR="${CALIBRATION_BASE}/staging_${MODEL_TAG}_topk${TOPK_VAL}_bs${BLOCK_SIZE}_${TIMESTAMP}" + mkdir -p "${CALIBRATION_DIR}" + python "${BENCH_DIR}/calibrate_topk.py" \ + --model-name "${MODEL_NAME}" \ + --topk-val "${TOPK_VAL}" \ + --page-size "${BLOCK_SIZE}" \ + --mem "${MEM}" \ + --max-total-tokens "${MAX_TOTAL_TOKENS}" \ + --min-free-disk-gb "${MIN_FREE_DISK_GB}" \ + --vortex-module-name "${ALGO}" \ + --output-dir "${CALIBRATION_DIR}" \ + 2>&1 | tee "${RUN_DIR}/step1_calibrate.log" + # Promote raw_histograms.npy to the shared per-model cache path. + mv -f "${CALIBRATION_DIR}/raw_histograms.npy" "${DEFAULT_REAL_HIST}" + REAL_HIST_PATH="${DEFAULT_REAL_HIST}" + echo ">>> Step 1: Done. raw_histograms -> ${REAL_HIST_PATH}" + echo ">>> Step 1: Staging dir (lut/quantiles/logs): ${CALIBRATION_DIR}" +fi + +# Modes 1 (LUT_CDF) and 2 (Quantile) are dropped from the comparison, so +# lut.npy / quantiles.npy produced by calibration are no longer consumed. + +# ── Step 2: Auto-tune hyperparameters by profiled fused-topk latency ── +# For every (mode, hparam) combo in the sweep grid, the autotune runs the +# fused remap+topk kernel on the real histogram and measures end-to-end +# kernel latency with CUDA events. The per-mode hparam with the lowest +# measured topk kernel latency wins. +AUTOTUNE_JSON="${RUN_DIR}/autotune_results.json" +if [ "${SKIP_AUTOTUNE}" -eq 1 ]; then + echo "" + if [ -n "${PINNED_AUTOTUNE_JSON}" ]; then + echo ">>> Step 2: SKIPPED (pinned hparams from ${PINNED_AUTOTUNE_JSON})" + AUTOTUNE_ARGS="--autotune-json ${PINNED_AUTOTUNE_JSON}" + else + echo ">>> Step 2: SKIPPED (using fallback --mapping-hparam ${MAPPING_HPARAM})" + AUTOTUNE_ARGS="" + fi +else + echo "" + echo ">>> Step 2: Auto-tuning hyperparameters by profiled topk kernel latency" + PYTHONPATH="${SCRIPT_DIR}/.." python "${BENCH_DIR}/autotune_topk_mapping.py" \ + --batch-size "${BATCH_SIZE}" \ + --num-kv-heads "${NUM_KV_HEADS}" \ + --seq-len "${SEQ_LEN}" \ + --topk-val "${TOPK_VAL}" \ + --page-size "${BLOCK_SIZE}" \ + --real-histograms "${REAL_HIST_PATH}" \ + --warmup "${WARMUP}" \ + --repeat "${REPEAT}" \ + --collect-stats \ + --output-json "${AUTOTUNE_JSON}" \ + 2>&1 | tee "${RUN_DIR}/step2_autotune.log" + echo ">>> Step 2: Done. Autotune results saved to ${AUTOTUNE_JSON}" + AUTOTUNE_ARGS="--autotune-json ${AUTOTUNE_JSON}" +fi + +# ── Step 3: Remap benchmark (baseline / fused / remap / split) ── +echo "" +echo ">>> Step 3: Timing remap / topk / fused / baseline with autotuned hparams" +REMAP_JSON="${RUN_DIR}/remap_bench.json" +BENCH_EXTRA=() +[ -n "${REAL_HIST_PATH}" ] && BENCH_EXTRA+=(--real-histograms "${REAL_HIST_PATH}") +PYTHONPATH="${SCRIPT_DIR}/.." python "${BENCH_DIR}/bench_topk.py" \ + --remap-bench \ + --per-head-bench \ + --batch-sizes "${BATCH_SIZE}" \ + --num-kv-heads "${NUM_KV_HEADS}" \ + --seq-lens "${SEQ_LEN}" \ + --topk-vals "${TOPK_VAL}" \ + --page-size "${BLOCK_SIZE}" \ + --distributions ${DISTRIBUTIONS} \ + --mapping-modes ${MAPPING_MODES} \ + --mapping-hparam "${MAPPING_HPARAM}" \ + ${AUTOTUNE_ARGS} \ + "${BENCH_EXTRA[@]}" \ + --warmup "${WARMUP}" \ + --repeat "${REPEAT}" \ + --output-json "${REMAP_JSON}" \ + 2>&1 | tee "${RUN_DIR}/step3_remap_bench.log" +echo ">>> Step 3: Done. Remap bench saved to ${REMAP_JSON}" + +# ── Summary ─────────────────────────────────────────────────── +echo "" +echo "============================================================" +echo "Remap Function Benchmark Complete" +echo " Model: ${MODEL_NAME}" +echo " Block size: ${BLOCK_SIZE}" +echo " All outputs in: ${RUN_DIR}/" +echo " calibration/raw_histograms.npy — real topk distribution (per layer)" +echo " autotune_results.json — latency-ranked mapping hparams" +echo " remap_bench.json — per-config remap/topk/fused/baseline latencies" +echo " step{1,2,3}_*.log — pipeline logs" +echo "============================================================" diff --git a/examples/remap_function_bench_topk30.sh b/examples/remap_function_bench_topk30.sh new file mode 100755 index 0000000..3843906 --- /dev/null +++ b/examples/remap_function_bench_topk30.sh @@ -0,0 +1,267 @@ +#!/usr/bin/env bash +# ============================================================ +# Remap Function Benchmark +# +# Compares four kernel configurations for TopK page selection: +# 1. baseline — unmapped topk (topk_output_sglang) +# 2. fused remap + topk — topk_output_sglang_fused +# 3. remap only — topk_remap_only (standalone kernel) +# 4. unmapped topk on remapped — topk_output_sglang on the output +# buffer of step 3 +# +# Per configuration the script also reports the threshold-bin +# position, the threshold-bin size, and how many values are +# selected from the threshold bin (derived from +# topk_profile_counters — collected after all timing measurements, +# never interleaved with latency measurements). +# +# Pipeline: +# 1. Calibrate — run `calibrate_topk.py` on the chosen model to +# collect the REAL per-segment topk distribution +# (raw_histograms.npy). Skippable via +# --real-histograms /path/to/raw_histograms.npy. +# 2. Autotune — run `autotune_topk_mapping.py` on those real +# histograms and pick the per-mode hyperparameter +# with the LOWEST measured topk kernel latency. +# 3. Remap bench— run `bench_topk.py --remap-bench` with the +# autotune-selected per-mode hyperparameters. +# +# Argument layout mirrors run_distribution_analysis_new.sh. +# +# Usage: +# # Default (Qwen/Qwen3-1.7B, block_size=16): +# bash remap_function_bench.sh --gpu 5 +# +# # Larger model + larger page/block size: +# bash remap_function_bench.sh --gpu 0 \ +# --model-name Qwen/Qwen3-8B \ +# --block-size 32 \ +# --seq-len 16384 --topk-val 512 \ +# --modes "0 3 6 7" +# +# # Reuse an existing calibration: +# bash remap_function_bench.sh --gpu 0 \ +# --model-name Qwen/Qwen3-8B \ +# --real-histograms /path/to/calibration/raw_histograms.npy +# # Tight GPU: lower calibration KV cap (default 1048576): +# bash remap_function_bench_topk30.sh --gpu 0 --max-total-tokens 524288 +# ============================================================ +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +BENCH_DIR="${SCRIPT_DIR}/../benchmarks" + +# ── Defaults ────────────────────────────────────────────────── +GPU_ID=1 +MODEL_NAME="Qwen/Qwen3-1.7B" +TOPK_VAL=30 +MEM=0.7 +MAX_TOTAL_TOKENS=1048576 +ALGO="block_sparse_attention" +SAMPLE_STRIDE=1 +SEQ_LEN=32768 +BLOCK_SIZE=16 +BATCH_SIZE=4 +NUM_KV_HEADS=8 +DISTRIBUTIONS="normal bucket_uniform" +# Modes 1 (LUT_CDF) and 2 (Quantile) are no longer benchmarked — their +# mapping happens inside compute_stage1_bin, not apply_transform, so +# split-phase timing isn't meaningful for them. +MAPPING_MODES="0 3 6 7 9 10 11 13 15 16 17 18 19 20" +# Fallback hparam used only if autotune is explicitly skipped. +MAPPING_HPARAM=0.5 +REPEAT=100 +WARMUP=20 +# Empty by default — Step 1 will calibrate on the selected model. +# Pass --real-histograms /path/to/raw_histograms.npy to skip calibration. +REAL_HISTOGRAMS="/var/tmp/zhuominc/vortex_torch/calibration/raw_histograms_qwen3-4B.npy" +SKIP_AUTOTUNE=0 + +# ── Parse arguments ─────────────────────────────────────────── +while [[ $# -gt 0 ]]; do + case "$1" in + --model-name) MODEL_NAME="$2"; shift 2 ;; + --topk-val) TOPK_VAL="$2"; shift 2 ;; + --mem) MEM="$2"; shift 2 ;; + --max-total-tokens) MAX_TOTAL_TOKENS="$2"; shift 2 ;; + --gpu) GPU_ID="$2"; shift 2 ;; + --algo) ALGO="$2"; shift 2 ;; + --real-histograms) REAL_HISTOGRAMS="$2"; shift 2 ;; + --sample-stride) SAMPLE_STRIDE="$2"; shift 2 ;; + --seq-len) SEQ_LEN="$2"; shift 2 ;; + --block-size|--page-size) BLOCK_SIZE="$2"; shift 2 ;; + --batch-size) BATCH_SIZE="$2"; shift 2 ;; + --num-kv-heads) NUM_KV_HEADS="$2"; shift 2 ;; + --distributions) DISTRIBUTIONS="$2"; shift 2 ;; + --modes) MAPPING_MODES="$2"; shift 2 ;; + --mapping-hparam) MAPPING_HPARAM="$2"; shift 2 ;; + --repeat) REPEAT="$2"; shift 2 ;; + --warmup) WARMUP="$2"; shift 2 ;; + --skip-autotune) SKIP_AUTOTUNE=1; shift 1 ;; + *) echo "Unknown option: $1"; exit 1 ;; + esac +done + +export CUDA_VISIBLE_DEVICES="${GPU_ID}" +export PYTORCH_CUDA_ALLOC_CONF="${PYTORCH_CUDA_ALLOC_CONF:-expandable_segments:True}" + +# Qwen3-1.7B does not use DeepGEMM (no FP8/MoE path). +# Disable its JIT to silence "NVCC Compiler not found ... use NVRTC" on Blackwell. +export SGL_ENABLE_JIT_DEEPGEMM="${SGL_ENABLE_JIT_DEEPGEMM:-true}" + +# If DeepGEMM JIT is ever re-enabled, make sure it can find nvcc. +if [ -z "${DG_JIT_NVCC_COMPILER:-}" ]; then + if [ -x /usr/local/cuda/bin/nvcc ]; then + export CUDA_HOME="${CUDA_HOME:-/usr/local/cuda}" + export PATH="${CUDA_HOME}/bin:${PATH}" + export DG_JIT_NVCC_COMPILER="${CUDA_HOME}/bin/nvcc" + elif command -v nvcc >/dev/null 2>&1; then + export DG_JIT_NVCC_COMPILER="$(command -v nvcc)" + fi +fi + +# Validate seq_len: need pages/seg > topk_val (3 reserved pages) +MIN_SEQ_LEN=$(( (TOPK_VAL + 4) * BLOCK_SIZE )) +if [ "${SEQ_LEN}" -lt "${MIN_SEQ_LEN}" ]; then + echo "ERROR: --seq-len ${SEQ_LEN} too small for --topk-val ${TOPK_VAL} @ --block-size ${BLOCK_SIZE}." + echo " Minimum: ${MIN_SEQ_LEN} (pages/seg must exceed topk_val + 3 reserved pages)" + exit 1 +fi + +RESULTS_DIR="${SCRIPT_DIR}/results" +mkdir -p "${RESULTS_DIR}" +TIMESTAMP=$(date +%Y%m%d_%H%M%S) +MODEL_SLUG="$(echo "${MODEL_NAME}" | tr '/' '_')" +RUN_DIR="${RESULTS_DIR}/remap_bench_${MODEL_SLUG}_topk${TOPK_VAL}_bs${BLOCK_SIZE}_${TIMESTAMP}" +mkdir -p "${RUN_DIR}" + +# Calibration artifacts live on /var/tmp (large disk), keyed by model. +# Example: /var/tmp/zhuominc/vortex_torch/calibration/raw_histograms_qwen3-1.7B.npy +CALIBRATION_BASE="/var/tmp/zhuominc/vortex_torch/calibration" +MODEL_TAG="$(echo "${MODEL_NAME##*/}" | sed 's/^Q/q/')" +DEFAULT_REAL_HIST="${CALIBRATION_BASE}/raw_histograms_${MODEL_TAG}.npy" +mkdir -p "${CALIBRATION_BASE}" + +# If no explicit --real-histograms and a cached file exists, reuse it. +if [ -z "${REAL_HISTOGRAMS}" ] && [ -f "${DEFAULT_REAL_HIST}" ]; then + REAL_HISTOGRAMS="${DEFAULT_REAL_HIST}" +fi + +echo "============================================================" +echo "Remap Function Benchmark" +echo " Model: ${MODEL_NAME}" +echo " Algorithm: ${ALGO}" +echo " TopK: ${TOPK_VAL}" +echo " Block size: ${BLOCK_SIZE}" +echo " Seq len: ${SEQ_LEN} ($(( SEQ_LEN / BLOCK_SIZE )) pages/seg)" +echo " Batch size: ${BATCH_SIZE}" +echo " KV heads: ${NUM_KV_HEADS}" +echo " Distributions: ${DISTRIBUTIONS}" +echo " Mapping modes: ${MAPPING_MODES}" +echo " Fallback hparam: ${MAPPING_HPARAM} (used only when --skip-autotune)" +echo " Max total tokens: ${MAX_TOTAL_TOKENS} (calibration KV / VTX buffer cap)" +echo " GPU: ${GPU_ID}" +echo " Sample stride: ${SAMPLE_STRIDE}" +echo " Real histograms: ${REAL_HISTOGRAMS:-}" +echo " Output: ${RUN_DIR}" +echo "============================================================" + +# ── Step 1: Calibrate — collect real-distribution topk histograms ── +# calibrate_topk.py runs the model end-to-end with histogram profiling +# enabled and writes per-segment raw_histograms.npy. The histograms are +# aggregated over every layer and every decode/prefill step so the +# autotune in Step 2 sees the true attention-score distribution. +if [ -n "${REAL_HISTOGRAMS}" ]; then + echo "" + echo ">>> Step 1: SKIPPED (using provided --real-histograms ${REAL_HISTOGRAMS})" + REAL_HIST_PATH="${REAL_HISTOGRAMS}" +else + echo "" + echo ">>> Step 1: Calibrating ${MODEL_NAME} — collecting real topk histograms" + CALIBRATION_DIR="${CALIBRATION_BASE}/staging_${MODEL_TAG}_topk${TOPK_VAL}_bs${BLOCK_SIZE}_${TIMESTAMP}" + mkdir -p "${CALIBRATION_DIR}" + python "${BENCH_DIR}/calibrate_topk.py" \ + --model-name "${MODEL_NAME}" \ + --topk-val "${TOPK_VAL}" \ + --page-size "${BLOCK_SIZE}" \ + --mem "${MEM}" \ + --max-total-tokens "${MAX_TOTAL_TOKENS}" \ + --vortex-module-name "${ALGO}" \ + --output-dir "${CALIBRATION_DIR}" \ + 2>&1 | tee "${RUN_DIR}/step1_calibrate.log" + # Promote raw_histograms.npy to the shared per-model cache path. + mv -f "${CALIBRATION_DIR}/raw_histograms.npy" "${DEFAULT_REAL_HIST}" + REAL_HIST_PATH="${DEFAULT_REAL_HIST}" + echo ">>> Step 1: Done. raw_histograms -> ${REAL_HIST_PATH}" + echo ">>> Step 1: Staging dir (lut/quantiles/logs): ${CALIBRATION_DIR}" +fi + +# Modes 1 (LUT_CDF) and 2 (Quantile) are dropped from the comparison, so +# lut.npy / quantiles.npy produced by calibration are no longer consumed. + +# ── Step 2: Auto-tune hyperparameters by profiled fused-topk latency ── +# For every (mode, hparam) combo in the sweep grid, the autotune runs the +# fused remap+topk kernel on the real histogram and measures end-to-end +# kernel latency with CUDA events. The per-mode hparam with the lowest +# measured topk kernel latency wins. +AUTOTUNE_JSON="${RUN_DIR}/autotune_results.json" +if [ "${SKIP_AUTOTUNE}" -eq 1 ]; then + echo "" + echo ">>> Step 2: SKIPPED (using fallback --mapping-hparam ${MAPPING_HPARAM})" + AUTOTUNE_ARGS="" +else + echo "" + echo ">>> Step 2: Auto-tuning hyperparameters by profiled topk kernel latency" + PYTHONPATH="${SCRIPT_DIR}/.." python "${BENCH_DIR}/autotune_topk_mapping.py" \ + --batch-size "${BATCH_SIZE}" \ + --num-kv-heads "${NUM_KV_HEADS}" \ + --seq-len "${SEQ_LEN}" \ + --topk-val "${TOPK_VAL}" \ + --page-size "${BLOCK_SIZE}" \ + --real-histograms "${REAL_HIST_PATH}" \ + --warmup "${WARMUP}" \ + --repeat "${REPEAT}" \ + --collect-stats \ + --output-json "${AUTOTUNE_JSON}" \ + 2>&1 | tee "${RUN_DIR}/step2_autotune.log" + echo ">>> Step 2: Done. Autotune results saved to ${AUTOTUNE_JSON}" + AUTOTUNE_ARGS="--autotune-json ${AUTOTUNE_JSON}" +fi + +# ── Step 3: Remap benchmark (baseline / fused / remap / split) ── +echo "" +echo ">>> Step 3: Timing remap / topk / fused / baseline with autotuned hparams" +REMAP_JSON="${RUN_DIR}/remap_bench.json" +BENCH_EXTRA=() +[ -n "${REAL_HIST_PATH}" ] && BENCH_EXTRA+=(--real-histograms "${REAL_HIST_PATH}") +PYTHONPATH="${SCRIPT_DIR}/.." python "${BENCH_DIR}/bench_topk.py" \ + --remap-bench \ + --per-head-bench \ + --batch-sizes "${BATCH_SIZE}" \ + --num-kv-heads "${NUM_KV_HEADS}" \ + --seq-lens "${SEQ_LEN}" \ + --topk-vals "${TOPK_VAL}" \ + --page-size "${BLOCK_SIZE}" \ + --distributions ${DISTRIBUTIONS} \ + --mapping-modes ${MAPPING_MODES} \ + --mapping-hparam "${MAPPING_HPARAM}" \ + ${AUTOTUNE_ARGS} \ + "${BENCH_EXTRA[@]}" \ + --warmup "${WARMUP}" \ + --repeat "${REPEAT}" \ + --output-json "${REMAP_JSON}" \ + 2>&1 | tee "${RUN_DIR}/step3_remap_bench.log" +echo ">>> Step 3: Done. Remap bench saved to ${REMAP_JSON}" + +# ── Summary ─────────────────────────────────────────────────── +echo "" +echo "============================================================" +echo "Remap Function Benchmark Complete" +echo " Model: ${MODEL_NAME}" +echo " Block size: ${BLOCK_SIZE}" +echo " All outputs in: ${RUN_DIR}/" +echo " calibration/raw_histograms.npy — real topk distribution (per layer)" +echo " autotune_results.json — latency-ranked mapping hparams" +echo " remap_bench.json — per-config remap/topk/fused/baseline latencies" +echo " step{1,2,3}_*.log — pipeline logs" +echo "============================================================" diff --git a/examples/run_distribution_analysis.sh b/examples/run_distribution_analysis.sh new file mode 100755 index 0000000..2515015 --- /dev/null +++ b/examples/run_distribution_analysis.sh @@ -0,0 +1,236 @@ +#!/usr/bin/env bash +# ============================================================ +# Bucket Distribution Profiling Pipeline +# +# Profiles the SGLang TopK kernel's first-pass bucket distribution +# to identify hotspot buckets causing tail latency. +# +# Four steps: +# 1. Calibrate — collect real-data histograms +# (skippable via --real-histograms PATH) +# 2. Auto-tune — sweep hyperparameters to find best per-mode power +# 3. Bench — histogram profiling (bucket_uniform + normal) +# noscale kernels use the same autotuned power +# 4. Analyze — comparison plots + bucket count tables +# +# All outputs (JSON, plots, CSV tables, logs) are written to a +# single timestamped folder under examples/results/dist_analysis_*. +# +# Usage: +# bash run_distribution_analysis.sh --gpu 5 +# bash run_distribution_analysis.sh --gpu 5 \ +# --real-histograms /path/to/calibration_dir/raw_histograms.npy +# bash run_distribution_analysis.sh --gpu 5 --block-size 16 +# bash run_distribution_analysis.sh --watchdog-timeout 0 # disable calibrate watchdog (fork) +# bash run_distribution_analysis.sh --max-total-tokens 1048576 # cap KV / VTX buffers during calibrate +# Models (default: 1.7B + 4B). Override with repeated --model-name: +# bash run_distribution_analysis.sh --model-name Qwen/Qwen3-1.7B --model-name Qwen/Qwen3-4B +# ============================================================ + +# Mapping functions: +# 0: None — original fp16 bit-pattern bucketing +# 1: LUT CDF — LUT-based CDF equalization (calibrated) +# 2: Quantile — piecewise-linear quantile mapping (calibrated) +# 3: Power — y = sign(x) * |x|^p +# 4: Log — y = sign(x) * log(|x| + 1) +# 5: Index Cache — reuse previous layer's indices +# 6: Asinh — y = asinh(beta * x) +# 7: Log1p — y = sign(x) * log1p(alpha * |x|) +# 8: Trunc8 — bf16 upper-8-bit bucketing +# 9: Erf — y = erf(alpha * x) +# 10: Tanh — y = tanh(alpha * x) +# 11: Subtract — x - pivot (RadiK-style scatter) + +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +BENCH_DIR="${SCRIPT_DIR}/../benchmarks" + +# ── Defaults ────────────────────────────────────────────────── +GPU_ID=7 +# Models to run (full pipeline per model). Override with one or more --model-name. +MODEL_NAMES=( "Qwen/Qwen3-1.7B" "Qwen/Qwen3-4B" ) +MODEL_NAMES_USER_SET=0 +TOPK_VAL=30 +MEM=0.7 +MAX_TOTAL_TOKENS=1048576 +ALGO="block_sparse_attention" +RADIX_BITS=8 +SAMPLE_STRIDE=1 +SEQ_LEN=32768 +# KV page / block size (passed to benchmarks as --page-size) +BLOCK_SIZE=16 +# The path to the raw_histograms.npy file (set to skip calibration) +REAL_HISTOGRAMS="/data/datasets/xinrui/My_Projects/vortex_torch/examples/calibration/raw_histograms.npy" +REAL_HISTOGRAMS="" +HAS_WATCHDOG_TIMEOUT=0 +WATCHDOG_TIMEOUT="" +# ── Parse arguments ─────────────────────────────────────────── +while [[ $# -gt 0 ]]; do + case "$1" in + --model-name) + if [ "${MODEL_NAMES_USER_SET}" -eq 0 ]; then + MODEL_NAMES=() + MODEL_NAMES_USER_SET=1 + fi + MODEL_NAMES+=("$2") + shift 2 + ;; + --topk-val) TOPK_VAL="$2"; shift 2 ;; + --mem) MEM="$2"; shift 2 ;; + --max-total-tokens) MAX_TOTAL_TOKENS="$2"; shift 2 ;; + --gpu) GPU_ID="$2"; shift 2 ;; + --algo) ALGO="$2"; shift 2 ;; + --real-histograms) REAL_HISTOGRAMS="$2"; shift 2 ;; + --radix-bits) RADIX_BITS="$2"; shift 2 ;; + --sample-stride) SAMPLE_STRIDE="$2"; shift 2 ;; + --seq-len) SEQ_LEN="$2"; shift 2 ;; + --block-size) BLOCK_SIZE="$2"; shift 2 ;; + --watchdog-timeout) HAS_WATCHDOG_TIMEOUT=1; WATCHDOG_TIMEOUT="$2"; shift 2 ;; + *) echo "Unknown option: $1"; exit 1 ;; + esac +done + +export CUDA_VISIBLE_DEVICES="${GPU_ID}" + +if [ "${#MODEL_NAMES[@]}" -eq 0 ]; then + echo "ERROR: No models in MODEL_NAMES; pass at least one --model-name." + exit 1 +fi + +# Validate seq_len: need pages/seg > topk_val (reserved=3 pages + slack) +MIN_SEQ_LEN=$(( (TOPK_VAL + 4) * BLOCK_SIZE )) +if [ "${SEQ_LEN}" -lt "${MIN_SEQ_LEN}" ]; then + echo "ERROR: --seq-len ${SEQ_LEN} too small for --topk-val ${TOPK_VAL}." + echo " Minimum: ${MIN_SEQ_LEN} (pages/seg must exceed topk_val + 3 reserved pages)" + exit 1 +fi + +RESULTS_DIR="${SCRIPT_DIR}/results" +mkdir -p "${RESULTS_DIR}" +TIMESTAMP=$(date +%Y%m%d_%H%M%S) + +echo "============================================================" +echo "Bucket Distribution Profiling Pipeline" +echo " Models (${#MODEL_NAMES[@]}): ${MODEL_NAMES[*]}" +echo " Algorithm: ${ALGO}" +echo " TopK: ${TOPK_VAL}" +echo " Seq len: ${SEQ_LEN} ($(( SEQ_LEN / BLOCK_SIZE )) pages/seg)" +echo " Block size: ${BLOCK_SIZE} (--page-size in benchmarks)" +echo " GPU: ${GPU_ID}" +echo " Radix bits: ${RADIX_BITS} ($(( 1 << RADIX_BITS )) bins)" +echo " Sample stride: ${SAMPLE_STRIDE}" +echo " Max total tokens: ${MAX_TOTAL_TOKENS} (calibration KV / VTX buffer cap)" +if [ "${HAS_WATCHDOG_TIMEOUT}" -eq 1 ]; then + echo " Watchdog (cal): ${WATCHDOG_TIMEOUT}s (0 = off, vortex SGLang fork)" +else + echo " Watchdog (cal): " +fi +echo " Real histograms: ${REAL_HISTOGRAMS:-}" +echo " Run id: ${TIMESTAMP}" +echo " Output root: ${RESULTS_DIR}/dist_analysis__${TIMESTAMP}/" +echo "============================================================" + +for MODEL_NAME in "${MODEL_NAMES[@]}"; do + MODEL_SLUG="${MODEL_NAME//\//_}" + RUN_DIR="${RESULTS_DIR}/dist_analysis_${MODEL_SLUG}_${TIMESTAMP}" + mkdir -p "${RUN_DIR}" + + echo "" + echo "############################ MODEL: ${MODEL_NAME} ############################" + echo " Output: ${RUN_DIR}" + + # ── Step 1: Calibrate — collect real-data histograms + LUT/quantiles ── + if [ -n "${REAL_HISTOGRAMS}" ]; then + echo "" + echo ">>> Step 1: SKIPPED (using provided --real-histograms ${REAL_HISTOGRAMS})" + REAL_HIST_PATH="${REAL_HISTOGRAMS}" + else + echo "" + echo ">>> Step 1: Calibrating — collecting real-inference histograms" + CALIBRATION_DIR="${RUN_DIR}/calibration" + mkdir -p "${CALIBRATION_DIR}" + CALIB_EXTRA_ARGS=() + if [ "${HAS_WATCHDOG_TIMEOUT}" -eq 1 ]; then + CALIB_EXTRA_ARGS+=(--watchdog-timeout "${WATCHDOG_TIMEOUT}") + fi + python "${BENCH_DIR}/calibrate_topk.py" \ + --model-name "${MODEL_NAME}" \ + --topk-val "${TOPK_VAL}" \ + --mem "${MEM}" \ + --max-total-tokens "${MAX_TOTAL_TOKENS}" \ + --vortex-module-name "${ALGO}" \ + --page-size "${BLOCK_SIZE}" \ + --output-dir "${CALIBRATION_DIR}" \ + "${CALIB_EXTRA_ARGS[@]}" \ + 2>&1 | tee "${RUN_DIR}/step1_calibrate.log" + REAL_HIST_PATH="${CALIBRATION_DIR}/raw_histograms.npy" + echo ">>> Step 1: Done. Calibration saved to ${CALIBRATION_DIR}" + fi + + # Pick up lut.npy / quantiles.npy if calibration produced them. + CALIB_DIR="$(dirname "${REAL_HIST_PATH}")" + LUT_PATH="" + Q_PATH="" + [ -f "${CALIB_DIR}/lut.npy" ] && LUT_PATH="${CALIB_DIR}/lut.npy" + [ -f "${CALIB_DIR}/quantiles.npy" ] && Q_PATH="${CALIB_DIR}/quantiles.npy" + [ -n "${LUT_PATH}" ] && echo " Calibration LUT: ${LUT_PATH}" + [ -n "${Q_PATH}" ] && echo " Calibration quantile: ${Q_PATH}" + + # ── Step 2: Auto-tune — rank by fused-topk kernel latency ────── + echo "" + echo ">>> Step 2: Auto-tuning hyperparameters by fused-topk kernel latency" + + AUTOTUNE_JSON="${RUN_DIR}/autotune_results.json" + AUTOTUNE_EXTRA=(--real-histograms "${REAL_HIST_PATH}") + [ -n "${LUT_PATH}" ] && AUTOTUNE_EXTRA+=(--lut-path "${LUT_PATH}") + [ -n "${Q_PATH}" ] && AUTOTUNE_EXTRA+=(--quantiles-path "${Q_PATH}") + + PYTHONPATH="${SCRIPT_DIR}/.." python "${BENCH_DIR}/autotune_topk_mapping.py" \ + --topk-val "${TOPK_VAL}" \ + --batch-size 4 \ + --seq-len ${SEQ_LEN} \ + --page-size "${BLOCK_SIZE}" \ + --num-kv-heads 2 \ + --collect-stats \ + "${AUTOTUNE_EXTRA[@]}" \ + --output-json "${AUTOTUNE_JSON}" \ + 2>&1 | tee "${RUN_DIR}/step2_autotune.log" + + echo ">>> Step 2: Done. Autotune results saved to ${AUTOTUNE_JSON}" + + # ── Step 3: Remap benchmark with autotuned hparams ────────────── + echo "" + echo ">>> Step 3: Remap benchmark (baseline / fused / remap / split) with autotuned hparams" + + BENCH_JSON="${RUN_DIR}/remap_bench.json" + BENCH_EXTRA=() + [ -n "${LUT_PATH}" ] && BENCH_EXTRA+=(--lut-path "${LUT_PATH}") + [ -n "${Q_PATH}" ] && BENCH_EXTRA+=(--quantiles-path "${Q_PATH}") + + PYTHONPATH="${SCRIPT_DIR}/.." python "${BENCH_DIR}/bench_topk.py" \ + --remap-bench \ + --batch-sizes 4 \ + --num-kv-heads 8 \ + --seq-lens ${SEQ_LEN} \ + --topk-vals "${TOPK_VAL}" \ + --page-size "${BLOCK_SIZE}" \ + --distributions bucket_uniform normal \ + --mapping-modes 0 1 2 3 6 7 8 9 10 11 13 \ + --autotune-json "${AUTOTUNE_JSON}" \ + "${BENCH_EXTRA[@]}" \ + --repeat 20 \ + --output-json "${BENCH_JSON}" \ + 2>&1 | tee "${RUN_DIR}/step3_bench.log" + + echo ">>> Step 3: Done. Remap bench saved to ${BENCH_JSON}" +done + +# ── Summary ─────────────────────────────────────────────────── +echo "" +echo "============================================================" +echo "Bucket Distribution Profiling Complete" +echo " Per-model outputs under ${RESULTS_DIR}/ (run id ${TIMESTAMP}):" +echo " dist_analysis__${TIMESTAMP}/" +echo " autotune_results.json, bench_distribution.json, plots, CSV, logs" +echo "============================================================" diff --git a/examples/run_distribution_analysis_new.sh b/examples/run_distribution_analysis_new.sh new file mode 100755 index 0000000..38438bd --- /dev/null +++ b/examples/run_distribution_analysis_new.sh @@ -0,0 +1,196 @@ +#!/usr/bin/env bash +# ============================================================ +# Bucket Distribution / Remap Latency Pipeline (parametric modes) +# +# Tests the surviving parametric mapping modes after the lean +# refactor: +# Mode 3 (Power): y = sign(x) * |x|^p +# Mode 6 (Asinh): y = asinh(beta * x) +# Mode 7 (Log1p): y = sign(x) * log1p(alpha * |x|) +# Mode 9 (Erf): y = erf(alpha * x) +# Mode 10 (Tanh): y = tanh(alpha * x) +# Mode 13 (ExpStretch): y = exp(alpha * x) +# +# Pipeline: +# 1. Calibrate — collect real-distribution histograms from the +# chosen model (skippable via --real-histograms). +# 2. Autotune — rank per-mode hparams by measured fused-topk +# kernel latency (lowest wins). +# 3. Remap bench— bench_topk.py --remap-bench fed with the +# autotune JSON. Reports per-mode remap / topk / +# fused / baseline latencies and threshold stats. +# +# Usage: +# bash run_distribution_analysis_new.sh --gpu 5 +# bash run_distribution_analysis_new.sh --gpu 5 \ +# --model-name Qwen/Qwen3-8B --block-size 32 +# bash run_distribution_analysis_new.sh --gpu 5 \ +# --real-histograms /path/to/raw_histograms.npy +# bash run_distribution_analysis_new.sh --gpu 5 --max-total-tokens 524288 +# ============================================================ +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +BENCH_DIR="${SCRIPT_DIR}/../benchmarks" + +# ── Defaults ────────────────────────────────────────────────── +GPU_ID=4 +MODEL_NAME="Qwen/Qwen3-1.7B" +TOPK_VAL=2048 +MEM=0.7 +MAX_TOTAL_TOKENS=1048576 +ALGO="block_sparse_attention" +SEQ_LEN=65536 +BLOCK_SIZE=16 +BATCH_SIZE=4 +NUM_KV_HEADS=8 +DISTRIBUTIONS="bucket_uniform normal" +# LUT_CDF (1) / QUANTILE (2) are evaluated only when calibration produces +# lut.npy / quantiles.npy. 0 baseline is always included by --remap-bench. +MAPPING_MODES="1 2 3 6 7 8 9 10 11 13" +REPEAT=100 +WARMUP=20 +REAL_HISTOGRAMS="" + +while [[ $# -gt 0 ]]; do + case "$1" in + --model-name) MODEL_NAME="$2"; shift 2 ;; + --topk-val) TOPK_VAL="$2"; shift 2 ;; + --mem) MEM="$2"; shift 2 ;; + --max-total-tokens) MAX_TOTAL_TOKENS="$2"; shift 2 ;; + --gpu) GPU_ID="$2"; shift 2 ;; + --algo) ALGO="$2"; shift 2 ;; + --real-histograms) REAL_HISTOGRAMS="$2"; shift 2 ;; + --seq-len) SEQ_LEN="$2"; shift 2 ;; + --block-size|--page-size) BLOCK_SIZE="$2"; shift 2 ;; + --batch-size) BATCH_SIZE="$2"; shift 2 ;; + --num-kv-heads) NUM_KV_HEADS="$2"; shift 2 ;; + --distributions) DISTRIBUTIONS="$2"; shift 2 ;; + --modes) MAPPING_MODES="$2"; shift 2 ;; + --repeat) REPEAT="$2"; shift 2 ;; + --warmup) WARMUP="$2"; shift 2 ;; + *) echo "Unknown option: $1"; exit 1 ;; + esac +done + +export CUDA_VISIBLE_DEVICES="${GPU_ID}" + +MIN_SEQ_LEN=$(( (TOPK_VAL + 4) * BLOCK_SIZE )) +if [ "${SEQ_LEN}" -lt "${MIN_SEQ_LEN}" ]; then + echo "ERROR: --seq-len ${SEQ_LEN} too small for --topk-val ${TOPK_VAL} @ --block-size ${BLOCK_SIZE}." + echo " Minimum: ${MIN_SEQ_LEN}" + exit 1 +fi + +RESULTS_DIR="${SCRIPT_DIR}/results" +mkdir -p "${RESULTS_DIR}" +TIMESTAMP=$(date +%Y%m%d_%H%M%S) +MODEL_SLUG="$(echo "${MODEL_NAME}" | tr '/' '_')" +RUN_DIR="${RESULTS_DIR}/dist_analysis_${MODEL_SLUG}_topk${TOPK_VAL}_bs${BLOCK_SIZE}_${TIMESTAMP}" +mkdir -p "${RUN_DIR}" + +echo "============================================================" +echo "Bucket Distribution / Remap Latency Pipeline (parametric modes)" +echo " Model: ${MODEL_NAME}" +echo " Algorithm: ${ALGO}" +echo " TopK: ${TOPK_VAL}" +echo " Block size: ${BLOCK_SIZE}" +echo " Seq len: ${SEQ_LEN} ($(( SEQ_LEN / BLOCK_SIZE )) pages/seg)" +echo " Batch size: ${BATCH_SIZE}" +echo " KV heads: ${NUM_KV_HEADS}" +echo " Distributions: ${DISTRIBUTIONS}" +echo " Mapping modes: ${MAPPING_MODES}" +echo " Max total tokens: ${MAX_TOTAL_TOKENS} (calibration KV / VTX buffer cap)" +echo " GPU: ${GPU_ID}" +echo " Real histograms: ${REAL_HISTOGRAMS:-}" +echo " Output: ${RUN_DIR}" +echo "============================================================" + +# ── Step 1: Calibrate ─────────────────────────────────────────── +if [ -n "${REAL_HISTOGRAMS}" ]; then + echo "" + echo ">>> Step 1: SKIPPED (using provided --real-histograms ${REAL_HISTOGRAMS})" + REAL_HIST_PATH="${REAL_HISTOGRAMS}" +else + echo "" + echo ">>> Step 1: Calibrating ${MODEL_NAME} — collecting real topk histograms" + CALIBRATION_DIR="${RUN_DIR}/calibration" + mkdir -p "${CALIBRATION_DIR}" + python "${BENCH_DIR}/calibrate_topk.py" \ + --model-name "${MODEL_NAME}" \ + --topk-val "${TOPK_VAL}" \ + --mem "${MEM}" \ + --max-total-tokens "${MAX_TOTAL_TOKENS}" \ + --vortex-module-name "${ALGO}" \ + --page-size "${BLOCK_SIZE}" \ + --output-dir "${CALIBRATION_DIR}" \ + 2>&1 | tee "${RUN_DIR}/step1_calibrate.log" + REAL_HIST_PATH="${CALIBRATION_DIR}/raw_histograms.npy" + echo ">>> Step 1: Done. Calibration saved to ${CALIBRATION_DIR}" +fi + +# Pick up lut.npy / quantiles.npy if calibration produced them. +CALIB_DIR="$(dirname "${REAL_HIST_PATH}")" +LUT_PATH="" +Q_PATH="" +[ -f "${CALIB_DIR}/lut.npy" ] && LUT_PATH="${CALIB_DIR}/lut.npy" +[ -f "${CALIB_DIR}/quantiles.npy" ] && Q_PATH="${CALIB_DIR}/quantiles.npy" + +# ── Step 2: Autotune (latency-ranked) ─────────────────────────── +echo "" +echo ">>> Step 2: Auto-tuning hyperparameters by fused-topk kernel latency" +AUTOTUNE_JSON="${RUN_DIR}/autotune_results.json" +AUTOTUNE_EXTRA=() +[ -n "${LUT_PATH}" ] && AUTOTUNE_EXTRA+=(--lut-path "${LUT_PATH}") +[ -n "${Q_PATH}" ] && AUTOTUNE_EXTRA+=(--quantiles-path "${Q_PATH}") +PYTHONPATH="${SCRIPT_DIR}/.." python "${BENCH_DIR}/autotune_topk_mapping.py" \ + --topk-val "${TOPK_VAL}" \ + --batch-size "${BATCH_SIZE}" \ + --num-kv-heads "${NUM_KV_HEADS}" \ + --seq-len "${SEQ_LEN}" \ + --page-size "${BLOCK_SIZE}" \ + --real-histograms "${REAL_HIST_PATH}" \ + --warmup "${WARMUP}" \ + --repeat "${REPEAT}" \ + --collect-stats \ + "${AUTOTUNE_EXTRA[@]}" \ + --output-json "${AUTOTUNE_JSON}" \ + 2>&1 | tee "${RUN_DIR}/step2_autotune.log" +echo ">>> Step 2: Done. Autotune results saved to ${AUTOTUNE_JSON}" + +# ── Step 3: Remap bench with autotuned hparams ────────────────── +echo "" +echo ">>> Step 3: Remap benchmark (baseline / fused / remap / split) with autotuned hparams" +BENCH_JSON="${RUN_DIR}/remap_bench.json" +BENCH_EXTRA=() +[ -n "${LUT_PATH}" ] && BENCH_EXTRA+=(--lut-path "${LUT_PATH}") +[ -n "${Q_PATH}" ] && BENCH_EXTRA+=(--quantiles-path "${Q_PATH}") +PYTHONPATH="${SCRIPT_DIR}/.." python "${BENCH_DIR}/bench_topk.py" \ + --remap-bench \ + --batch-sizes "${BATCH_SIZE}" \ + --num-kv-heads "${NUM_KV_HEADS}" \ + --seq-lens "${SEQ_LEN}" \ + --topk-vals "${TOPK_VAL}" \ + --page-size "${BLOCK_SIZE}" \ + --distributions ${DISTRIBUTIONS} \ + --mapping-modes ${MAPPING_MODES} \ + --autotune-json "${AUTOTUNE_JSON}" \ + "${BENCH_EXTRA[@]}" \ + --warmup "${WARMUP}" \ + --repeat "${REPEAT}" \ + --output-json "${BENCH_JSON}" \ + 2>&1 | tee "${RUN_DIR}/step3_bench.log" +echo ">>> Step 3: Done. Remap bench saved to ${BENCH_JSON}" + +# ── Summary ───────────────────────────────────────────────────── +echo "" +echo "============================================================" +echo "Bucket Distribution / Remap Latency Pipeline Complete" +echo " Model: ${MODEL_NAME}" +echo " Block size: ${BLOCK_SIZE}" +echo " All outputs in: ${RUN_DIR}/" +echo " calibration/raw_histograms.npy — real topk distribution" +echo " autotune_results.json — latency-ranked hparams" +echo " remap_bench.json — remap/topk/fused/baseline latencies" +echo " step{1,2,3}_*.log — pipeline logs" +echo "============================================================" diff --git a/examples/run_topk_benchmark.sh b/examples/run_topk_benchmark.sh new file mode 100755 index 0000000..f3eabff --- /dev/null +++ b/examples/run_topk_benchmark.sh @@ -0,0 +1,253 @@ +#!/usr/bin/env bash +# ============================================================ +# Unified TopK Benchmark +# +# Three-step pipeline on a single configurable model: +# Step 1: Calibrate — run the model to collect +# real-distribution histograms +# (raw_histograms.npy, lut.npy, +# quantiles.npy). +# Step 2: Latency autotune + bench — rank per-mode hparams by +# measured fused-topk kernel +# latency, then run the +# remap / topk / fused / baseline +# comparison. +# Step 3: E2E accuracy — verify_algo.py on the same +# model for the unmapped baseline +# plus each mapping mode, with +# autotuned hparams. +# +# Usage: +# bash run_topk_benchmark.sh --gpu 0 +# bash run_topk_benchmark.sh --gpu 0 --model-name Qwen/Qwen3-8B \ +# --block-size 32 --topk-val 512 +# bash run_topk_benchmark.sh --gpu 0 --max-total-tokens 1048576 +# ============================================================ +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +BENCH_DIR="${SCRIPT_DIR}/../benchmarks" + +# ── Defaults ────────────────────────────────────────────────── +GPU_ID=4 +MODEL_NAME="Qwen/Qwen3-1.7B" +TOPK_VAL=30 +TRIALS=8 +MEM=0.7 +MAX_TOTAL_TOKENS=1048576 +ALGO="block_sparse_attention" +BLOCK_SIZE=16 +BATCH_SIZE=4 +NUM_KV_HEADS=8 +SEQ_LEN=32768 +BENCHMARKS="amc23" +SKIP_CALIBRATE=false +SKIP_KERNEL=false +SKIP_E2E=true + +# ── Parse arguments ─────────────────────────────────────────── +while [[ $# -gt 0 ]]; do + case "$1" in + --model-name) MODEL_NAME="$2"; shift 2 ;; + --topk-val) TOPK_VAL="$2"; shift 2 ;; + --trials) TRIALS="$2"; shift 2 ;; + --mem) MEM="$2"; shift 2 ;; + --max-total-tokens) MAX_TOTAL_TOKENS="$2"; shift 2 ;; + --gpu) GPU_ID="$2"; shift 2 ;; + --algo) ALGO="$2"; shift 2 ;; + --benchmark) BENCHMARKS="$2"; shift 2 ;; + --block-size|--page-size) BLOCK_SIZE="$2"; shift 2 ;; + --batch-size) BATCH_SIZE="$2"; shift 2 ;; + --num-kv-heads) NUM_KV_HEADS="$2"; shift 2 ;; + --seq-len) SEQ_LEN="$2"; shift 2 ;; + --skip-calibrate) SKIP_CALIBRATE=true; shift ;; + --skip-kernel) SKIP_KERNEL=true; shift ;; + --skip-e2e) SKIP_E2E=false; shift ;; # --skip-e2e actually toggles it OFF (enables) + *) echo "Unknown option: $1"; exit 1 ;; + esac +done + +export CUDA_VISIBLE_DEVICES="${GPU_ID}" + +RESULTS_DIR="${SCRIPT_DIR}/results" +mkdir -p "${RESULTS_DIR}" +TIMESTAMP=$(date +%Y%m%d_%H%M%S) +BENCH_LABEL=$(echo "${BENCHMARKS}" | tr ' ' '_') +MODEL_SLUG="$(echo "${MODEL_NAME}" | tr '/' '_')" +RUN_DIR="${RESULTS_DIR}/topk_benchmark_${MODEL_SLUG}_${BENCH_LABEL}_${TIMESTAMP}" +mkdir -p "${RUN_DIR}" + +echo "============================================================" +echo "Unified TopK Benchmark" +echo " Model: ${MODEL_NAME}" +echo " Algorithm: ${ALGO}" +echo " TopK: ${TOPK_VAL}" +echo " Block size: ${BLOCK_SIZE}" +echo " Seq len: ${SEQ_LEN}" +echo " Batch size: ${BATCH_SIZE}" +echo " KV heads: ${NUM_KV_HEADS}" +echo " Trials: ${TRIALS}" +echo " Max total tokens: ${MAX_TOTAL_TOKENS} (calibration KV / VTX buffer cap)" +echo " GPU: ${GPU_ID}" +echo " Output: ${RUN_DIR}" +echo "============================================================" + +# ── Step 1: Calibrate ──────────────────────────────────────── +CALIBRATION_DIR="${RUN_DIR}/calibration" +if [ "${SKIP_CALIBRATE}" = true ] && [ -d "${CALIBRATION_DIR}" ]; then + echo "" + echo ">>> Step 1: SKIPPED (--skip-calibrate)" +else + echo "" + echo ">>> Step 1: Calibrating ${MODEL_NAME} — real topk histograms + LUT/quantiles" + mkdir -p "${CALIBRATION_DIR}" + python "${BENCH_DIR}/calibrate_topk.py" \ + --model-name "${MODEL_NAME}" \ + --topk-val "${TOPK_VAL}" \ + --mem "${MEM}" \ + --max-total-tokens "${MAX_TOTAL_TOKENS}" \ + --vortex-module-name "${ALGO}" \ + --page-size "${BLOCK_SIZE}" \ + --output-dir "${CALIBRATION_DIR}" \ + 2>&1 | tee "${RUN_DIR}/step1_calibrate.log" + echo ">>> Step 1: Done." +fi + +REAL_HIST_PATH="${CALIBRATION_DIR}/raw_histograms.npy" +LUT_PATH="" +Q_PATH="" +[ -f "${CALIBRATION_DIR}/lut.npy" ] && LUT_PATH="${CALIBRATION_DIR}/lut.npy" +[ -f "${CALIBRATION_DIR}/quantiles.npy" ] && Q_PATH="${CALIBRATION_DIR}/quantiles.npy" +[ -n "${LUT_PATH}" ] && echo " Calibration LUT: ${LUT_PATH}" +[ -n "${Q_PATH}" ] && echo " Calibration quantile: ${Q_PATH}" + +# ── Step 2: Latency autotune + remap bench ─────────────────── +AUTOTUNE_JSON="${RUN_DIR}/autotune_results.json" +if [ "${SKIP_KERNEL}" = true ]; then + echo "" + echo ">>> Step 2: SKIPPED (--skip-kernel)" +else + echo "" + echo ">>> Step 2a: Auto-tuning per-mode hparams by fused-topk kernel latency" + AUTOTUNE_EXTRA=() + [ -f "${REAL_HIST_PATH}" ] && AUTOTUNE_EXTRA+=(--real-histograms "${REAL_HIST_PATH}") + [ -n "${LUT_PATH}" ] && AUTOTUNE_EXTRA+=(--lut-path "${LUT_PATH}") + [ -n "${Q_PATH}" ] && AUTOTUNE_EXTRA+=(--quantiles-path "${Q_PATH}") + PYTHONPATH="${SCRIPT_DIR}/.." python "${BENCH_DIR}/autotune_topk_mapping.py" \ + --topk-val "${TOPK_VAL}" \ + --batch-size "${BATCH_SIZE}" \ + --num-kv-heads "${NUM_KV_HEADS}" \ + --seq-len "${SEQ_LEN}" \ + --page-size "${BLOCK_SIZE}" \ + --warmup 20 --repeat 100 \ + --collect-stats \ + "${AUTOTUNE_EXTRA[@]}" \ + --output-json "${AUTOTUNE_JSON}" \ + 2>&1 | tee "${RUN_DIR}/step2a_autotune.log" + echo ">>> Step 2a: Done. Autotune saved to ${AUTOTUNE_JSON}" + + echo "" + echo ">>> Step 2b: Remap benchmark (baseline / fused / remap / split) with autotuned hparams" + BENCH_JSON="${RUN_DIR}/kernel_latency.json" + BENCH_EXTRA=() + [ -n "${LUT_PATH}" ] && BENCH_EXTRA+=(--lut-path "${LUT_PATH}") + [ -n "${Q_PATH}" ] && BENCH_EXTRA+=(--quantiles-path "${Q_PATH}") + PYTHONPATH="${SCRIPT_DIR}/.." python "${BENCH_DIR}/bench_topk.py" \ + --remap-bench \ + --batch-sizes "${BATCH_SIZE}" \ + --num-kv-heads "${NUM_KV_HEADS}" \ + --seq-lens "${SEQ_LEN}" \ + --topk-vals "${TOPK_VAL}" \ + --page-size "${BLOCK_SIZE}" \ + --distributions normal bucket_uniform \ + --mapping-modes 0 1 2 3 6 7 8 9 10 11 13 \ + --autotune-json "${AUTOTUNE_JSON}" \ + "${BENCH_EXTRA[@]}" \ + --warmup 20 --repeat 100 \ + --output-json "${BENCH_JSON}" \ + 2>&1 | tee "${RUN_DIR}/step2b_kernel_bench.log" + echo ">>> Step 2b: Done. Results saved to ${BENCH_JSON}" +fi + +# ── Step 3: E2E accuracy ───────────────────────────────────── +if [ "${SKIP_E2E}" = true ]; then + echo "" + echo ">>> Step 3: SKIPPED (default). Pass --skip-e2e to toggle it ON." +else + echo "" + echo ">>> Step 3: E2E accuracy comparison" + + # Extract autotuned hparams per mode. + eval "$(python3 -c " +import json, sys +data = json.load(open(sys.argv[1])) +best = {} +for r in data: + m = r.get('mode'); lat = r.get('latency_ms') + if m is None or lat is None: continue + if m not in best or lat < best[m]['latency_ms']: + best[m] = r +for m in (3, 6, 7, 9, 10, 11, 13): + print(f'BEST_HPARAM_{m}={best.get(m, {}).get(\"param\", 0.5)}') +" "${AUTOTUNE_JSON}")" + + E2E_DIR="${RUN_DIR}/e2e" + mkdir -p "${E2E_DIR}" + + run_e2e() { + # $1=label, remaining args passed to verify_algo.py + local label="$1"; shift + local logfile="${E2E_DIR}/${label}.log" + echo "" + echo " --- ${label} ---" + { time python "${SCRIPT_DIR}/verify_algo.py" \ + --trials "${TRIALS}" \ + --topk-val "${TOPK_VAL}" \ + --model-name "${MODEL_NAME}" \ + --benchmark ${BENCHMARKS} \ + --mem "${MEM}" \ + "$@" ; } \ + 2>&1 | tee "${logfile}" + } + + run_mapped() { + # $1=mode $2=hparam $3=label + local mode="$1"; local hp="$2"; local label="$3" + local extra=(--vortex-module-name "${ALGO}") + if [ "${mode}" -eq 0 ]; then + extra+=(--topk-type sglang) + else + extra+=(--topk-type sglang_fused --topk-mapping-mode "${mode}" --topk-mapping-hparam "${hp}") + fi + run_e2e "${label}" "${extra[@]}" + } + + run_e2e "full_attention_baseline" --full-attention + run_e2e "naive_topk" --vortex-module-name "${ALGO}" --topk-type naive + run_mapped 0 0.5 "sglang_m0_none" + run_mapped 3 "${BEST_HPARAM_3}" "sglang_m3_power_p${BEST_HPARAM_3}" + run_mapped 4 0.5 "sglang_m4_log" + run_mapped 6 "${BEST_HPARAM_6}" "sglang_m6_asinh_beta${BEST_HPARAM_6}" + run_mapped 7 "${BEST_HPARAM_7}" "sglang_m7_log1p_alpha${BEST_HPARAM_7}" + run_mapped 8 0.5 "sglang_m8_trunc8" + run_mapped 9 "${BEST_HPARAM_9}" "sglang_m9_erf_alpha${BEST_HPARAM_9}" + run_mapped 10 "${BEST_HPARAM_10}" "sglang_m10_tanh_alpha${BEST_HPARAM_10}" + run_mapped 11 "${BEST_HPARAM_11}" "sglang_m11_subtract_pivot${BEST_HPARAM_11}" + run_mapped 13 "${BEST_HPARAM_13}" "sglang_m13_expstretch_alpha${BEST_HPARAM_13}" + + echo "" + echo ">>> Step 3: Done. E2E logs saved to ${E2E_DIR}/" +fi + +# ── Final Summary ───────────────────────────────────────────── +echo "" +echo "============================================================" +echo "TopK Benchmark Complete" +echo " Model: ${MODEL_NAME}" +echo " Block size: ${BLOCK_SIZE}" +echo " All results: ${RUN_DIR}" +echo " Calibration: ${CALIBRATION_DIR}" +[ "${SKIP_KERNEL}" != true ] && echo " Autotune: ${AUTOTUNE_JSON}" +[ "${SKIP_KERNEL}" != true ] && echo " Kernel JSON: ${RUN_DIR}/kernel_latency.json" +[ "${SKIP_E2E}" != true ] && echo " E2E logs: ${RUN_DIR}/e2e/" +echo "============================================================" diff --git a/examples/test_topk.py b/examples/test_topk.py new file mode 100644 index 0000000..01edc7b --- /dev/null +++ b/examples/test_topk.py @@ -0,0 +1,118 @@ +import torch +import triton +# topk_output_sglang expects sparse_kv_indptr before dense_kv_indices (unlike topk_output). +from vortex_torch_C import topk_output_sglang as topk_output + +SEQ_LENS = [4096] +BATCH_SIZES = [256] + +K = 32 +RESERVE_BOS = 0 +RESERVE_EOS = 0 +DEVICE = "cuda" + + +def make_inputs(batch_size, seq_len, k, reserve_bos, reserve_eos, device="cuda"): + dense_kv_indptr = torch.arange( + 0, batch_size * seq_len + 1, seq_len, dtype=torch.int32, device=device + ) + + dense_kv_indices = torch.arange( + 0, batch_size * seq_len, dtype=torch.int32, device=device + ) + + scores = torch.randn( + batch_size * seq_len, dtype=torch.bfloat16, device=device + ) + + # ✅ Fixed CSR-style sparse indptr + sparse_kv_indptr = torch.arange( + 0, batch_size * k + 1, k, dtype=torch.int32, device=device + ) + + sparse_kv_indices = torch.empty( + batch_size * k, dtype=torch.int32, device=device + ) + + return ( + scores, + dense_kv_indptr, + dense_kv_indices, + sparse_kv_indptr, + sparse_kv_indices, + ) + + +def bench_one(batch_size, seq_len, k, reserve_bos, reserve_eos): + ( + scores, + dense_kv_indptr, + dense_kv_indices, + sparse_kv_indptr, + sparse_kv_indices, + ) = make_inputs( + batch_size=batch_size, + seq_len=seq_len, + k=k, + reserve_bos=reserve_bos, + reserve_eos=reserve_eos, + device=DEVICE, + ) + + def fn(): + topk_output( + scores, + dense_kv_indptr, + sparse_kv_indptr, + dense_kv_indices, + sparse_kv_indices, + batch_size, + k, + reserve_bos, + reserve_eos, + seq_len, + ) + + # warmup + for _ in range(10): + fn() + torch.cuda.synchronize() + + ms = triton.testing.do_bench( + fn, + warmup=100, + rep=1000, + return_mode="mean", + ) + return ms + + +def main(): + torch.cuda.init() + + results = {} + + for bs in BATCH_SIZES: + results[bs] = {} + for seq_len in SEQ_LENS: + ms = bench_one( + batch_size=bs, + seq_len=seq_len, + k=K, + reserve_bos=RESERVE_BOS, + reserve_eos=RESERVE_EOS, + ) + results[bs][seq_len] = ms + print(f"bs={bs:>3}, seq_len={seq_len:>4} -> {ms:.6f} ms") + + print("\nLatency table (ms):") + header = "bs\\seq".ljust(10) + "".join(f"{s:>12}" for s in SEQ_LENS) + print(header) + + for bs in BATCH_SIZES: + row = f"{bs:<10}" + "".join(f"{results[bs][s]:>12.4f}" for s in SEQ_LENS) + print(row) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/examples/verify_aim24.py b/examples/verify_aim24.py new file mode 100644 index 0000000..9e54a96 --- /dev/null +++ b/examples/verify_aim24.py @@ -0,0 +1,106 @@ +import json +import sys +sys.path.append("../") +import python.sglang as sgl +from transformers import AutoTokenizer +import os +from tqdm import tqdm +import time +import torch +os.environ["TOKENIZERS_PARALLELISM"] = "false" +MATH_QUERY_TEMPLATE = """ +Solve the following math problem efficiently and clearly. The last line of your response should be of the following format: 'Therefore, the final answer is: $\\boxed{{ANSWER}}$. I hope it is correct' (without quotes) where ANSWER is just the final number or expression that solves the problem. Think step by step before answering. + +{Question} +""".strip() + +from datasets import load_dataset, Dataset, concatenate_datasets +def generate_requests(dataset: Dataset, field_name: str, data_format: str, trial: int = 1, rank: int = 0, world_size: int = 1): + requests = [] + + # Step 1: Expand dataset trial times + if trial > 1: + dataset = Dataset.from_dict(dataset.to_dict().copy())  # ensure copy + datasets = [dataset] * trial + dataset = concatenate_datasets(datasets) + + total = len(dataset) + + # Step 2: Partition across ranks + per_proc = total // world_size + remainder = total % world_size + start = rank * per_proc + min(rank, remainder) + end = start + per_proc + (1 if rank < remainder else 0) + subset = dataset.select(list(range(start, end))) + + # Step 3: Format requests + for data in dataset: + conversations = [ + {"role": "user", "content": data_format.format(Question=data[field_name])} + ] + data["conversations"] = conversations + requests.append(data) + + return requests + + +def main(): + model_name = "Qwen/Qwen3-0.6B" + llm = sgl.Engine(model_path=model_name, + disable_cuda_graph=False, + page_size=16, + vortex_num_selected_pages=29, + disable_overlap_schedule=True, + attention_backend="flashinfer", + enable_vortex_sparsity=True, + vortex_page_reserved_bos=1, + vortex_page_reserved_eos=2, + vortex_layers_skip=list(range(1)), + mem_fraction_static=0.9, + vortex_cg=True, + vortex_graph=True, + vortex_module_name="block_sparse_attention", + vortex_max_seq_lens=20480 + ) + + dataset = load_dataset("HuggingFaceH4/aime_2024", split="train") + + requests = generate_requests(dataset, "problem", MATH_QUERY_TEMPLATE) + + + + texts = [ + x["conversations"] for x in requests + ] + + tokenizer = AutoTokenizer.from_pretrained(model_name) + prompts = [ + tokenizer.apply_chat_template( + text, + tokenize=False, + add_generation_prompt=True, + enable_thinking=True + ) for text in texts + ] * 8 + + sampling_params = {"temperature": 0.6, "top_p": 0.95, "top_k": 20, "max_new_tokens": 16384} + total_tokens = 0 + total_time = 0.0 + start = time.perf_counter() + o = llm.generate(prompts, sampling_params) + elapsed = time.perf_counter() - start + total_time += elapsed + e2e_time = 0 + with open(f"0.6B_VTX_CG_TP1_16K.jsonl", "w", encoding="utf-8") as f: + for item in o: + total_tokens += item["meta_info"]["completion_tokens"] + e2e_time = max(e2e_time, item["meta_info"]["e2e_latency"]) + json.dump(item, f, ensure_ascii=False) + f.write("\n") + + meta_data = {"e2e_time": e2e_time, "total_time": total_time, "total_tokens": total_tokens, "throughput": total_tokens / total_time} + json.dump(meta_data, f, ensure_ascii=False) + f.write("\n") + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/examples/verify_algo.py b/examples/verify_algo.py index e290a81..a78f1e6 100644 --- a/examples/verify_algo.py +++ b/examples/verify_algo.py @@ -11,7 +11,11 @@ from lighteval.models.model_output import ModelResponse from datasets import load_dataset, Dataset, concatenate_datasets import argparse +import ast import json +import os +import subprocess +import sys MATH_QUERY_TEMPLATE = """ Solve the following math problem efficiently and clearly. The last line of your response should be of the following format: 'Therefore, the final answer is: $\\boxed{{ANSWER}}$. I hope it is correct' (without quotes) where ANSWER is just the final number or expression that solves the problem. Think step by step before answering. @@ -47,6 +51,63 @@ def generate_requests(dataset: Dataset, field_name: str, data_format: str, trial return requests +BENCHMARK_REGISTRY = { + "amc23": { + "type": "jsonl", + "path": "amc23.jsonl", + "prompt_key": "prompt", + "answer_key": "answer", + "question_key": "question", + }, + "aime24": { + "type": "huggingface", + "path": "HuggingFaceH4/aime_2024", + "split": "train", + "field_name": "problem", + "answer_key": "answer", + }, +} + +def _load_benchmark(benchmark_name: str, trials: int, tokenizer=None): + """Load benchmark data and return (prompts, requests) tuple.""" + cfg = BENCHMARK_REGISTRY[benchmark_name] + + if cfg["type"] == "jsonl": + script_dir = os.path.dirname(os.path.abspath(__file__)) + jsonl_path = os.path.join(script_dir, cfg["path"]) + with open(jsonl_path, "r", encoding="utf-8") as f: + requests = [json.loads(line) for line in f] + requests = requests * trials + prompts = [req[cfg["prompt_key"]] for req in requests] + return prompts, requests + + elif cfg["type"] == "huggingface": + dataset = load_dataset(cfg["path"], split=cfg["split"]) + hf_requests = generate_requests(dataset, cfg["field_name"], MATH_QUERY_TEMPLATE) + # Normalize keys: ensure "question" and "answer" exist + for req in hf_requests: + if "question" not in req and cfg["field_name"] in req: + req["question"] = req[cfg["field_name"]] + # Build chat-template prompts if tokenizer is provided + if tokenizer is not None: + texts = [x["conversations"] for x in hf_requests] + prompts = [ + tokenizer.apply_chat_template( + text, tokenize=False, add_generation_prompt=True, enable_thinking=True + ) for text in texts + ] * trials + hf_requests = hf_requests * trials + else: + prompts = [ + MATH_QUERY_TEMPLATE.format(Question=x[cfg["field_name"]]) for x in hf_requests + ] * trials + hf_requests = hf_requests * trials + return prompts, hf_requests + + else: + raise ValueError(f"Unknown benchmark type: {cfg['type']}") + + def verify_algos( trials: int = 2, topk_val: int = 30, @@ -54,13 +115,19 @@ def verify_algos( vortex_module_name: str = "gqa_block_sparse_attention", model_name: str = "Qwen/Qwen3-1.7B", sparse_attention: bool = True, -mem: float = 0.8 -): +mem: float = 0.8, +kv_cache_dtype: str = "auto", +topk_type: str = "naive", +topk_mapping_mode: int = 0, +topk_mapping_hparam: float = 0.5, +disable_cuda_graph: bool = False, +benchmark: str = "amc23", +): - llm = sgl.Engine(model_path=model_name, - disable_cuda_graph=False, + llm = sgl.Engine(model_path=model_name, + disable_cuda_graph=disable_cuda_graph, page_size=page_size, - vortex_topk_val=topk_val, + vortex_topk_val=topk_val, disable_overlap_schedule=True, attention_backend="flashinfer", enable_vortex_sparsity=sparse_attention, @@ -69,17 +136,17 @@ def verify_algos( vortex_layers_skip=list(range(1)), vortex_module_name=vortex_module_name, vortex_max_seq_lens=12288, - mem_fraction_static=mem + mem_fraction_static=mem, + kv_cache_dtype=kv_cache_dtype, + vortex_topk_type=topk_type, + vortex_topk_mapping_mode=topk_mapping_mode, + vortex_topk_mapping_hparam=topk_mapping_hparam, ) - - with open("examples/amc23.jsonl", "r", encoding="utf-8") as f: - requests = [json.loads(line) for line in f] - - requests = requests * trials - prompts = [req["prompt"] for req in requests] + tokenizer = AutoTokenizer.from_pretrained(model_name) if benchmark != "amc23" else None + prompts, requests = _load_benchmark(benchmark, trials, tokenizer=tokenizer) sampling_params = {"temperature": 0.6, "top_p": 0.95, "top_k": 20, "max_new_tokens": 8192} - + o = llm.generate(prompts, sampling_params) gold_metric = MultilingualExtractiveMatchMetric( language=Language.ENGLISH, @@ -89,7 +156,7 @@ def verify_algos( pred_extraction_target=(ExprExtractionConfig(), LatexExtractionConfig(boxed_match_priority=0)), aggregation_function=max, ) - + results = [] for data, item in zip(requests, o): golds = [data["answer"]] @@ -99,7 +166,7 @@ def verify_algos( result = gold_metric.compute(model_response=ModelResponse(text=[predictions]), doc=target) except: result = 0.0 - + results.append( { "score": float(result), @@ -110,7 +177,15 @@ def verify_algos( "num_tokens": item["meta_info"]["completion_tokens"] } ) - + # --- Per-question debug output --- + # print(f"[Q{len(results):03d}] score={float(result):.1f} " + # f"tokens={item['meta_info']['completion_tokens']} " + # f"latency={item['meta_info']['e2e_latency']:.2f}s " + # f"gold={golds[0]}") + # print(f" question: {data['question'][:120]}...") + # print(f" prediction: {predictions[:200]}...") + # print() + total_accuracy = 0.0 total_tokens = 0 @@ -130,12 +205,17 @@ def verify_algos( if sparse_attention: llm_cfg = AutoConfig.from_pretrained(model_name) - flow = vortex_torch.flow.build_vflow(vortex_module_name) - memory_access_runtime = flow.run_indexer_virtual( - group_size=llm_cfg.num_attention_heads // llm_cfg.num_key_value_heads, - page_size=page_size, - head_dim=llm_cfg.head_dim, - ) + flow = vortex_torch.flow.build_vflow(vortex_module_name) + try: + memory_access_runtime = flow.run_indexer_virtual( + group_size=llm_cfg.num_attention_heads // llm_cfg.num_key_value_heads, + page_size=page_size, + head_dim=llm_cfg.head_dim, + ) + except Exception: + # External algorithms (nsa, fsa, flash_moba) override run_indexer_virtual + # to return 0 since their vendored kernels don't participate in vortex profiling + memory_access_runtime = 0.0 else: memory_access_runtime = 0.0 @@ -203,20 +283,76 @@ def parse_args(): default=0.8, help="memory fraction in sglang", ) + + parser.add_argument( + "--kv-cache-dtype", + type=str, + default="auto", + choices=["auto", "fp8_e5m2", "fp8_e4m3", "int8"], + help='KV cache dtype (default: "auto").', + ) + + parser.add_argument( + "--topk-type", + type=str, + default="naive", + choices=["naive", "sglang", "sglang_fused"], + help='TopK kernel type: "naive" (CUB radix), "sglang" (unmapped baseline), "sglang_fused" (fused remap + topk). Default: "naive".', + ) + parser.add_argument( + "--topk-mapping-mode", + type=int, + default=0, + choices=[0, 1, 2, 3, 4, 6, 7, 8, 9, 10, 11, 13], + help='TopK mapping mode for sglang_fused: 0=none, 1=lut_cdf (calibrated), ' + '2=quantile (calibrated), 3=power, 4=log, 6=asinh, 7=log1p, 8=trunc8, ' + '9=erf, 10=tanh, 11=subtract, 13=exp_stretch (default: 0).', + ) + + parser.add_argument( + "--topk-mapping-hparam", "--topk-mapping-power", + type=float, + default=0.5, + dest="topk_mapping_hparam", + help='Hyperparameter for parametric modes: power exponent (mode 3), beta (mode 6), alpha (mode 7/9/10/13), rho (mode 12/14). Default: 0.5.', + ) + + parser.add_argument( + "--benchmark", + type=str, + nargs="+", + default=["amc23"], + help="Benchmark(s) to run. Available: amc23, aime24. " + "Use multiple values to run several benchmarks sequentially (default: amc23).", + ) + return parser.parse_args() if __name__ == "__main__": args = parse_args() - summary = verify_algos( - trials=args.trials, - topk_val=args.topk_val, - page_size=args.page_size, - vortex_module_name=args.vortex_module_name, - model_name=args.model_name, - sparse_attention=not(args.full_attention), - mem=args.mem - ) - print(summary) + for bench_name in args.benchmark: + if bench_name not in BENCHMARK_REGISTRY: + print(f"WARNING: Unknown benchmark '{bench_name}', skipping. Available: {list(BENCHMARK_REGISTRY.keys())}") + continue + print(f"\n{'='*60}") + print(f"Benchmark: {bench_name}") + print(f"{'='*60}") + summary = verify_algos( + trials=args.trials, + topk_val=args.topk_val, + page_size=args.page_size, + vortex_module_name=args.vortex_module_name, + model_name=args.model_name, + sparse_attention=not(args.full_attention), + mem=args.mem, + kv_cache_dtype=args.kv_cache_dtype, + topk_type=args.topk_type, + topk_mapping_mode=args.topk_mapping_mode, + topk_mapping_hparam=args.topk_mapping_hparam, + benchmark=bench_name, + ) + summary["benchmark"] = bench_name + print(summary) exit(0) \ No newline at end of file diff --git a/examples/verify_algo.sh b/examples/verify_algo.sh index 17c2a5e..7a96d1e 100644 --- a/examples/verify_algo.sh +++ b/examples/verify_algo.sh @@ -1,17 +1,26 @@ #!/usr/bin/env bash set -e +# use CUDA_VISIBLE_DEVICES to set the GPU id you want to use +export CUDA_VISIBLE_DEVICES=5 sparse_algos=( - + "block_sparse_attention" ) -for algo in "${sparse_algos[@]}"; do - echo ">>> Running verify_algo.py with --vortex-module-name ${algo}" - python examples/verify_algo.py \ - --trials 8 \ - --topk-val 30 \ - --vortex-module-name "${algo}" \ - --model-name Qwen/Qwen3-1.7B \ - --mem 0.7 -done +RESULTS_DIR="results" +mkdir -p "${RESULTS_DIR}" +TIMESTAMP=$(date +%Y%m%d_%H%M%S) + for algo in "${sparse_algos[@]}"; do + OUTFILE="${RESULTS_DIR}/${algo}_bf16_${TIMESTAMP}.log" + echo ">>> Running verify_algo.py with --vortex-module-name ${algo} --kv-cache-dtype bf16" + echo ">>> Saving results to ${OUTFILE}" + { time python verify_algo.py \ + --trials 8 \ + --topk-val 30 \ + --vortex-module-name "${algo}" \ + --model-name Qwen/Qwen3-1.7B \ + --topk-type naive \ + --mem 0.7 ; } \ + 2>&1 | tee "${OUTFILE}" + done \ No newline at end of file diff --git a/examples/verify_algo_quant.sh b/examples/verify_algo_quant.sh new file mode 100644 index 0000000..a2663e9 --- /dev/null +++ b/examples/verify_algo_quant.sh @@ -0,0 +1,27 @@ +#!/usr/bin/env bash +set -e +# use CUDA_VISIBLE_DEVICES to set the GPU id you want to use +export CUDA_VISIBLE_DEVICES=5 + +sparse_algos=( + "block_sparse_attention" +) + +RESULTS_DIR="results" +mkdir -p "${RESULTS_DIR}" +TIMESTAMP=$(date +%Y%m%d_%H%M%S) + + for algo in "${sparse_algos[@]}"; do + OUTFILE="${RESULTS_DIR}/${algo}_int8_${TIMESTAMP}.log" + echo ">>> Running verify_algo.py with --vortex-module-name ${algo} --kv-cache-dtype int8" + echo ">>> Saving results to ${OUTFILE}" + { time python verify_algo.py \ + --trials 8 \ + --topk-val 30 \ + --vortex-module-name "${algo}" \ + --model-name Qwen/Qwen3-1.7B \ + --kv-cache-dtype int8 \ + --topk-type naive \ + --mem 0.7 ; } \ + 2>&1 | tee "${OUTFILE}" + done \ No newline at end of file diff --git a/examples/verify_algo_topk.sh b/examples/verify_algo_topk.sh new file mode 100644 index 0000000..6b2744a --- /dev/null +++ b/examples/verify_algo_topk.sh @@ -0,0 +1,44 @@ +#!/usr/bin/env bash +set -e +export CUDA_VISIBLE_DEVICES=5 + +sparse_algos=( + "block_sparse_attention" +) + +RESULTS_DIR="results" +REPEAT_COUNT="${REPEAT_COUNT:-3}" +mkdir -p "${RESULTS_DIR}" +TIMESTAMP=$(date +%Y%m%d_%H%M%S) + +for repeat_idx in $(seq 1 "${REPEAT_COUNT}"); do + for algo in "${sparse_algos[@]}"; do + OUTFILE="${RESULTS_DIR}/${algo}_naive_${TIMESTAMP}_run${repeat_idx}.log" + echo ">>> Run ${repeat_idx}/${REPEAT_COUNT}: verify_algo.py with --vortex-module-name ${algo} --topk-type naive" + echo ">>> Saving results to ${OUTFILE}" + { time python verify_algo.py \ + --trials 8 \ + --topk-val 30 \ + --vortex-module-name "${algo}" \ + --model-name Qwen/Qwen3-1.7B \ + --topk-type naive \ + --mem 0.7 ; } \ + 2>&1 | tee "${OUTFILE}" + done +done + +for repeat_idx in $(seq 1 "${REPEAT_COUNT}"); do + for algo in "${sparse_algos[@]}"; do + OUTFILE="${RESULTS_DIR}/${algo}_sglang_${TIMESTAMP}_run${repeat_idx}.log" + echo ">>> Run ${repeat_idx}/${REPEAT_COUNT}: verify_algo.py with --vortex-module-name ${algo} --topk-type sglang" + echo ">>> Saving results to ${OUTFILE}" + { time python verify_algo.py \ + --trials 8 \ + --topk-val 30 \ + --vortex-module-name "${algo}" \ + --model-name Qwen/Qwen3-1.7B \ + --topk-type sglang \ + --mem 0.7 ; } \ + 2>&1 | tee "${OUTFILE}" + done +done \ No newline at end of file diff --git a/examples/verify_algo_topk_mapping.sh b/examples/verify_algo_topk_mapping.sh new file mode 100644 index 0000000..f361e59 --- /dev/null +++ b/examples/verify_algo_topk_mapping.sh @@ -0,0 +1,204 @@ +#!/usr/bin/env bash +# ============================================================ +# E2E accuracy comparison: naive baseline + unmapped sglang + +# every surviving parametric mapping mode (3, 4, 6, 7, 9, 10, 13) +# with per-mode hyperparameters picked by autotune_topk_mapping.py +# (ranked by measured fused-topk kernel latency, lowest wins). +# +# Surviving mapping modes after the lean refactor: +# 0: None — unmapped baseline +# 3: Power — y = sign(x) * |x|^p +# 4: Log — y = sign(x) * log(|x| + 1) +# 6: Asinh — y = asinh(beta * x) +# 7: Log1p — y = sign(x) * log1p(alpha * |x|) +# 9: Erf — y = erf(alpha * x) +# 10: Tanh — y = tanh(alpha * x) +# 13: ExpStretch — y = exp(alpha * x) +# ============================================================ +set -e + +# ── Defaults ────────────────────────────────────────────────── +GPU_ID=0 +BENCHMARKS="amc23" +MODEL_NAME="Qwen/Qwen3-1.7B" +TOPK_VAL=30 +BLOCK_SIZE=16 +BATCH_SIZE=4 +NUM_KV_HEADS=2 +SEQ_LEN=32768 +MAX_TOTAL_TOKENS=1048576 +REAL_HISTOGRAMS="" +SKIP_AUTOTUNE=0 + +while [[ $# -gt 0 ]]; do + case "$1" in + --gpu) GPU_ID="$2"; shift 2 ;; + --benchmark) BENCHMARKS="$2"; shift 2 ;; + --model-name) MODEL_NAME="$2"; shift 2 ;; + --topk-val) TOPK_VAL="$2"; shift 2 ;; + --block-size|--page-size) BLOCK_SIZE="$2"; shift 2 ;; + --batch-size) BATCH_SIZE="$2"; shift 2 ;; + --num-kv-heads) NUM_KV_HEADS="$2"; shift 2 ;; + --seq-len) SEQ_LEN="$2"; shift 2 ;; + --real-histograms) REAL_HISTOGRAMS="$2"; shift 2 ;; + --max-total-tokens) MAX_TOTAL_TOKENS="$2"; shift 2 ;; + --skip-autotune) SKIP_AUTOTUNE=1; shift 1 ;; + *) echo "Unknown option: $1"; exit 1 ;; + esac +done + +export CUDA_VISIBLE_DEVICES="${GPU_ID}" + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +BENCH_DIR="${SCRIPT_DIR}/../benchmarks" + +sparse_algos=( "block_sparse_attention" ) + +BENCH_LABEL=$(echo "${BENCHMARKS}" | tr ' ' '_') +MODEL_SLUG="$(echo "${MODEL_NAME}" | tr '/' '_')" +RESULTS_DIR="results/${MODEL_SLUG}_${BENCH_LABEL}" +mkdir -p "${RESULTS_DIR}" +TIMESTAMP=$(date +%Y%m%d_%H%M%S) + +# ============================================================ +# Baseline: naive topk +# ============================================================ +for algo in "${sparse_algos[@]}"; do + OUTFILE="${RESULTS_DIR}/topk_mapping_${algo}_naive_${TIMESTAMP}.log" + echo ">>> naive topk algo=${algo}" + { time python verify_algo.py \ + --trials 8 \ + --topk-val "${TOPK_VAL}" \ + --vortex-module-name "${algo}" \ + --model-name "${MODEL_NAME}" \ + --topk-type naive \ + --benchmark ${BENCHMARKS} \ + --mem 0.7 ; } \ + 2>&1 | tee "${OUTFILE}" +done + +# ============================================================ +# Calibrate (optional) — real-distribution histograms +# ============================================================ +if [ -z "${REAL_HISTOGRAMS}" ]; then + CALIBRATION_DIR="${RESULTS_DIR}/calibration_${TIMESTAMP}" + echo ">>> Max total tokens (KV / VTX cap): ${MAX_TOTAL_TOKENS}" + for algo in "${sparse_algos[@]}"; do + echo ">>> Calibrating ${MODEL_NAME} for ${algo}..." + python "${BENCH_DIR}/calibrate_topk.py" \ + --model-name "${MODEL_NAME}" \ + --topk-val "${TOPK_VAL}" \ + --mem 0.7 \ + --max-total-tokens "${MAX_TOTAL_TOKENS}" \ + --vortex-module-name "${algo}" \ + --page-size "${BLOCK_SIZE}" \ + --output-dir "${CALIBRATION_DIR}" \ + 2>&1 | tee "${RESULTS_DIR}/calibration_${algo}_${TIMESTAMP}.log" + done + REAL_HISTOGRAMS="${CALIBRATION_DIR}/raw_histograms.npy" +fi + +# Pick up lut.npy / quantiles.npy if calibration produced them. +CALIB_DIR="$(dirname "${REAL_HISTOGRAMS}")" +LUT_PATH="" +Q_PATH="" +[ -f "${CALIB_DIR}/lut.npy" ] && LUT_PATH="${CALIB_DIR}/lut.npy" +[ -f "${CALIB_DIR}/quantiles.npy" ] && Q_PATH="${CALIB_DIR}/quantiles.npy" + +# ============================================================ +# Auto-tune — rank by fused-topk kernel latency +# ============================================================ +AUTOTUNE_JSON="${RESULTS_DIR}/autotune_${TIMESTAMP}.json" +if [ "${SKIP_AUTOTUNE}" -eq 0 ]; then + AUTOTUNE_EXTRA=() + [ -n "${LUT_PATH}" ] && AUTOTUNE_EXTRA+=(--lut-path "${LUT_PATH}") + [ -n "${Q_PATH}" ] && AUTOTUNE_EXTRA+=(--quantiles-path "${Q_PATH}") + if [ -f "${REAL_HISTOGRAMS}" ]; then + echo "============================================================" + echo "Auto-tuning hyperparameters (real distribution, latency-ranked)" + echo "============================================================" + PYTHONPATH="${SCRIPT_DIR}/.." python "${BENCH_DIR}/autotune_topk_mapping.py" \ + --topk-val "${TOPK_VAL}" \ + --batch-size "${BATCH_SIZE}" \ + --seq-len "${SEQ_LEN}" \ + --num-kv-heads "${NUM_KV_HEADS}" \ + --page-size "${BLOCK_SIZE}" \ + --real-histograms "${REAL_HISTOGRAMS}" \ + "${AUTOTUNE_EXTRA[@]}" \ + --output-json "${AUTOTUNE_JSON}" \ + 2>&1 | tee "${RESULTS_DIR}/autotune_${TIMESTAMP}.log" + echo ">>> Auto-tune results saved to ${AUTOTUNE_JSON}" + else + echo ">>> WARNING: ${REAL_HISTOGRAMS} not found — autotune will use synthetic data" + PYTHONPATH="${SCRIPT_DIR}/.." python "${BENCH_DIR}/autotune_topk_mapping.py" \ + --topk-val "${TOPK_VAL}" \ + --batch-size "${BATCH_SIZE}" \ + --seq-len "${SEQ_LEN}" \ + --num-kv-heads "${NUM_KV_HEADS}" \ + --page-size "${BLOCK_SIZE}" \ + "${AUTOTUNE_EXTRA[@]}" \ + --output-json "${AUTOTUNE_JSON}" \ + 2>&1 | tee "${RESULTS_DIR}/autotune_${TIMESTAMP}.log" + fi +fi + +# Extract best per-mode hparam (ranked by kernel latency, lowest wins). +eval "$(python3 -c " +import json, sys +data = json.load(open(sys.argv[1])) +best = {} +for r in data: + m = r.get('mode'); lat = r.get('latency_ms') + if m is None or lat is None: continue + if m not in best or lat < best[m]['latency_ms']: + best[m] = r +for m in (3, 6, 7, 9, 10, 11, 13): + print(f'BEST_HPARAM_{m}={best.get(m, {}).get(\"param\", 0.5)}') +" "${AUTOTUNE_JSON}")" +echo ">>> Autotuned hparams (lowest fused-topk latency):" +echo " mode3=${BEST_HPARAM_3} mode6=${BEST_HPARAM_6} mode7=${BEST_HPARAM_7}" +echo " mode9=${BEST_HPARAM_9} mode10=${BEST_HPARAM_10} mode11=${BEST_HPARAM_11} mode13=${BEST_HPARAM_13}" +echo "" + +run_mapped() { + # $1=mode $2=hparam $3=label + local mode="$1"; local hp="$2"; local label="$3" + for algo in "${sparse_algos[@]}"; do + local out="${RESULTS_DIR}/topk_mapping_${algo}_${label}_${TIMESTAMP}.log" + echo ">>> ${label} algo=${algo}" + local extra=() + if [ "${mode}" -eq 0 ]; then + extra+=(--topk-type sglang) + else + extra+=(--topk-type sglang_fused --topk-mapping-mode "${mode}" --topk-mapping-hparam "${hp}") + fi + { time python verify_algo.py \ + --trials 8 \ + --topk-val "${TOPK_VAL}" \ + --vortex-module-name "${algo}" \ + --model-name "${MODEL_NAME}" \ + --benchmark ${BENCHMARKS} \ + --mem 0.7 \ + "${extra[@]}" ; } \ + 2>&1 | tee "${out}" + done +} + +run_mapped 0 0.5 "sglang_m0" +run_mapped 3 "${BEST_HPARAM_3}" "sglang_m3_p${BEST_HPARAM_3}" +run_mapped 4 0.5 "sglang_m4" +run_mapped 6 "${BEST_HPARAM_6}" "sglang_m6_beta${BEST_HPARAM_6}" +run_mapped 7 "${BEST_HPARAM_7}" "sglang_m7_alpha${BEST_HPARAM_7}" +run_mapped 8 0.5 "sglang_m8" +run_mapped 9 "${BEST_HPARAM_9}" "sglang_m9_alpha${BEST_HPARAM_9}" +run_mapped 10 "${BEST_HPARAM_10}" "sglang_m10_alpha${BEST_HPARAM_10}" +run_mapped 11 "${BEST_HPARAM_11}" "sglang_m11_pivot${BEST_HPARAM_11}" +run_mapped 13 "${BEST_HPARAM_13}" "sglang_m13_alpha${BEST_HPARAM_13}" + +echo "" +echo "============================================================" +echo "All runs complete. Results in ${RESULTS_DIR}/" +echo " Model: ${MODEL_NAME}" +echo " Block size: ${BLOCK_SIZE}" +echo " Auto-tune: ${AUTOTUNE_JSON}" +echo "============================================================" diff --git a/examples/verify_algo_topk_mapping_new.sh b/examples/verify_algo_topk_mapping_new.sh new file mode 100644 index 0000000..2cdc526 --- /dev/null +++ b/examples/verify_algo_topk_mapping_new.sh @@ -0,0 +1,221 @@ +#!/usr/bin/env bash +# ============================================================ +# E2E accuracy sweep over the surviving parametric mapping modes. +# Each mode runs verify_algo.py with the per-mode hyperparameter +# that autotune_topk_mapping.py picked as having the lowest +# measured fused-topk-kernel latency. +# +# Mapping modes (after the lean refactor): +# 0: None — unmapped baseline (no remap) +# 3: Power — y = sign(x) * |x|^p +# 4: Log — y = sign(x) * log(|x| + 1) [no knob] +# 6: Asinh — y = asinh(beta * x) +# 7: Log1p — y = sign(x) * log1p(alpha * |x|) +# 9: Erf — y = erf(alpha * x) +# 10: Tanh — y = tanh(alpha * x) +# 13: ExpStretch — y = exp(alpha * x) +# ============================================================ +set -e +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +BENCH_DIR="${SCRIPT_DIR}/../benchmarks" + +# ── Defaults ────────────────────────────────────────────────── +GPU_ID=5 +TOPK_VAL=30 +BENCHMARKS="amc23" +MODEL_NAME="Qwen/Qwen3-1.7B" +BLOCK_SIZE=16 +BATCH_SIZE=4 +NUM_KV_HEADS=2 +SEQ_LEN=32768 +MAX_TOTAL_TOKENS=1048576 +REAL_HISTOGRAMS="" +SKIP_AUTOTUNE=0 + +while [[ $# -gt 0 ]]; do + case "$1" in + --topk-val) TOPK_VAL="$2"; shift 2 ;; + --gpu) GPU_ID="$2"; shift 2 ;; + --benchmark) BENCHMARKS="$2"; shift 2 ;; + --model-name) MODEL_NAME="$2"; shift 2 ;; + --block-size|--page-size) BLOCK_SIZE="$2"; shift 2 ;; + --batch-size) BATCH_SIZE="$2"; shift 2 ;; + --num-kv-heads) NUM_KV_HEADS="$2"; shift 2 ;; + --seq-len) SEQ_LEN="$2"; shift 2 ;; + --real-histograms) REAL_HISTOGRAMS="$2"; shift 2 ;; + --max-total-tokens) MAX_TOTAL_TOKENS="$2"; shift 2 ;; + --skip-autotune) SKIP_AUTOTUNE=1; shift 1 ;; + *) echo "Unknown option: $1"; exit 1 ;; + esac +done + +export CUDA_VISIBLE_DEVICES="${GPU_ID}" + +sparse_algos=( "block_sparse_attention" ) + +BENCH_LABEL=$(echo "${BENCHMARKS}" | tr ' ' '_') +MODEL_SLUG="$(echo "${MODEL_NAME}" | tr '/' '_')" +RESULTS_DIR="results/topk_mapping_${MODEL_SLUG}_topk${TOPK_VAL}_${BENCH_LABEL}" +mkdir -p "${RESULTS_DIR}" +TIMESTAMP=$(date +%Y%m%d_%H%M%S) + +# ============================================================ +# Step 0: Calibrate (optional) — real-distribution histograms +# ============================================================ +if [ -z "${REAL_HISTOGRAMS}" ]; then + echo "============================================================" + echo "Step 0: Calibrating ${MODEL_NAME} for real-distribution histograms" + echo " Max total tokens (KV / VTX cap): ${MAX_TOTAL_TOKENS}" + echo "============================================================" + CAL_DIR="${RESULTS_DIR}/calibration_${TIMESTAMP}" + mkdir -p "${CAL_DIR}" + python "${BENCH_DIR}/calibrate_topk.py" \ + --model-name "${MODEL_NAME}" \ + --topk-val "${TOPK_VAL}" \ + --mem 0.7 \ + --max-total-tokens "${MAX_TOTAL_TOKENS}" \ + --vortex-module-name "${sparse_algos[0]}" \ + --page-size "${BLOCK_SIZE}" \ + --output-dir "${CAL_DIR}" \ + 2>&1 | tee "${RESULTS_DIR}/calibrate_${TIMESTAMP}.log" + REAL_HISTOGRAMS="${CAL_DIR}/raw_histograms.npy" +fi + +# Pick up lut.npy / quantiles.npy if calibration produced them. +CALIB_DIR="$(dirname "${REAL_HISTOGRAMS}")" +LUT_PATH="" +Q_PATH="" +[ -f "${CALIB_DIR}/lut.npy" ] && LUT_PATH="${CALIB_DIR}/lut.npy" +[ -f "${CALIB_DIR}/quantiles.npy" ] && Q_PATH="${CALIB_DIR}/quantiles.npy" + +# ============================================================ +# Step 1: Auto-tune — rank by profiled fused-topk kernel latency +# ============================================================ +AUTOTUNE_JSON="${RESULTS_DIR}/autotune_${TIMESTAMP}.json" +if [ "${SKIP_AUTOTUNE}" -eq 0 ]; then + echo "============================================================" + echo "Step 1: Auto-tuning hyperparameters by fused-topk kernel latency" + echo "============================================================" + AUTOTUNE_EXTRA=() + [ -n "${LUT_PATH}" ] && AUTOTUNE_EXTRA+=(--lut-path "${LUT_PATH}") + [ -n "${Q_PATH}" ] && AUTOTUNE_EXTRA+=(--quantiles-path "${Q_PATH}") + PYTHONPATH="${SCRIPT_DIR}/.." python "${BENCH_DIR}/autotune_topk_mapping.py" \ + --topk-val ${TOPK_VAL} \ + --batch-size ${BATCH_SIZE} \ + --seq-len ${SEQ_LEN} \ + --num-kv-heads ${NUM_KV_HEADS} \ + --page-size ${BLOCK_SIZE} \ + --real-histograms "${REAL_HISTOGRAMS}" \ + "${AUTOTUNE_EXTRA[@]}" \ + --output-json "${AUTOTUNE_JSON}" \ + 2>&1 | tee "${RESULTS_DIR}/autotune_${TIMESTAMP}.log" + echo ">>> Auto-tune results saved to ${AUTOTUNE_JSON}" +fi + +# Extract best per-mode hparam (ranked by measured kernel latency, lowest wins) +eval "$(python3 -c " +import json, sys +data = json.load(open(sys.argv[1])) +best = {} +for r in data: + m = r.get('mode') + lat = r.get('latency_ms') + if m is None or lat is None: continue + if m not in best or lat < best[m]['latency_ms']: + best[m] = r +for m in (3, 6, 7, 9, 10, 11, 13): + v = best.get(m, {}).get('param', 0.5) + print(f'BEST_HPARAM_{m}={v}') +" "${AUTOTUNE_JSON}")" +echo ">>> Autotuned hparams (lowest topk kernel latency):" +echo " mode3=${BEST_HPARAM_3} mode6=${BEST_HPARAM_6} mode7=${BEST_HPARAM_7}" +echo " mode9=${BEST_HPARAM_9} mode10=${BEST_HPARAM_10} mode11=${BEST_HPARAM_11} mode13=${BEST_HPARAM_13}" +echo "" + +run_verify() { + # $1=mode $2=hparam $3=label + local mode="$1"; local hp="$2"; local label="$3" + for algo in "${sparse_algos[@]}"; do + local out="${RESULTS_DIR}/topk_mapping_${algo}_${label}_${TIMESTAMP}.log" + echo ">>> ${label} algo=${algo}" + local extra_args=() + if [ "${mode}" -eq 0 ]; then + extra_args+=(--topk-type sglang) + else + extra_args+=(--topk-type sglang_fused --topk-mapping-mode "${mode}" --topk-mapping-hparam "${hp}") + fi + { time python verify_algo.py \ + --trials 8 \ + --topk-val "${TOPK_VAL}" \ + --vortex-module-name "${algo}" \ + --model-name "${MODEL_NAME}" \ + --benchmark ${BENCHMARKS} \ + --mem 0.7 \ + "${extra_args[@]}" ; } \ + 2>&1 | tee "${out}" + done +} + +echo "============================================================" +echo "Baseline: sglang (no remap)" +echo "============================================================" +run_verify 0 0.5 "sglang_m0" + +echo "============================================================" +echo "Mode 3 (power) — p=${BEST_HPARAM_3} (autotuned)" +echo "============================================================" +run_verify 3 "${BEST_HPARAM_3}" "sglang_m3_p${BEST_HPARAM_3}" + +echo "============================================================" +echo "Mode 4 (log)" +echo "============================================================" +run_verify 4 0.5 "sglang_m4" + +echo "============================================================" +echo "Mode 6 (asinh) — beta=${BEST_HPARAM_6} (autotuned)" +echo "============================================================" +run_verify 6 "${BEST_HPARAM_6}" "sglang_m6_beta${BEST_HPARAM_6}" + +echo "============================================================" +echo "Mode 7 (log1p) — alpha=${BEST_HPARAM_7} (autotuned)" +echo "============================================================" +run_verify 7 "${BEST_HPARAM_7}" "sglang_m7_alpha${BEST_HPARAM_7}" + +echo "============================================================" +echo "Mode 9 (erf) — alpha=${BEST_HPARAM_9} (autotuned)" +echo "============================================================" +run_verify 9 "${BEST_HPARAM_9}" "sglang_m9_alpha${BEST_HPARAM_9}" + +echo "============================================================" +echo "Mode 10 (tanh) — alpha=${BEST_HPARAM_10} (autotuned)" +echo "============================================================" +run_verify 10 "${BEST_HPARAM_10}" "sglang_m10_alpha${BEST_HPARAM_10}" + +echo "============================================================" +echo "Mode 8 (trunc8)" +echo "============================================================" +run_verify 8 0.5 "sglang_m8" + +echo "============================================================" +echo "Mode 11 (subtract) — pivot=${BEST_HPARAM_11} (autotuned)" +echo "============================================================" +run_verify 11 "${BEST_HPARAM_11}" "sglang_m11_pivot${BEST_HPARAM_11}" + +echo "============================================================" +echo "Mode 13 (exp_stretch) — alpha=${BEST_HPARAM_13} (autotuned)" +echo "============================================================" +run_verify 13 "${BEST_HPARAM_13}" "sglang_m13_alpha${BEST_HPARAM_13}" + +echo "" +echo "============================================================" +echo "All runs complete. Results in ${RESULTS_DIR}/" +echo " Model: ${MODEL_NAME}" +echo " Block size: ${BLOCK_SIZE}" +echo " Auto-tune: ${AUTOTUNE_JSON}" +echo " Mode 3 (power): p = ${BEST_HPARAM_3} (autotuned)" +echo " Mode 6 (asinh): beta = ${BEST_HPARAM_6} (autotuned)" +echo " Mode 7 (log1p): alpha = ${BEST_HPARAM_7} (autotuned)" +echo " Mode 9 (erf): alpha = ${BEST_HPARAM_9} (autotuned)" +echo " Mode 10 (tanh): alpha = ${BEST_HPARAM_10} (autotuned)" +echo " Mode 13 (exp_stretch):alpha = ${BEST_HPARAM_13} (autotuned)" +echo "============================================================" diff --git a/examples/verify_external_backends.sh b/examples/verify_external_backends.sh new file mode 100755 index 0000000..12600d0 --- /dev/null +++ b/examples/verify_external_backends.sh @@ -0,0 +1,27 @@ +#!/usr/bin/env bash +set -e +export CUDA_VISIBLE_DEVICES=6 + +sparse_algos=( + "nsa" + "fsa" + "flash_moba" +) + +RESULTS_DIR="results" +mkdir -p "${RESULTS_DIR}" +TIMESTAMP=$(date +%Y%m%d_%H%M%S) + +for algo in "${sparse_algos[@]}"; do + OUTFILE="${RESULTS_DIR}/${algo}_bf16_${TIMESTAMP}.log" + echo ">>> Running verify_algo.py with --vortex-module-name ${algo}" + echo ">>> Saving results to ${OUTFILE}" + { time python verify_algo.py \ + --trials 8 \ + --topk-val 30 \ + --vortex-module-name "${algo}" \ + --model-name Qwen/Qwen3-1.7B \ + --topk-type naive \ + --mem 0.7 ; } \ + 2>&1 | tee "${OUTFILE}" +done diff --git a/setup.py b/setup.py index e272326..c973181 100644 --- a/setup.py +++ b/setup.py @@ -16,15 +16,22 @@ sources=[ 'csrc/register.cc', 'csrc/utils_sglang.cu', - 'csrc/topk.cu' + 'csrc/topk.cu', + 'csrc/topk_sglang.cu', + 'csrc/topk_sglang_profile.cu', + 'csrc/topk_sglang_ori.cu', ], include_dirs=['csrc'], extra_compile_args={ 'cxx': ['-O3'], 'nvcc': [ '-O3', + '-gencode=arch=compute_86,code=sm_86', '-gencode=arch=compute_89,code=sm_89', - '-gencode=arch=compute_90,code=sm_90' + '-gencode=arch=compute_90,code=sm_90', + '-gencode=arch=compute_100a,code=sm_100a', + '-gencode=arch=compute_120,code=sm_120' + ], }, ), diff --git a/third_party/sglang b/third_party/sglang index e383c0f..b7825d0 160000 --- a/third_party/sglang +++ b/third_party/sglang @@ -1 +1 @@ -Subproject commit e383c0fdd551f74f24d247e8a7cc8013861949ad +Subproject commit b7825d08399fccdf1f29a5380d6601fcef59aca1 diff --git a/vortex_torch/attention_backend/__init__.py b/vortex_torch/attention_backend/__init__.py new file mode 100644 index 0000000..9ca7855 --- /dev/null +++ b/vortex_torch/attention_backend/__init__.py @@ -0,0 +1,3 @@ +# Vendored sparse attention backends for Vortex forward_extend. +# NSA and FSA are pure Triton kernels. +# FlashMoBA requires flash_moba_cuda C++ extension (pip install flash_moba). diff --git a/vortex_torch/attention_backend/flashmoba/__init__.py b/vortex_torch/attention_backend/flashmoba/__init__.py new file mode 100644 index 0000000..aa912b9 --- /dev/null +++ b/vortex_torch/attention_backend/flashmoba/__init__.py @@ -0,0 +1,13 @@ +from .flash_moba_interface import ( + flash_moba_varlen_func, + flash_moba_attn_varlen_func, + flash_topk_varlen_func, + decide_lg_block_m, +) + +__all__ = [ + "flash_moba_varlen_func", + "flash_moba_attn_varlen_func", + "flash_topk_varlen_func", + "decide_lg_block_m", +] diff --git a/vortex_torch/attention_backend/flashmoba/flash_moba_interface.py b/vortex_torch/attention_backend/flashmoba/flash_moba_interface.py new file mode 100644 index 0000000..c196c21 --- /dev/null +++ b/vortex_torch/attention_backend/flashmoba/flash_moba_interface.py @@ -0,0 +1,730 @@ +from typing import Optional, Sequence, Tuple, Union + +import torch +import torch.nn as nn +import os + +try: + import flash_moba_cuda as flash_moba_gpu +except ImportError: + flash_moba_gpu = None +from .triton_mean_pool import flash_topk_mean_pool + +########################################################################################################################## +# Helper functions +########################################################################################################################## + +def round_multiple(x: int, m: int) -> int: + """Round x up to the nearest multiple of m.""" + return ((x + m - 1) // m) * m + +########################################################################################################################## + +def decide_lg_block_m(top_k: int, chunk_size: int, seqlen: int, causal: bool = False) -> int: + sparsity = 0.0 + budget = top_k * chunk_size + if causal: + density = (2*(budget * seqlen) - budget**2) / (seqlen**2) + else: + density = budget / seqlen + + sparsity = 1 - density + + if sparsity <= 0.5: + lg_block_m = 128 + elif sparsity <= 0.7: + lg_block_m = 256 + elif sparsity <= 0.8: + lg_block_m = 512 + elif sparsity <= 0.9: + lg_block_m = 768 + else: + lg_block_m = 1024 + + # [Optimization] Hardware-aware cap for A6000/3090/4090 to avoid Shared Memory OOM + if torch.cuda.is_available(): + major, minor = torch.cuda.get_device_capability() + # sm86 (A6000, 3090) and sm89 (4090, L40) have smaller shared memory than A100 (sm80) + if major == 8 and minor > 0: + lg_block_m = min(lg_block_m, 512) + + return lg_block_m + +########################################################################################################################## + +# torch.compile() support is only enabled for pytorch >= 2.4 +# The reason for this is that we are using the new custom_op and register_fake +# APIs, which support inplace modification of inputs in the function itself +if torch.__version__ >= "2.4.0": + _torch_custom_op_wrapper = torch.library.custom_op + _torch_register_fake_wrapper = torch.library.register_fake +else: + def noop_custom_op_wrapper(name, fn=None, /, *, mutates_args, device_types=None, schema=None): + def wrap(func): + return func + if fn is None: + return wrap + return fn + def noop_register_fake_wrapper(op, fn=None, /, *, lib=None, _stacklevel=1): + def wrap(func): + return func + if fn is None: + return wrap + return fn + _torch_custom_op_wrapper = noop_custom_op_wrapper + _torch_register_fake_wrapper = noop_register_fake_wrapper + + +def maybe_contiguous(x): + return x.contiguous() if x is not None and x.stride(-1) != 1 else x + + +########################################################################################################################## +# Custom ops +########################################################################################################################## + +@_torch_custom_op_wrapper("flash_moba::_moba_fused_topk", mutates_args=(), device_types="cuda") +def _moba_fused_topk( + q: torch.Tensor, + km: torch.Tensor, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + cu_seqlens_km: torch.Tensor, + max_seqlen_q: int, + max_seqlen_k: int, + # MOBA sparse pattern parameters + moba_topk: int, + moba_chunk_size: int, + causal: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + q, km = [maybe_contiguous(x) for x in (q, km)] + + col_offsets, col_nnz, indices, _, _ = flash_moba_gpu.moba_fused_topk( + q, + km, + cu_seqlens_q, + cu_seqlens_k, + cu_seqlens_km, + max_seqlen_q, + max_seqlen_k, + moba_topk, + moba_chunk_size, + causal, + ) + return col_offsets, col_nnz, indices + +@_torch_register_fake_wrapper("flash_moba::_moba_fused_topk") +def _moba_fused_topk_fake( + q: torch.Tensor, + km: torch.Tensor, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + cu_seqlens_km: torch.Tensor, + max_seqlen_q: int, + max_seqlen_k: int, + # MOBA sparse pattern parameters + moba_topk: int, + moba_chunk_size: int, + causal: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + q, km = [maybe_contiguous(x) for x in (q, km)] + batch_size = cu_seqlens_q.numel() - 1 + total_q, num_heads, _ = q.shape + + max_lg_col_num = (max_seqlen_k + moba_chunk_size - 1) // moba_chunk_size + + col_offsets = torch.empty((batch_size, num_heads, max_lg_col_num), device=q.device, dtype=torch.int64) + col_nnz = torch.empty((batch_size, num_heads, max_lg_col_num), device=q.device, dtype=torch.int32) + indices = torch.empty((total_q * num_heads * moba_topk), device=q.device, dtype=torch.int32) + + return col_offsets, col_nnz, indices + +if torch.__version__ >= "2.4.0": + _wrapped_moba_fused_topk = torch.ops.flash_moba._moba_fused_topk +else: + _wrapped_moba_fused_topk = _moba_fused_topk + +########################################################################################################################## + +@_torch_custom_op_wrapper("flash_moba::_varlen_sort", mutates_args=(), device_types="cuda") +def _varlen_sort( + col_offsets: torch.Tensor, + col_nnz: torch.Tensor, + indices: torch.Tensor, +) -> torch.Tensor: + col_offset_ends = col_offsets.view(-1) + col_nnz.view(-1) + return flash_moba_gpu.varlen_sort( + col_offsets.view(-1), col_offset_ends, indices + ) + +@_torch_register_fake_wrapper("flash_moba::_varlen_sort") +def _varlen_sort_fake( + col_offsets: torch.Tensor, + col_nnz: torch.Tensor, + indices: torch.Tensor, +) -> torch.Tensor: + # varlen_sort is out-of-place + col_offset_ends = col_offsets.view(-1) + col_nnz.view(-1) + return torch.empty_like(indices) + +if torch.__version__ >= "2.4.0": + _wrapped_varlen_sort = torch.ops.flash_moba._varlen_sort +else: + _wrapped_varlen_sort = _varlen_sort + +########################################################################################################################## + +@_torch_custom_op_wrapper("flash_moba::_flash_moba_attn_varlen_forward", mutates_args=(), device_types="cuda") +def _flash_moba_attn_varlen_forward( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + max_seqlen_q: int, + max_seqlen_k: int, + # MOBA sparse pattern parameters + moba_col_offsets: torch.Tensor, + moba_col_nnz: torch.Tensor, + moba_row_indices: torch.Tensor, + lg_block_m: int, + lg_block_n: int, + dropout_p: float, + softmax_scale: float, + causal: bool, + softcap: float = 0.0, + alibi_slopes: Optional[torch.Tensor] = None, + return_softmax: bool = False, + leftpad_k: Optional[torch.Tensor] = None, + seqused_k: Optional[torch.Tensor] = None, + zero_tensors: bool = False, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + q, k, v = [maybe_contiguous(x) for x in (q, k, v)] + moba_col_offsets = maybe_contiguous(moba_col_offsets) + moba_col_nnz = maybe_contiguous(moba_col_nnz) + moba_row_indices = maybe_contiguous(moba_row_indices) + + out, softmax_lse, S_dmask, rng_state = flash_moba_gpu.moba_varlen_fwd( + q, + k, + v, + None, + moba_col_offsets, + moba_col_nnz, + moba_row_indices, + cu_seqlens_q, + cu_seqlens_k, + seqused_k, + leftpad_k, + alibi_slopes, + max_seqlen_q, + max_seqlen_k, + dropout_p, + softmax_scale, + zero_tensors, + causal, + softcap, + return_softmax, + lg_block_m, + lg_block_n, + None, + ) + # if out.isnan().any() or softmax_lse.isnan().any(): + # breakpoint() + return out, softmax_lse, S_dmask, rng_state + +@_torch_register_fake_wrapper("flash_moba::_flash_moba_attn_varlen_forward") +def _flash_moba_attn_varlen_forward_fake( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + max_seqlen_q: int, + max_seqlen_k: int, + # MOBA sparse pattern parameters + moba_col_offsets: torch.Tensor, + moba_col_nnz: torch.Tensor, + moba_row_indices: torch.Tensor, + lg_block_m: int, + lg_block_n: int, + dropout_p: float, + softmax_scale: float, + causal: bool, + softcap: float = 0.0, + alibi_slopes: Optional[torch.Tensor] = None, + return_softmax: bool = False, + leftpad_k: Optional[torch.Tensor] = None, + seqused_k: Optional[torch.Tensor] = None, + zero_tensors: bool = False, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + q, k, v = [maybe_contiguous(x) for x in (q, k, v)] + batch_size = cu_seqlens_q.numel() - 1 + total_q, num_heads, _ = q.shape + + out = torch.empty_like(q) + softmax_lse = torch.empty((num_heads, total_q), dtype=torch.float32, device=q.device, layout=q.layout) + p = torch.empty((0,), dtype=q.dtype, device=q.device, layout=q.layout) + seqlen_q_rounded = round_multiple(max_seqlen_q, 128) + seqlen_k_rounded = round_multiple(max_seqlen_k, 128) + if return_softmax: + p = torch.empty((batch_size, num_heads, seqlen_q_rounded, seqlen_k_rounded), dtype=q.dtype, device=q.device, layout=q.layout) + rng_state = torch.empty((2,), dtype=torch.int64, device=q.device) + return out, softmax_lse, p, rng_state + +if torch.__version__ >= "2.4.0": + _wrapped_flash_moba_attn_varlen_forward = torch.ops.flash_moba._flash_moba_attn_varlen_forward +else: + _wrapped_flash_moba_attn_varlen_forward = _flash_moba_attn_varlen_forward + +########################################################################################################################## + +@_torch_custom_op_wrapper("flash_moba::_flash_moba_attn_varlen_backward", mutates_args=("dq", "dk", "dv"), device_types="cuda") +def _flash_moba_attn_varlen_backward( + dout: torch.Tensor, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + out: torch.Tensor, + softmax_lse: torch.Tensor, + dq: Optional[torch.Tensor], + dk: Optional[torch.Tensor], + dv: Optional[torch.Tensor], + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + max_seqlen_q: int, + max_seqlen_k: int, + # MOBA sparse pattern parameters + moba_col_offsets: torch.Tensor, + moba_col_nnz: torch.Tensor, + moba_row_indices: torch.Tensor, + lg_block_m: int, + lg_block_n: int, + dropout_p: float, + softmax_scale: float, + causal: bool, + softcap: float, + alibi_slopes: Optional[torch.Tensor], + deterministic: bool, + rng_state: Optional[torch.Tensor] = None, + zero_tensors: bool = False, +) -> torch.Tensor: + # dq, dk, dv are allocated by us so they should already be contiguous + dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)] + ( + dq, + dk, + dv, + softmax_d, + ) = flash_moba_gpu.moba_varlen_bwd( + dout, + q, + k, + v, + out, + softmax_lse, + dq, + dk, + dv, + moba_col_offsets, + moba_col_nnz, + moba_row_indices, + cu_seqlens_q, + cu_seqlens_k, + alibi_slopes, + max_seqlen_q, + max_seqlen_k, + dropout_p, + softmax_scale, + zero_tensors, + causal, + softcap, + deterministic, + lg_block_m, + lg_block_n, + None, + rng_state, + ) + # if dk.isnan().any() or dk.isnan().any() or dv.isnan().any() or softmax_d.isnan().any(): + # breakpoint() + return softmax_d + +@_torch_register_fake_wrapper("flash_moba::_flash_moba_attn_varlen_backward") +def _flash_moba_attn_varlen_backward_fake( + dout: torch.Tensor, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + out: torch.Tensor, + softmax_lse: torch.Tensor, + dq: Optional[torch.Tensor], + dk: Optional[torch.Tensor], + dv: Optional[torch.Tensor], + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + max_seqlen_q: int, + max_seqlen_k: int, + # MOBA sparse pattern parameters + moba_col_offsets: torch.Tensor, + moba_col_nnz: torch.Tensor, + moba_row_indices: torch.Tensor, + lg_block_m: int, + lg_block_n: int, + dropout_p: float, + softmax_scale: float, + causal: bool, + softcap: float, + alibi_slopes: Optional[torch.Tensor], + deterministic: bool, + rng_state: Optional[torch.Tensor] = None, + zero_tensors: bool = False, +) -> torch.Tensor: + dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)] + batch_size = cu_seqlens_q.numel() - 1 + total_q, num_heads, _ = q.shape + + if dq is None: + dq = torch.empty_like(q) + if dk is None: + dk = torch.empty_like(k) + if dv is None: + dv = torch.empty_like(v) + softmax_d = torch.empty((num_heads, total_q + 128 * batch_size), device=q.device, dtype=torch.float32) + + return softmax_d + +if torch.__version__ >= "2.4.0": + _wrapped_flash_moba_attn_varlen_backward = torch.ops.flash_moba._flash_moba_attn_varlen_backward +else: + _wrapped_flash_moba_attn_varlen_backward = _flash_moba_attn_varlen_backward + +########################################################################################################################## + +class FlashMobaAttnVarlenFunc(torch.autograd.Function): + @staticmethod + def forward( + ctx, + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + # MOBA sparse pattern parameters + moba_col_offsets, + moba_col_nnz, + moba_row_indices, + lg_block_m, + lg_block_n, + dropout_p, + softmax_scale, + causal, + softcap, + alibi_slopes, + deterministic, + return_softmax, + is_grad_enabled, + ): + is_grad = is_grad_enabled and any( + x.requires_grad for x in [q, k, v] + ) + if softmax_scale is None: + softmax_scale = q.shape[-1] ** (-0.5) + head_size_og = q.size(2) + if head_size_og % 8 != 0: + q = torch.nn.functional.pad(q, [0, 8 - head_size_og % 8]) + k = torch.nn.functional.pad(k, [0, 8 - head_size_og % 8]) + v = torch.nn.functional.pad(v, [0, 8 - head_size_og % 8]) + out_padded, softmax_lse, S_dmask, rng_state = _wrapped_flash_moba_attn_varlen_forward( + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + moba_col_offsets, + moba_col_nnz, + moba_row_indices, + lg_block_m, + lg_block_n, + dropout_p, + softmax_scale, + causal=causal, + softcap=softcap, + alibi_slopes=alibi_slopes, + return_softmax=return_softmax and dropout_p > 0, + ) + if is_grad: + ctx.save_for_backward( + q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state, + moba_col_offsets, moba_col_nnz, moba_row_indices + ) + ctx.dropout_p = dropout_p + ctx.max_seqlen_q = max_seqlen_q + ctx.max_seqlen_k = max_seqlen_k + ctx.softmax_scale = softmax_scale + ctx.causal = causal + ctx.softcap = softcap + ctx.alibi_slopes = alibi_slopes + ctx.deterministic = deterministic + ctx.lg_block_m = lg_block_m + ctx.lg_block_n = lg_block_n + + out = out_padded[..., :head_size_og] + return out if not return_softmax else (out, softmax_lse, S_dmask) + + @staticmethod + def backward(ctx, dout, *args): + q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state, moba_col_offsets, moba_col_nnz, moba_row_indices = ctx.saved_tensors + dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v) + head_size_og = dout.size(2) + dout_padded = dout + if head_size_og % 8 != 0: + dout_padded = torch.nn.functional.pad(dout, [0, 8 - head_size_og % 8]) + _wrapped_flash_moba_attn_varlen_backward( + dout_padded, + q, + k, + v, + out, + softmax_lse, + dq, + dk, + dv, + cu_seqlens_q, + cu_seqlens_k, + ctx.max_seqlen_q, + ctx.max_seqlen_k, + moba_col_offsets, + moba_col_nnz, + moba_row_indices, + ctx.lg_block_m, + ctx.lg_block_n, + ctx.dropout_p, + ctx.softmax_scale, + ctx.causal, + ctx.softcap, + ctx.alibi_slopes, + ctx.deterministic, + rng_state=rng_state, + ) + dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension + dk = dk[..., : dout.shape[-1]] + dv = dv[..., : dout.shape[-1]] + return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None + +########################################################################################################################## + +def flash_topk_varlen_func( + q, + k, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + # MOBA sparse pattern parameters + moba_topk, + moba_chunk_size, + causal=False, +): + """ + Computes the top-k indices for Mixture-of-Blocks Attention (MOBA). + This function handles variable length sequences. + + Args: + q (torch.Tensor): Query tensor of shape (total_q, num_heads, head_size). + k (torch.Tensor): Key tensor of shape (total_k, num_heads, head_size). + cu_seqlens_q (torch.Tensor): Cumulative sequence lengths for queries, shape (batch_size + 1,). + cu_seqlens_k (torch.Tensor): Cumulative sequence lengths for keys, shape (batch_size + 1,). + max_seqlen_q (int): Maximum sequence length for queries. + max_seqlen_k (int): Maximum sequence length for keys. + moba_topk (int): The number of top-k elements to select. + moba_chunk_size (int): The chunk size for MOBA. + causal (bool): Whether to apply causal masking. + + Returns: + Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: A tuple containing: + - col_offsets (torch.Tensor): Column offsets for the sparse matrix. + - col_nnz (torch.Tensor): Number of non-zero elements per column block. + - indices (torch.Tensor): The top-k indices. + """ + head_size_og = q.size(2) + if head_size_og % 8 != 0: + q = torch.nn.functional.pad(q, [0, 8 - head_size_og % 8]) + k = torch.nn.functional.pad(k, [0, 8 - head_size_og % 8]) + + km, cu_seqlens_km, _ = flash_topk_mean_pool(k, cu_seqlens_k, max_seqlen_k, moba_chunk_size) + + col_offsets, col_nnz, indices = _wrapped_moba_fused_topk( + q, + km, + cu_seqlens_q, + cu_seqlens_k, + cu_seqlens_km, + max_seqlen_q, + max_seqlen_k, + moba_topk, + moba_chunk_size, + causal=causal + ) + + indices = _wrapped_varlen_sort( + col_offsets, col_nnz, indices + ) + + return col_offsets, col_nnz, indices + +########################################################################################################################## + +def flash_moba_attn_varlen_func( + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + # MOBA sparse pattern parameters + moba_col_offsets, + moba_col_nnz, + moba_row_indices, + lg_block_m=64, + lg_block_n=64, + dropout_p=0.0, + softmax_scale=None, + causal=False, + softcap=0.0, # 0.0 means deactivated + alibi_slopes=None, + deterministic=False, + return_attn_probs=False, +): + """dropout_p should be set to 0.0 during evaluation + Supports multi-query and grouped-query attention (MQA/GQA) by passing in K, V with fewer heads + than Q. Note that the number of heads in Q must be divisible by the number of heads in KV. + For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head + 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V. + + If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix. + For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is: + 1 1 1 1 0 + 1 1 1 1 1 + If seqlen_q = 5 and seqlen_k = 2, the causal mask is: + 0 0 + 0 0 + 0 0 + 1 0 + 1 1 + If the row of the mask is all zero, the output will be zero. + + If window_size != (-1, -1), implements sliding window local attention. Query at position i + will only attend to keys between + [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive. + + Arguments: + q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch. + k: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch. + v: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch. + cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths + of the sequences in the batch, used to index into q. + cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths + of the sequences in the batch, used to index into kv. + max_seqlen_q: int. Maximum query sequence length in the batch. + max_seqlen_k: int. Maximum key sequence length in the batch. + dropout_p: float. Dropout probability. + softmax_scale: float. The scaling of QK^T before applying softmax. + Default to 1 / sqrt(headdim). + causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). + window_size: (left, right). If not (-1, -1), implements sliding window local attention. + softcap: float. Anything > 0 activates softcapping attention. + alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of + (-alibi_slope * |i + seqlen_k - seqlen_q - j|) + is added to the attention score of query i and key j. + deterministic: bool. Whether to use the deterministic implementation of the backward pass, + which is slightly slower and uses more memory. The forward pass is always deterministic. + return_attn_probs: bool. Whether to return the attention probabilities. This option is for + testing only. The returned probabilities are not guaranteed to be correct + (they might not have the right scaling). + moba_col_offsets: Optional[torch.Tensor]. Column offsets for MOBA sparse pattern. + Shape: (batch_size, num_heads, max_lg_col_num), dtype: int64 + moba_col_nnz: Optional[torch.Tensor]. Non-zero counts per column for MOBA sparse pattern. + Shape: (batch_size, num_heads, max_lg_col_num), dtype: int32 + moba_row_indices: Optional[torch.Tensor]. Row indices for MOBA sparse pattern (flattened). + dtype: int32 + lg_block_m: int. Logical block size in M dimension (query). Default: 64 + lg_block_n: int. Logical block size in N dimension (key). Default: 64 + Return: + out: (total, nheads, headdim). + softmax_lse [optional, if return_attn_probs=True]: (nheads, total_q_seqlen). The + logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax + normalization factor). + S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen). + The output of softmax (possibly with different scaling). It also encodes the dropout + pattern (negative means that location was dropped, nonnegative means it was kept). + """ + return FlashMobaAttnVarlenFunc.apply( + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + moba_col_offsets, + moba_col_nnz, + moba_row_indices, + lg_block_m, + lg_block_n, + dropout_p, + softmax_scale, + causal, + softcap, + alibi_slopes, + deterministic, + return_attn_probs, + torch.is_grad_enabled(), + ) + +########################################################################################################################## + +def flash_moba_varlen_func( + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + moba_chunk_size, + moba_topk, + causal=True, +): + + col_offsets, col_nnz, indices = flash_topk_varlen_func( + q, + k, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + # MOBA sparse pattern parameters + moba_topk, + moba_chunk_size, + causal=causal, + ) + + lg_block_m = decide_lg_block_m(moba_topk, moba_chunk_size, max_seqlen_k, causal) + + return flash_moba_attn_varlen_func( + q, k, v, + cu_seqlens_q, cu_seqlens_k, + max_seqlen_q, max_seqlen_k, + col_offsets, + col_nnz, + indices, + lg_block_m, + moba_chunk_size, + dropout_p=0.0, + causal=causal, + ) diff --git a/vortex_torch/attention_backend/flashmoba/triton_mean_pool.py b/vortex_torch/attention_backend/flashmoba/triton_mean_pool.py new file mode 100644 index 0000000..6fbd59f --- /dev/null +++ b/vortex_torch/attention_backend/flashmoba/triton_mean_pool.py @@ -0,0 +1,158 @@ +# Copyright (c) 2025, FlashMoBA Team. +import torch +import torch.nn.functional as F + +import triton +import triton.language as tl + + +# `triton.jit`'ed functions can be auto-tuned by using the `triton.autotune` decorator, which consumes: +# - A list of `triton.Config` objects that define different configurations of +# meta-parameters (e.g., `BLOCK_SIZE_M`) and compilation options (e.g., `num_warps`) to try +# - An auto-tuning *key* whose change in values will trigger evaluation of all the +# provided configs +@triton.autotune( + configs=[ + # triton.Config({'kBlockN': 16}, num_warps=2, num_stages=3), + triton.Config({'kBlockN': 32}, num_warps=2, num_stages=3), + triton.Config({'kBlockN': 32}, num_warps=4, num_stages=3), + triton.Config({'kBlockN': 32}, num_warps=4, num_stages=4), + triton.Config({'kBlockN': 64}, num_warps=2, num_stages=3), + triton.Config({'kBlockN': 64}, num_warps=4, num_stages=3), + triton.Config({'kBlockN': 64}, num_warps=4, num_stages=4), + triton.Config({'kBlockN': 64}, num_warps=8, num_stages=3), + triton.Config({'kBlockN': 128}, num_warps=2, num_stages=3), + triton.Config({'kBlockN': 128}, num_warps=4, num_stages=3), + triton.Config({'kBlockN': 128}, num_warps=4, num_stages=4), + triton.Config({'kBlockN': 128}, num_warps=8, num_stages=3), + triton.Config({'kBlockN': 128}, num_warps=8, num_stages=4), + # triton.Config({'kBlockN': 256}, num_warps=4, num_stages=3), + # triton.Config({'kBlockN': 256}, num_warps=8, num_stages=3), + # triton.Config({'kBlockN': 256}, num_warps=8, num_stages=4), + # triton.Config({'kBlockN': 256}, num_warps=16, num_stages=2), + # triton.Config({'kBlockN': 512}, num_warps=8, num_stages=2), + # triton.Config({'kBlockN': 512}, num_warps=16, num_stages=2), + # triton.Config({'kBlockN': 512}, num_warps=16, num_stages=3), + # triton.Config({'kBlockN': 1024}, num_warps=16, num_stages=2), + ], + key=['HEAD_DIM', 'POOL_BLOCK_SIZE'], +) +@triton.jit +def mean_pool_kernel( + # Pointers to matrices + input_ptr, + output_ptr, + # Matrix dimensions + HEAD_DIM: tl.constexpr, + POOL_BLOCK_SIZE: tl.constexpr, + cu_seqlens_input, + cu_seqlens_output, + input_stride_row, input_stride_head, + output_stride_row, output_stride_head, + # Meta-parameters + kBlockN: tl.constexpr, +): + """ + Triton kernel for mean pooling over variable-length sequences. + + This kernel computes the mean of non-overlapping blocks of size `POOL_BLOCK_SIZE` + for each sequence in a batch. It is designed to handle variable sequence lengths. + + Args: + input_ptr: Pointer to the input tensor of shape (total_seqlen, num_heads, head_dim). + output_ptr: Pointer to the output tensor of shape (total_blocks, num_heads, head_dim). + HEAD_DIM: The dimension of each head. + POOL_BLOCK_SIZE: The size of the pooling window. + cu_seqlens_input: Cumulative sequence lengths of the input tensor, shape (batch_size + 1,). + cu_seqlens_output: Cumulative sequence lengths of the output tensor, shape (batch_size + 1,). + input_stride_row: Stride of the input tensor along the sequence dimension. + input_stride_head: Stride of the input tensor along the head dimension. + output_stride_row: Stride of the output tensor along the sequence dimension. + output_stride_head: Stride of the output tensor along the head dimension. + kBlockN: Block size for the sequence dimension, a meta-parameter for tuning. + """ + n_block = tl.program_id(0) + bidb = tl.program_id(1) + bidh = tl.program_id(2) + + seq_start = tl.load(cu_seqlens_input + bidb) + seq_end = tl.load(cu_seqlens_input + bidb + 1) + + block_start_row = seq_start + n_block * POOL_BLOCK_SIZE + + if seq_end <= block_start_row: + return + + actual_block_size = tl.minimum(POOL_BLOCK_SIZE, seq_end - block_start_row) + + offsets_d = tl.arange(0, HEAD_DIM) + # mask_d = offsets_d < HEAD_DIM + + acc = tl.zeros([HEAD_DIM], dtype=tl.float32) + + for block_k_start in range(0, actual_block_size, kBlockN): + offsets_k = block_k_start + tl.arange(0, kBlockN) + mask_k = offsets_k < actual_block_size + + row_indices = block_start_row + offsets_k + + input_offset = row_indices[:, None] * input_stride_row.to(tl.int64) + bidh * input_stride_head.to(tl.int64) + offsets_d[None, :] + + inp = tl.load(input_ptr + input_offset, mask=mask_k[:, None], other=0.0) + acc += tl.sum(inp, axis=0) + + # safe division + mean_val = acc / actual_block_size + + output_start = tl.load(cu_seqlens_output + bidb) + output_offset = (output_start + n_block) * output_stride_row.to(tl.int64) + bidh * output_stride_head.to(tl.int64) + offsets_d + tl.store(output_ptr + output_offset, mean_val) + + +def flash_topk_mean_pool(input, cu_seqlens_input, max_seqlen_input, pool_block_size): + """ + Performs mean pooling on variable-length sequences using a Triton kernel. + + This function takes a tensor of packed sequences and applies mean pooling over + fixed-size blocks. + + Args: + input (torch.Tensor): The input tensor of shape (total_seqlen, num_heads, head_dim). + cu_seqlens_input (torch.Tensor): Cumulative sequence lengths for the input, shape (batch_size + 1,). + max_seqlen_input (int): The maximum sequence length in the input batch. + pool_block_size (int): The size of the pooling window. + + Returns: + Tuple[torch.Tensor, torch.Tensor, int]: A tuple containing: + - output (torch.Tensor): The pooled output tensor of shape (total_blocks, num_heads, head_dim). + - cu_seqlens_output (torch.Tensor): Cumulative sequence lengths for the output. + - max_seqlen_output (int): The maximum number of blocks for any sequence in the batch. + """ + total_seqlen, head_num, head_dim = input.shape + batch_size = cu_seqlens_input.shape[0] - 1 + + max_seqlen_output = (max_seqlen_input + pool_block_size - 1) // pool_block_size + + actual_input_seqlens = cu_seqlens_input[1:] - cu_seqlens_input[:-1] + actual_output_seqlens = (actual_input_seqlens + pool_block_size - 1) // pool_block_size + cu_seqlens_output = F.pad(torch.cumsum(actual_output_seqlens, dim=0), (1, 0)).to(torch.int32) + + total_blocks = cu_seqlens_output[-1].item() + + output = torch.zeros((total_blocks, head_num, head_dim), dtype=input.dtype, device=input.device) + + grid = (max_seqlen_output, batch_size, head_num) + + mean_pool_kernel[grid]( + input, + output, + head_dim, + pool_block_size, + cu_seqlens_input, + cu_seqlens_output, + input.stride(0), input.stride(1), + output.stride(0), output.stride(1), + ) + + return output, cu_seqlens_output, max_seqlen_output + \ No newline at end of file diff --git a/vortex_torch/attention_backend/fsa/FSA_topk_sparse_attention.py b/vortex_torch/attention_backend/fsa/FSA_topk_sparse_attention.py new file mode 100644 index 0000000..acca2ac --- /dev/null +++ b/vortex_torch/attention_backend/fsa/FSA_topk_sparse_attention.py @@ -0,0 +1,2040 @@ +# Copyright 2025 Ran Yan. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific +import math +from typing import Any, Optional + +import torch +import triton +import triton.language as tl + +from ..nsa.topk_sparse_attention import (backward_sum_o_do, + reorder_topk_idx, + get_num_warps_stages, + is_hopper_gpu) + +IS_HOPPER_GPU = is_hopper_gpu() + + +@triton.jit +def fused_fill_kernel(ptr_tile, ptr_m_i_cur_tiles, N, BLOCK_SIZE: tl.constexpr): + pid = tl.program_id(0) + offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = offsets < N + + tl.store(ptr_tile + offsets, -1, mask=mask) # fill int32 with -1 + tl.store(ptr_m_i_cur_tiles + offsets, float("-inf"), mask=mask) + + +def fused_fill(topk_idx_permuted_tile: torch.Tensor, m_i_cur_tiles): + + numel = topk_idx_permuted_tile.numel() + BLOCK_SIZE = 1024 + + # Flatten for pointer access + tile_flat = topk_idx_permuted_tile.view(-1) + + m_i_cur_tiles_flat = m_i_cur_tiles.view(-1) + + grid = lambda meta: (triton.cdiv(numel, meta['BLOCK_SIZE']),) + + fused_fill_kernel[grid]( + tile_flat, + m_i_cur_tiles_flat, + numel, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=1, + num_stages=3, + ) + + +@triton.jit +def block_to_token_kernel( + topk_idx_ptr, + result_ptr, + N_token, + K, + min_block_id, + max_block_id, + padding_value, + ts_h, + ts_b, + ts_n, + rs_h, + rs_b, + rs_n, + num_q_loops: tl.constexpr, + BLOCK_K: tl.constexpr, +): + pid = tl.program_id(0) # token index i + pid_h = 0 + offs = tl.arange(0, BLOCK_K) # [0, 1, ..., K-1] + + offs_q = tl.arange(0, num_q_loops) + + pid_j = pid * num_q_loops + offs_q + + topk_idx_offset = pid_h * ts_h + pid_j[None, :] * K + offs[:, None] + block_ids = tl.load( + topk_idx_ptr + topk_idx_offset, mask=(pid_j < N_token)[None, :] & (offs < K)[:, None], other=padding_value + ) + + result_ptrs = result_ptr + pid_h * rs_h + block_ids * N_token + pid_j[None, :] + + mask = (block_ids >= 0) & (block_ids != padding_value) & (pid_j < N_token)[None, :] + tl.store(result_ptrs, pid_j[None, :], mask=mask) + + +def build_block_to_token_triton( + result: torch.Tensor, topk_idx: torch.Tensor, min_block_id: int, max_block_id: int, padding_value: int = -1 +): + """ + Args: + topk_idx: [num_heads, N_token, TopK], block indices per token, padded with padding_value for invalid blocks + num_blocks: int + padding_value: int + + Returns: + result: [num_blocks, N_token], token indices per block, padded by padding_value + """ + assert topk_idx.ndim == 3 + assert padding_value == -1 + num_heads, N_token, TopK = topk_idx.shape + + # 每个 token,每个head 一个 program + num_q_loops = 4 + grid = (triton.cdiv(N_token, num_q_loops),) + BLOCK_K = triton.next_power_of_2(TopK) + block_to_token_kernel[grid]( + topk_idx, + result, + N_token, + TopK, + min_block_id, + max_block_id, + padding_value, + topk_idx.stride(0), + topk_idx.stride(1), + topk_idx.stride(2), + result.stride(0), + result.stride(1), + result.stride(2), + num_q_loops, + BLOCK_K=BLOCK_K, + num_warps=2, + num_stages=3, + ) + return result + + +@triton.jit +def reduce_kernel( + lse_ptr, # float32 [H, N] + m_ij_ptr, # float32 [H, B, N] + l_ij_first_ptr, # float32 [H, 1, N] + l_ij_rest_ptr, # float32 [H, B, N] + m_ij_last_ptr, # float32 [H, N] + o_ptr, # o: n x h x d + o_tiles_first_ptr, # o_tiles: n x h x 1 x d + o_tiles_rest_ptr, # o_tiles: n x h x b x d + acc_o_scales_first_ptr, # acc_o_scales: n x h x 1 + acc_o_scales_rest_ptr, # acc_o_scales: n x h x b + t_ptr, # topk_idx: h x n x k + token_index_mapping_ptr, + start_head_id, + num_qz_loop, + TOPK, + total_len, + # stride + stride_lse_h, + stride_lse_n, + stride_m_ij_h, + stride_m_ij_b, + stride_m_ij_n, + stride_l_ij_fh, + stride_l_ij_fb, + stride_l_ij_fn, + stride_l_ij_rh, + stride_l_ij_rb, + stride_l_ij_rn, + stride_on, + stride_oh, + stride_od, + stride_otfh, + stride_otfb, + stride_otfn, + stride_otfd, + stride_otrh, + stride_otrb, + stride_otrn, + stride_otrd, + stride_acc_fh, + stride_acc_fb, + stride_acc_fn, + stride_acc_rh, + stride_acc_rb, + stride_acc_rn, + stride_th, + stride_tn, + stride_tk, + stride_tim_h, + stride_tim_b, + stride_tim_n, + # META parameters + BLOCK_SIZE_T: tl.constexpr, + BLOCK_SIZE_D: tl.constexpr, +): + pid_qy = tl.program_id(0) + pid_q = tl.program_id(1) # token + + pid_q_j = pid_q + pid_qy * num_qz_loop + if pid_q_j < total_len: + t_ptr_j = t_ptr + pid_q_j * stride_tn + + off_d = tl.arange(0, BLOCK_SIZE_D) + o_ptrs = o_ptr + pid_q_j * stride_on + off_d + last_acc_o = tl.load(o_ptrs, mask=off_d < BLOCK_SIZE_D, other=0.0) + acc_o = tl.zeros((BLOCK_SIZE_D,), dtype=tl.float32) + acc_o += last_acc_o + + lse_ptrs = lse_ptr + pid_q_j * stride_lse_n + # Load lse + lse = tl.load(lse_ptrs, mask=pid_q_j < total_len, other=float("-inf")) + + # the stride is 1 for m_ij_last + m_ij_last = tl.load(m_ij_last_ptr + pid_q_j) + + for block_id in range(TOPK): + t = tl.load(t_ptr_j + block_id * stride_tk, mask=block_id < TOPK, other=-1) + if t != -1: + if t == 0: + real_block_pos = 0 + l_ij_ptr = l_ij_first_ptr + o_tiles_ptr = o_tiles_first_ptr + acc_o_scales_ptr = acc_o_scales_first_ptr + stride_l_ij_b = stride_l_ij_fb + stride_l_ij_n = stride_l_ij_fn + stride_acc_b = stride_acc_fb + stride_acc_n = stride_acc_fn + stride_otb = stride_otfb + stride_otn = stride_otfn + else: + real_block_pos = t - 1 + l_ij_ptr = l_ij_rest_ptr + o_tiles_ptr = o_tiles_rest_ptr + acc_o_scales_ptr = acc_o_scales_rest_ptr + stride_l_ij_b = stride_l_ij_rb + stride_l_ij_n = stride_l_ij_rn + stride_acc_b = stride_acc_rb + stride_acc_n = stride_acc_rn + stride_otb = stride_otrb + stride_otn = stride_otrn + + # init pointers + token_index_mapping_ptrs = ( + token_index_mapping_ptr + t.to(tl.int64) * stride_tim_b + (pid_q_j) * stride_tim_n + ) + real_token_index = tl.load(token_index_mapping_ptrs) + + m_ij = tl.load( + m_ij_ptr + t * stride_m_ij_b + pid_q_j * stride_m_ij_n, mask=pid_q_j < total_len, other=float("-inf") + ) + l_ij = tl.load( + l_ij_ptr + real_block_pos * stride_l_ij_b + real_token_index * stride_l_ij_n, + mask=real_token_index < total_len, + other=0.0, + ) + delta = lse - m_ij + + log_delta = tl.exp2(delta) + l_ij + + # Update lse + lse = m_ij + tl.log2(log_delta) + + o_tiles_ptrs = ( + o_tiles_ptr + real_block_pos.to(tl.int64) * stride_otb + (real_token_index) * stride_otn + off_d + ) + acc_o_scales_ptrs = acc_o_scales_ptr + real_block_pos * stride_acc_b + (real_token_index) * stride_acc_n + + o_tiles = tl.load(o_tiles_ptrs) + acc_o_scales_tiles = tl.load(acc_o_scales_ptrs) + acc_o = o_tiles + acc_o * acc_o_scales_tiles + + # final scale + acc_o = acc_o * tl.exp2(m_ij_last - lse) + tl.store(o_ptrs, acc_o, mask=off_d < BLOCK_SIZE_D) + + # Store back + tl.store( + lse_ptrs, + lse, + mask=pid_q_j < total_len, + ) + + +@triton.jit +def qk_kernel( + q_ptr, # Q: n x h x d + k_ptr, # K: n x h x d + m_i_tiles_ptr, # m_i: h x b x n + selected_tokens_ptr, # selected_tokens: sum(valid_lens), + valid_lens_ptr, # valid_lens: (h x b), + valid_start_indices_ptr, # valid_start_indices: (h x b), + num_heads, + num_blocks, + # seqlens + cu_seqlens_q, + cu_seqlens_k, + # shape + HEAD_DIM, + # sm_scale + sm_scale, + num_q_blocks, + num_b_blocks, + # stride + stride_qn, + stride_qh, + stride_qd, + stride_kn, + stride_kh, + stride_kd, + stride_m_i_tiles_h, + stride_m_i_tiles_b, + stride_m_i_tiles_n, + # META parameters + BLOCK_SIZE_Q: tl.constexpr, # q block size + BLOCK_SIZE_K: tl.constexpr, # k block size + BLOCK_SIZE_D: tl.constexpr, +): + qk_scale = sm_scale * 1.44269504 + # get batch id and head id + pid_block_grid = tl.program_id(0) // num_heads # block id + head_id = tl.program_id(0) % num_heads + pid_q = tl.program_id(1) # token + + # get q k start and len after rmpad + k_len = tl.load(cu_seqlens_k + 1) + k_ptrs = tl.make_block_ptr( + base=k_ptr + head_id * stride_kh, + shape=(HEAD_DIM, k_len), + strides=(stride_kd, stride_kn), + offsets=(0, 0), + block_shape=(BLOCK_SIZE_D, BLOCK_SIZE_K), + order=(0, 1), + ) + + for bb in range(num_b_blocks): + pid_block = bb + pid_block_grid * num_b_blocks + + start_id = tl.load(valid_start_indices_ptr + head_id * num_blocks + pid_block) + valid_tokens = tl.load(valid_lens_ptr + head_id * num_blocks + pid_block) + if pid_q * BLOCK_SIZE_Q < valid_tokens: + + c = pid_block * BLOCK_SIZE_K + + # load k + k = tl.load(tl.advance(k_ptrs, (0, c)), boundary_check=(1, 0), padding_option="zero") + + off_k = tl.arange(0, BLOCK_SIZE_K) + off_d = tl.arange(0, BLOCK_SIZE_D) + for j in range(num_q_blocks): + pid_q_j = pid_q * num_q_blocks + j + # Enable early return + if pid_q_j * BLOCK_SIZE_Q < valid_tokens: + # one thread block for one KV block, a subset of selected tokens + st_offs = start_id + (pid_q_j * BLOCK_SIZE_Q + tl.arange(0, BLOCK_SIZE_Q)) + # st should be in shape [BLOCK_SIZE_Q] + st_mask = (pid_q_j * BLOCK_SIZE_Q + tl.arange(0, BLOCK_SIZE_Q)) < valid_tokens + + st = tl.load(selected_tokens_ptr + st_offs, mask=st_mask, other=-1) + # otherwise, st selects a set of q tokens, selected_tokens_ptr should be sorted + q_ptrs_off = st[:, None] * stride_qn + off_d[None, :] * stride_qd + q_ptrs = q_ptr + head_id * stride_qh + q_ptrs_off + # load q + q_mask = (st != -1)[:, None] & (off_d < HEAD_DIM)[None, :] + q = tl.load(q_ptrs, mask=q_mask, other=0) + # compute qk + qk = tl.zeros((BLOCK_SIZE_Q, BLOCK_SIZE_K), dtype=tl.float32) + qk += tl.where((st[:, None] >= c + off_k[None, :]), 0, float("-inf")) + # [BLOCK_SIZE_Q, HEAD_DIM] @ [HEAD_DIM, BLOCK_SIZE_K] -> [BLOCK_SIZE_Q, BLOCK_SIZE_K] + qk += tl.dot(q, k) * qk_scale + + m_i = tl.max(qk, axis=1) + + m_i_tiles_ptrs = ( + m_i_tiles_ptr + + head_id * stride_m_i_tiles_h + + pid_block * stride_m_i_tiles_b + + st * stride_m_i_tiles_n + ) + tl.store(m_i_tiles_ptrs, m_i, mask=(st != -1)) + + +@triton.jit +def forward_kernel_opt( + q_ptr, + k_ptr, + v_ptr, # V: n x h x d + o_tiles_ptr, # O: n x h x b x d + acc_o_scales_ptr, # acc_o_scales: h x b x n + m_ij_tiles_ptr, + l_ij_ptr, # h x b x n + token_index_mapping_ptr, + selected_tokens_ptr, # selected_tokens: sum(valid_lens), + valid_lens_ptr, # valid_lens: (h x b), + valid_start_indices_ptr, # valid_start_indices: (h x b), + min_block_id, + cur_max_valid_tokens, + num_heads, + num_blocks, + # seqlens + cu_seqlens_q, + cu_seqlens_k, + # shape + HEAD_DIM, + # sm_scale + sm_scale, + num_q_blocks, + # stride + stride_qn, + stride_qh, + stride_qd, + stride_kn, + stride_kh, + stride_kd, + stride_vn, + stride_vh, + stride_vd, + stride_oth, + stride_otb, + stride_otn, + stride_otd, + stride_acc_oh, + stride_acc_ob, + stride_acc_on, + stride_m_ij_tiles_h, + stride_m_ij_tiles_b, + stride_m_ij_tiles_n, + stride_l_ij_h, + stride_l_ij_b, + stride_l_ij_n, + stride_tim_h, + stride_tim_b, + stride_tim_n, + # META parameters + BLOCK_SIZE_Q: tl.constexpr, # q block size + BLOCK_SIZE_K: tl.constexpr, # k block size + BLOCK_SIZE_D: tl.constexpr, +): + # get batch id and head id + pid_block = tl.program_id(0) // num_heads # block id + head_id = tl.program_id(0) % num_heads + pid_q = tl.program_id(1) # token + # seq packing is not supported yet + q_start = 0 + k_start = 0 + + k_len = tl.load(cu_seqlens_k + 1) - k_start + + start_id = tl.load(valid_start_indices_ptr + head_id * num_blocks + pid_block) + valid_tokens = tl.load(valid_lens_ptr + head_id * num_blocks + pid_block) + if num_q_blocks * pid_q * BLOCK_SIZE_Q >= valid_tokens: + return + + c = (min_block_id + pid_block) * BLOCK_SIZE_K + k_ptrs = tl.make_block_ptr( + base=k_ptr + k_start * stride_kn + head_id * stride_kh, + shape=(HEAD_DIM, k_len), + strides=(stride_kd, stride_kn), + offsets=(0, 0), + block_shape=(BLOCK_SIZE_D, BLOCK_SIZE_K), + order=(0, 1), + ) + # load k + k = tl.load(tl.advance(k_ptrs, (0, c)), boundary_check=(1, 0), padding_option="zero") + + v_ptrs = tl.make_block_ptr( + base=v_ptr + k_start * stride_vn + head_id * stride_vh, + shape=(k_len, HEAD_DIM), + strides=(stride_vn, stride_vd), + offsets=(0, 0), + block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D), + order=(1, 0), + ) + + # load v + v = tl.load(tl.advance(v_ptrs, (c, 0)), boundary_check=(0, 1), padding_option="zero") + + off_k = tl.arange(0, BLOCK_SIZE_K) + off_d = tl.arange(0, BLOCK_SIZE_D) + for j in range(num_q_blocks): + pid_q_j = pid_q * num_q_blocks + j + if pid_q_j * BLOCK_SIZE_Q < valid_tokens: + # one thread block for one KV block, a subset of selected tokens + st_offs = start_id + (q_start + pid_q_j * BLOCK_SIZE_Q + tl.arange(0, BLOCK_SIZE_Q)) + # st should be in shape [BLOCK_SIZE_Q] + st_mask = (pid_q_j * BLOCK_SIZE_Q + tl.arange(0, BLOCK_SIZE_Q)) < valid_tokens + + st = tl.load(selected_tokens_ptr + st_offs, mask=st_mask, other=-1) + + # otherwise, st selects a set of q tokens, selected_tokens_ptr should be sorted + q_ptrs_off = st[:, None] * stride_qn + off_d[None, :] * stride_qd + + # load m_i + mask = st != -1 + + m_ij_tiles_ptrs = ( + m_ij_tiles_ptr + + head_id * stride_m_ij_tiles_h + + (q_start + st) * stride_m_ij_tiles_n + + (pid_block + min_block_id) * stride_m_ij_tiles_b + ) + m_ij = tl.load(m_ij_tiles_ptrs, mask=mask, other=float("-inf")) + + m_ij_tiles_prev_ptrs = ( + m_ij_tiles_ptr + + head_id * stride_m_ij_tiles_h + + (q_start + st) * stride_m_ij_tiles_n + + (pid_block + min_block_id - 1) * stride_m_ij_tiles_b + ) + m_ij_prev = tl.load(m_ij_tiles_prev_ptrs, mask=mask & (pid_block + min_block_id > 0), other=float("-inf")) + + m_i_minus_m_ij = m_ij_prev - m_ij + + q_ptrs = q_ptr + q_start * stride_qn + head_id * stride_qh + q_ptrs_off + # load q + q_mask = mask[:, None] & (off_d < HEAD_DIM)[None, :] + q = tl.load(q_ptrs, mask=q_mask, other=0) + + # compute qk + qk = tl.zeros((BLOCK_SIZE_Q, BLOCK_SIZE_K), dtype=tl.float32) + qk += tl.where((st[:, None] >= c + off_k[None, :]), 0, float("-inf")) + + # [BLOCK_SIZE_Q, HEAD_DIM] @ [HEAD_DIM, BLOCK_SIZE_K] -> [BLOCK_SIZE_Q, BLOCK_SIZE_K] + qk_scale = sm_scale * 1.44269504 + qk += tl.dot(q, k) * qk_scale + + # init statistics + acc_o_buffer = tl.full((BLOCK_SIZE_Q, BLOCK_SIZE_D), 0, dtype=tl.float32) + + # load m_ij and compute l_ij + p = tl.exp2(qk - m_ij[:, None]) + l_ij = tl.sum(p, axis=1) + + # load token index mapping + token_index_mapping_ptrs = ( + token_index_mapping_ptr + (st) * stride_tim_n + (pid_block + min_block_id) * stride_tim_b + ) + token_index_mapping = tl.load(token_index_mapping_ptrs, mask=mask, other=-1) + + l_ij_ptrs = ( + l_ij_ptr + + head_id * stride_l_ij_h + + (q_start + token_index_mapping) * stride_l_ij_n + + (pid_block) * stride_l_ij_b + ) + tl.store(l_ij_ptrs, l_ij, mask=mask) + # scale acc_o + if pid_block + min_block_id == 0: + acc_o_scale = tl.full((BLOCK_SIZE_Q,), 1.0, dtype=tl.float32) + else: + acc_o_scale = tl.exp2(m_i_minus_m_ij) + + tl.store( + acc_o_scales_ptr + + head_id * stride_acc_oh + + (pid_block) * stride_acc_ob + + (q_start + token_index_mapping) * stride_acc_on, + acc_o_scale, + mask=(st != -1), + ) + + p = p.to(v.dtype) + acc_o_buffer = tl.dot(p, v) + + o_ptrs_off = token_index_mapping[:, None] * stride_otn + off_d[None, :] * stride_otd + o_ptrs = o_tiles_ptr + head_id * stride_oth + o_ptrs_off + (pid_block).to(tl.int64) * stride_otb + tl.store(o_ptrs, acc_o_buffer.to(o_tiles_ptr.dtype.element_ty), mask=q_mask) + + +def _topk_sparse_attention_fwd_opt( + q: torch.Tensor, # [total_len, num_heads, head_dim] + k: torch.Tensor, # [total_len, num_heads, head_dim] + v: torch.Tensor, # [total_len, num_heads, head_dim] + topk_idx: torch.Tensor, # [num_heads, total_len, topk] + block_size: int, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + max_seqlen_q: int, + max_seqlen_k: int, + sm_scale: float, + causal=True, +): + """ + TODO: Currently sequence packing is explicitly done in for loop, will merge in kernels. + """ + o = torch.empty_like(q) + total_len, num_heads, _ = q.shape + lse = torch.empty((num_heads, total_len), dtype=torch.float32, device=q.device) + + permute_results = [] + for i in range(len(cu_seqlens_q) - 1): + cu_seqlens_q_ = cu_seqlens_q[i: i + 2] - cu_seqlens_q[i] + cu_seqlens_k_ = cu_seqlens_k[i: i + 2] - cu_seqlens_k[i] + max_seqlen_q_ = cu_seqlens_q_[1] - cu_seqlens_q_[0] + max_seqlen_k_ = cu_seqlens_k_[1] - cu_seqlens_k_[0] + + q_ = q[cu_seqlens_q[i]: cu_seqlens_q[i + 1]] + k_ = k[cu_seqlens_k[i]: cu_seqlens_k[i + 1]] + v_ = v[cu_seqlens_k[i]: cu_seqlens_k[i + 1]] + topk_idx_ = topk_idx[:, cu_seqlens_q[i]: cu_seqlens_q[i + 1]] + o_seq, lse_seq, permute_results_seq = _topk_sparse_attention_fwd_opt_per_seq( + q_, + k_, + v_, + topk_idx_, + block_size, + cu_seqlens_q_, + cu_seqlens_k_, + max_seqlen_q_, + max_seqlen_k_, + sm_scale, + causal, + ) + o[cu_seqlens_q[i]: cu_seqlens_q[i + 1]] = o_seq + + lse[:, cu_seqlens_q[i]: cu_seqlens_q[i + 1]] = lse_seq + permute_results.append(permute_results_seq) + + return o, lse, permute_results + + +@triton.jit +def index_mapping_kernel( + token_index_mapping_ptr, + selected_tokens_ptr, + valid_lens_ptr, + valid_start_indices_ptr, + stride_im_h, + stride_im_b, + stride_im_n, + BLOCK_SIZE_K: tl.constexpr, +): + pid_b = tl.program_id(0) + pid_n = tl.program_id(1) + + offs_q = tl.arange(0, BLOCK_SIZE_K) + offs_n = pid_n * BLOCK_SIZE_K + offs_q + + start_id = tl.load(valid_start_indices_ptr + pid_b) + valid_tokens = tl.load(valid_lens_ptr + pid_b) + + st_offs = start_id + offs_n + # st should be in shape [BLOCK_SIZE_K] + st_mask = offs_n < valid_tokens + + st = tl.load(selected_tokens_ptr + st_offs, mask=st_mask, other=-1) + + token_im_ptrs = token_index_mapping_ptr + pid_b * stride_im_b + st * stride_im_n + + tl.store(token_im_ptrs, offs_n, mask=st_mask) + + +def index_mapping(token_index_mapping, valid_topk_idx_permuted_tile, valid_lens, valid_start_indices, num_blocks): + max_tokens = valid_lens.max() + BLOCK_SIZE_K = 1024 + grid = (num_blocks, triton.cdiv(max_tokens, BLOCK_SIZE_K)) + + index_mapping_kernel[grid]( + token_index_mapping, + valid_topk_idx_permuted_tile, + valid_lens, + valid_start_indices, + token_index_mapping.stride(0), + token_index_mapping.stride(1), + token_index_mapping.stride(2), + BLOCK_SIZE_K, + num_warps=2, + num_stages=3, + ) + + +def online_softmax( + q_tile, + k_tile, + m_i_cur_tiles, + valid_topk_idx_permuted_tile, + valid_lens, + valid_start_indices, + compute_min_block_id, + cur_max_valid_tokens, + block_size, + num_blocks, + head_tile, + head_dim, + sm_scale, + cu_seqlens_q, + cu_seqlens_k, +): + + # launch kernel + BLOCK_SIZE_Q = 128 + BLOCK_SIZE_K = triton.next_power_of_2(block_size) + BLOCK_SIZE_D = triton.next_power_of_2(head_dim) + num_q_blocks = 8 + num_b_blocks = 1 + grid_qk = lambda META: ( + triton.cdiv(num_blocks, num_b_blocks), + triton.cdiv(cur_max_valid_tokens, BLOCK_SIZE_Q * num_q_blocks), + ) + qk_kernel[grid_qk]( + q_tile, + k_tile, + m_i_cur_tiles, + valid_topk_idx_permuted_tile, + valid_lens, + valid_start_indices, + head_tile, + num_blocks, + cu_seqlens_q, + cu_seqlens_k, + head_dim, + sm_scale, + num_q_blocks, + num_b_blocks, + q_tile.stride(0), + q_tile.stride(1), + q_tile.stride(2), + k_tile.stride(0), + k_tile.stride(1), + k_tile.stride(2), + m_i_cur_tiles.stride(0), + m_i_cur_tiles.stride(1), + m_i_cur_tiles.stride(2), + BLOCK_SIZE_Q=BLOCK_SIZE_Q, + BLOCK_SIZE_K=BLOCK_SIZE_K, + BLOCK_SIZE_D=BLOCK_SIZE_D, + num_warps=8, + num_stages=3, + ) + + m_ij_tiles = m_i_cur_tiles.cummax(dim=1).values + m_ij_last = m_ij_tiles[:, -1] + + return m_ij_tiles, m_ij_last + + +def qkv_kernel( + q_tile, + k_tile, + v_tile, + o_tiles, + acc_o_scales, + m_ij_tiles, + l_ij, + token_index_mapping, + valid_topk_idx_permuted_tile, + valid_lens, + valid_start_indices, + compute_min_block_id, + cur_max_valid_tokens, + head_tile, + compute_tile_size, + cu_seqlens_q, + cu_seqlens_k, + head_dim, + sm_scale, + block_size, +): + BLOCK_SIZE_Q = 128 + BLOCK_SIZE_K = triton.next_power_of_2(block_size) + BLOCK_SIZE_D = triton.next_power_of_2(head_dim) + + # a heuristic that avoids large grid size, and redudant KV loading + num_q_blocks = 8 + + grid_fwd = lambda META: ( + compute_tile_size * head_tile, + triton.cdiv(cur_max_valid_tokens, BLOCK_SIZE_Q * num_q_blocks), + ) + + forward_kernel_opt[grid_fwd]( + q_tile, + k_tile, + v_tile, + o_tiles, + acc_o_scales, + m_ij_tiles, + l_ij, + token_index_mapping, + valid_topk_idx_permuted_tile, + valid_lens, + valid_start_indices, + compute_min_block_id, + cur_max_valid_tokens, + head_tile, + compute_tile_size, + cu_seqlens_q, + cu_seqlens_k, + head_dim, + sm_scale, + num_q_blocks, + q_tile.stride(0), + q_tile.stride(1), + q_tile.stride(2), + k_tile.stride(0), + k_tile.stride(1), + k_tile.stride(2), + v_tile.stride(0), + v_tile.stride(1), + v_tile.stride(2), + o_tiles.stride(0), + o_tiles.stride(1), + o_tiles.stride(2), + o_tiles.stride(3), + acc_o_scales.stride(0), + acc_o_scales.stride(1), + acc_o_scales.stride(2), + m_ij_tiles.stride(0), + m_ij_tiles.stride(1), + m_ij_tiles.stride(2), + l_ij.stride(0), + l_ij.stride(1), + l_ij.stride(2), + token_index_mapping.stride(0), + token_index_mapping.stride(1), + token_index_mapping.stride(2), + BLOCK_SIZE_Q=BLOCK_SIZE_Q, + BLOCK_SIZE_K=BLOCK_SIZE_K, + BLOCK_SIZE_D=BLOCK_SIZE_D, + num_stages=3, + num_warps=4, + ) + + +def reduce_output( + lse, + o, + o_tiles_first, + o_tiles_rest, + m_ij_tiles, + l_ij_first, + l_ij_rest, + m_ij_last, + acc_o_scales_first, + acc_o_scales_rest, + topk_idx_tile, + token_index_mapping, + h, + head_tile, + total_len, + TOPK, + head_dim, +): + + num_qy_loop = 4 + num_qz_loop = total_len // num_qy_loop + + grid_reduce = lambda META: ( + num_qy_loop + (total_len % num_qy_loop != 0), + num_qz_loop, + ) + + reduce_kernel[grid_reduce]( + lse, + m_ij_tiles, + l_ij_first, + l_ij_rest, + m_ij_last, + o, + o_tiles_first, + o_tiles_rest, + acc_o_scales_first, + acc_o_scales_rest, + topk_idx_tile, + token_index_mapping, + h * head_tile, + num_qz_loop, + TOPK, + total_len, + lse.stride(0), + lse.stride(1), + m_ij_tiles.stride(0), + m_ij_tiles.stride(1), + m_ij_tiles.stride(2), + l_ij_first.stride(0), + l_ij_first.stride(1), + l_ij_first.stride(2), + l_ij_rest.stride(0), + l_ij_rest.stride(1), + l_ij_rest.stride(2), + o.stride(0), + o.stride(1), + o.stride(2), + o_tiles_first.stride(0), + o_tiles_first.stride(1), + o_tiles_first.stride(2), + o_tiles_first.stride(3), + o_tiles_rest.stride(0), + o_tiles_rest.stride(1), + o_tiles_rest.stride(2), + o_tiles_rest.stride(3), + acc_o_scales_first.stride(0), + acc_o_scales_first.stride(1), + acc_o_scales_first.stride(2), + acc_o_scales_rest.stride(0), + acc_o_scales_rest.stride(1), + acc_o_scales_rest.stride(2), + topk_idx_tile.stride(0), + topk_idx_tile.stride(1), + topk_idx_tile.stride(2), + token_index_mapping.stride(0), + token_index_mapping.stride(1), + token_index_mapping.stride(2), + BLOCK_SIZE_T=triton.next_power_of_2(TOPK), + BLOCK_SIZE_D=triton.next_power_of_2(head_dim), + num_warps=1, + num_stages=2, + ) + + +def _topk_sparse_attention_fwd_opt_per_seq( + q: torch.Tensor, # [total_len, num_heads, head_dim] + k: torch.Tensor, # [total_len, num_kv_heads, head_dim] + v: torch.Tensor, # [total_len, num_kv_heads, head_dim] + topk_idx: torch.Tensor, # [num_heads, total_len, topk] + block_size: int, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + max_seqlen_q: int, + max_seqlen_k: int, + sm_scale: float, + causal=True, +): + # dtype check + assert k.dtype == q.dtype and v.dtype == q.dtype + assert cu_seqlens_q.dtype == torch.int32 and cu_seqlens_k.dtype == torch.int32 + assert block_size in {16, 32, 64, 128, 256} + # shape + + total_len, num_heads, head_dim = q.shape + total_len, num_kv_heads, head_dim = k.shape + + assert num_heads % num_kv_heads == 0 + gqa_deg = num_heads // num_kv_heads + + TOPK = topk_idx.shape[-1] + + real_num_blocks = math.ceil(total_len / block_size) + num_blocks = max(real_num_blocks, TOPK) + + head_tile = 1 + reduce_tile_size = num_blocks - 1 + + valid_lens_all = torch.zeros( + ( + num_kv_heads, + num_blocks, + ), + dtype=torch.int32, + device=q.device, + ) + for h in range(num_kv_heads): + topk_idx_tile = topk_idx[h * head_tile: (h + 1) * head_tile] + topk_idx_nonneg = topk_idx_tile[topk_idx_tile >= 0] + valid_lens = torch.bincount(topk_idx_nonneg.view(-1), minlength=num_blocks) + valid_lens_all[h * head_tile: (h + 1) * head_tile] = valid_lens + + global_max_valid_tokens = valid_lens_all[:, 1:].max() if num_blocks > 1 else valid_lens_all.max() + + o_full = torch.zeros_like(q) + lse_full = torch.full((num_heads, total_len), float("-inf"), dtype=torch.float32, device=q.device) + + # New introduced buffers + topk_idx_permuted_tile = torch.full((head_tile, num_blocks, total_len), -1, dtype=torch.int32, device=q.device) + + token_index_mapping = torch.full((head_tile, num_blocks, total_len), 0, dtype=torch.int32, device=q.device) + # first KV block is computed seaprately + o_tiles_first = torch.zeros((head_tile, 1, total_len, head_dim), dtype=torch.bfloat16, device=q.device) + o_tiles_rest = torch.zeros( + (head_tile, reduce_tile_size, global_max_valid_tokens, head_dim), dtype=torch.bfloat16, device=q.device + ) + + # Statistics buffers + # m_i_tiles: 历史最大, m_diff_tiles: 历史最大和当前最大的差值 + # m_i_cur_tiles: 当前最大, # m_ij_tiles: 考虑当前和历史后的最大 + m_i_cur_tiles: torch.Tensor = torch.full( + (head_tile, num_blocks, total_len), float("-inf"), dtype=torch.float32, device=q.device + ) + + # first KV block is reduced separately + l_ij_first = torch.full((head_tile, 1, total_len), 0, dtype=torch.float32, device=q.device) + acc_o_scales_first = torch.full((head_tile, 1, total_len), 1, dtype=torch.float32, device=q.device) + + l_ij_rest = torch.full( + (head_tile, reduce_tile_size, global_max_valid_tokens), 0, dtype=torch.float32, device=q.device + ) + acc_o_scales_rest = torch.full( + (head_tile, reduce_tile_size, global_max_valid_tokens), 1, dtype=torch.float32, device=q.device + ) + + permute_results = {} + permute_results['global_max_valid_tokens'] = global_max_valid_tokens + permute_results['num_blocks'] = num_blocks + permute_results['real_num_blocks'] = real_num_blocks + permute_results['valid_topk_idx_permuted_tile'] = [] + permute_results['valid_lens_all'] = valid_lens_all + permute_results['valid_lens'] = [] + permute_results['valid_start_indices'] = [] + + for h in range(num_heads // head_tile): + q_tile = q[:, h * head_tile: (h + 1) * head_tile] + k_tile = k[:, (h // gqa_deg) * head_tile: ((h // gqa_deg + 1)) * head_tile] + v_tile = v[:, (h // gqa_deg) * head_tile: ((h // gqa_deg + 1)) * head_tile] + o = o_full[:, h * head_tile: (h + 1) * head_tile] + lse = lse_full[h * head_tile: (h + 1) * head_tile] + + permute_min_block_id = 0 + permute_max_block_id = min(permute_min_block_id + num_blocks, num_blocks) + + topk_idx_tile = topk_idx[(h // gqa_deg) * head_tile: ((h // gqa_deg + 1)) * head_tile] + + if h % gqa_deg == 0: + topk_idx_permuted_tile = build_block_to_token_triton( + topk_idx_permuted_tile, topk_idx_tile, permute_min_block_id, permute_max_block_id, padding_value=-1 + ) + + valid_topk_idx_permuted_tile = topk_idx_permuted_tile[topk_idx_permuted_tile != -1] + valid_lens = valid_lens_all[(h // gqa_deg) * head_tile, :] + valid_start_indices = torch.nn.functional.pad(valid_lens.cumsum(0)[:-1], (1, 0), value=0) + + index_mapping( + token_index_mapping, valid_topk_idx_permuted_tile, valid_lens, valid_start_indices, num_blocks + ) + + permute_results['valid_topk_idx_permuted_tile'].append(valid_topk_idx_permuted_tile) + permute_results['valid_lens'].append(valid_lens) + permute_results['valid_start_indices'].append(valid_start_indices) + + m_ij_tiles, m_ij_last = online_softmax( + q_tile, + k_tile, + m_i_cur_tiles, + valid_topk_idx_permuted_tile, + valid_lens, + valid_start_indices, + 0, + total_len, + block_size, + num_blocks, + head_tile, + head_dim, + sm_scale, + cu_seqlens_q, + cu_seqlens_k, + ) + + m_ij_tiles[:, :, :] = m_ij_tiles[:, :, 0][:, :, None] + m_ij_last[:, :] = m_ij_last[:, 0] + for compute_min_block_id in range(min(2, num_blocks)): + if compute_min_block_id == 0: + cur_max_valid_tokens = total_len + cur_valid_lens = valid_lens[0] + cur_valid_start_indices = valid_start_indices[0] + o_tiles = o_tiles_first + l_ij = l_ij_first + acc_o_scales = acc_o_scales_first + compute_tile_size = 1 + else: + cur_max_valid_tokens = valid_lens[compute_min_block_id:].max() + cur_valid_lens = valid_lens[compute_min_block_id:] + cur_valid_start_indices = valid_start_indices[compute_min_block_id:] + o_tiles = o_tiles_rest + l_ij = l_ij_rest + acc_o_scales = acc_o_scales_rest + compute_tile_size = num_blocks - 1 + + # launch kernel + qkv_kernel( + q_tile, + k_tile, + v_tile, + o_tiles, + acc_o_scales, + m_ij_tiles, + l_ij, + token_index_mapping, + valid_topk_idx_permuted_tile, + cur_valid_lens, + cur_valid_start_indices, + compute_min_block_id, + cur_max_valid_tokens, + head_tile, + compute_tile_size, + cu_seqlens_q, + cu_seqlens_k, + head_dim, + sm_scale, + block_size, + ) + + reduce_output( + lse, + o, + o_tiles_first, + o_tiles_rest, + m_ij_tiles, + l_ij_first, + l_ij_rest, + m_ij_last, + acc_o_scales_first, + acc_o_scales_rest, + topk_idx_tile, + token_index_mapping, + h, + head_tile, + total_len, + TOPK, + head_dim, + ) + + o_full[:, h * head_tile: (h + 1) * head_tile] = o + lse_full[h * head_tile: (h + 1) * head_tile] = lse + + if h % gqa_deg == 0: + fused_fill(topk_idx_permuted_tile, m_i_cur_tiles) + + return o_full, lse_full, permute_results + + +@triton.jit +def dq_compute_kernel( + q_ptr, + k_ptr, + v_ptr, + lse_ptr, + delta_ptr, + do_ptr, + dq_tiles_ptr, + token_index_mapping_ptr, + selected_tokens_ptr, + valid_lens_ptr, + valid_start_indices_ptr, + cur_max_valid_tokens, + compute_min_block_id, + head_tile, + num_blocks, + HEAD_DIM, + cu_seqlens_k, + num_dq_blocks, + sm_scale, + debug_ptr, + stride_qn, + stride_qh, + stride_qd, + stride_kn, + stride_kh, + stride_kd, + stride_vn, + stride_vh, + stride_vd, + stride_tim_h, + stride_tim_b, + stride_tim_n, + stride_dqth, + stride_dqtb, + stride_dqtn, + stride_dqtd, + BLOCK_SIZE_Q: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + BLOCK_SIZE_D: tl.constexpr, +): + + pid_block = tl.program_id(0) + pid_q = tl.program_id(1) # token + # seq packing is not supported yet + q_start = 0 + k_start = 0 + + k_len = tl.load(cu_seqlens_k + 1) - k_start + + start_id = tl.load(valid_start_indices_ptr + pid_block) + valid_tokens = tl.load(valid_lens_ptr + pid_block) + if num_dq_blocks * pid_q * BLOCK_SIZE_Q >= valid_tokens: + return + + c = (pid_block + compute_min_block_id) * BLOCK_SIZE_K + k_ptrs = tl.make_block_ptr( + base=k_ptr + k_start * stride_kn, + shape=(k_len, HEAD_DIM), + strides=(stride_kn, stride_kd), + offsets=(0, 0), + block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D), + order=(1, 0), + ) + + # load k + k = tl.load(tl.advance(k_ptrs, (c, 0)), boundary_check=(1, 0), padding_option="zero") + v_ptrs = tl.make_block_ptr( + base=v_ptr + k_start * stride_vn, + shape=(HEAD_DIM, k_len), + strides=(stride_vd, stride_vn), + offsets=(0, 0), + block_shape=(BLOCK_SIZE_D, BLOCK_SIZE_K), + order=(0, 1), + ) + + # load v + v = tl.load(tl.advance(v_ptrs, (0, c)), boundary_check=(0, 1), padding_option="zero") + + qk_scale = sm_scale * 1.44269504 + + off_k = tl.arange(0, BLOCK_SIZE_K) + off_d = tl.arange(0, BLOCK_SIZE_D) + for j in range(num_dq_blocks): + pid_q_j = pid_q * num_dq_blocks + j + if pid_q_j * BLOCK_SIZE_Q < valid_tokens: + # one thread block for one KV block, a subset of selected tokens + st_offs = start_id + (q_start + pid_q_j * BLOCK_SIZE_Q + tl.arange(0, BLOCK_SIZE_Q)) + # st should be in shape [BLOCK_SIZE_Q] + st_mask = (pid_q_j * BLOCK_SIZE_Q + tl.arange(0, BLOCK_SIZE_Q)) < valid_tokens + + st = tl.load(selected_tokens_ptr + st_offs, mask=st_mask, other=-1) + tl.store(debug_ptr + tl.arange(0, BLOCK_SIZE_Q), st_offs) + # otherwise, st selects a set of q tokens, selected_tokens_ptr should be sorted + q_ptrs_off = st[:, None] * stride_qn + off_d[None, :] * stride_qd + + mask = st != -1 + + q_ptrs = q_ptr + q_start * stride_qn + q_ptrs_off + # load q + q_mask = mask[:, None] & (off_d < HEAD_DIM)[None, :] + q = tl.load(q_ptrs, mask=q_mask, other=0) + do_ptrs = do_ptr + q_start * stride_qn + q_ptrs_off + do = tl.load(do_ptrs, mask=q_mask, other=0) + delta_ptrs = delta_ptr + st[:, None] + d = tl.load(delta_ptrs, mask=mask[:, None], other=0) + lse_ptrs = lse_ptr + st[:, None] + lse = tl.load(lse_ptrs, mask=mask[:, None], other=0) + + dq = tl.zeros((BLOCK_SIZE_Q, BLOCK_SIZE_D), dtype=tl.float32) + qk = tl.zeros((BLOCK_SIZE_Q, BLOCK_SIZE_K), dtype=tl.float32) + qk += tl.where((st[:, None] >= c + off_k[None, :]), 0, float("-inf")) + qk += tl.dot(q, tl.trans(k)) * qk_scale # [BLOCK_SIZE_Q, BLOCK_SIZE_K] + p = tl.exp2(qk - lse) # [BLOCK_SIZE_Q, BLOCK_SIZE_K] + dp = tl.dot(do, v) # [BLOCK_SIZE_Q, BLOCK_SIZE_K] + ds = sm_scale * p * (dp - d) # [BLOCK_SIZE_Q, BLOCK_SIZE_K] + ds = ds.to(q.dtype) + dq = tl.dot(ds, k) # [BLOCK_SIZE_Q, BLOCK_SIZE_D] + + # load token index mapping + token_index_mapping_ptrs = ( + token_index_mapping_ptr + (st) * stride_tim_n + (pid_block + compute_min_block_id) * stride_tim_b + ) + token_index_mapping = tl.load(token_index_mapping_ptrs, mask=mask, other=-1) + + dq_ptrs_off = token_index_mapping[:, None] * stride_dqtn + off_d[None, :] * stride_dqtd + dq_tiles_ptrs = dq_tiles_ptr + dq_ptrs_off + (pid_block).to(tl.int64) * stride_dqtb + tl.store(dq_tiles_ptrs, dq.to(dq_tiles_ptr.dtype.element_ty), mask=q_mask) + + +@triton.jit +def dq_reduce_kernel( + dq_buffer_first_ptr, # [H, 1, N, D] + dq_buffer_rest_ptr, # [H, B, N, D] + dq_ptr, # o: n x h x d + t_ptr, # topk_idx: h x n x k + token_index_mapping_ptr, + num_qz_loop, + TOPK, + total_len, + # stride + stride_dqtfh, + stride_dqtfb, + stride_dqtfn, + stride_dqtfd, + stride_dqtrh, + stride_dqtrb, + stride_dqtrn, + stride_dqtrd, + stride_dqn, + stride_dqh, + stride_dqd, + stride_th, + stride_tn, + stride_tk, + stride_tim_h, + stride_tim_b, + stride_tim_n, + # META parameters + BLOCK_SIZE_T: tl.constexpr, + BLOCK_SIZE_D: tl.constexpr, +): + pid_qy = tl.program_id(0) + pid_q = tl.program_id(1) # token + + pid_q_j = pid_q + pid_qy * num_qz_loop + if pid_q_j < total_len: + t_ptr_j = t_ptr + pid_q_j * stride_tn + + off_d = tl.arange(0, BLOCK_SIZE_D) + dq_ptrs = dq_ptr + pid_q_j * stride_dqn + off_d + acc_dq = tl.zeros((BLOCK_SIZE_D,), dtype=tl.float32) + + for block_id in range(TOPK): + t = tl.load(t_ptr_j + block_id * stride_tk, mask=block_id < TOPK, other=-1) + if t != -1: + if t == 0: + dq_buffer_ptr = dq_buffer_first_ptr + stride_dqtb = stride_dqtfb + stride_dqtn = stride_dqtfn + real_block_pos = 0 + else: + dq_buffer_ptr = dq_buffer_rest_ptr + stride_dqtb = stride_dqtrb + stride_dqtn = stride_dqtrn + real_block_pos = t - 1 + + # init pointers + token_index_mapping_ptrs = ( + token_index_mapping_ptr + t.to(tl.int64) * stride_tim_b + (pid_q_j) * stride_tim_n + ) + real_token_index = tl.load(token_index_mapping_ptrs) + + dq_buffer_ptrs = ( + dq_buffer_ptr + real_block_pos.to(tl.int64) * stride_dqtb + (real_token_index) * stride_dqtn + off_d + ) + + dq_buffers = tl.load(dq_buffer_ptrs) + acc_dq = dq_buffers + acc_dq + + tl.store(dq_ptrs, acc_dq, mask=off_d < BLOCK_SIZE_D) + + +def backward_dq_opt( + q, # [total_len, num_heads, head_dim] + k, # [total_len, num_k_heads, head_dim] + v, # [total_len, num_k_heads, head_dim] + topk_idx, # [num_k_heads, total_len, topk] + lse, # [num_heads, total_len] + delta, # [num_heads, total_len] + do, # [total_len, num_heads, head_dim] + dq, # [total_len, num_heads, head_dim] + cu_seqlens_q, + cu_seqlens_k, + num_k_heads, + num_share_q_heads, + head_dim, + topk, + sm_scale, + block_size, + permute_results, +): + """ + TODO: Currently sequence packing is explicitly done in for loop, will merge in kernels. + """ + for i in range(len(cu_seqlens_q) - 1): + cu_seqlens_q_ = cu_seqlens_q[i: i + 2] - cu_seqlens_q[i] + cu_seqlens_k_ = cu_seqlens_k[i: i + 2] - cu_seqlens_k[i] + + permute_results_ = permute_results[i] + + q_ = q[cu_seqlens_q[i]: cu_seqlens_q[i + 1]] + k_ = k[cu_seqlens_k[i]: cu_seqlens_k[i + 1]] + v_ = v[cu_seqlens_k[i]: cu_seqlens_k[i + 1]] + topk_idx_ = topk_idx[:, cu_seqlens_q[i]: cu_seqlens_q[i + 1]] + lse_ = lse[:, cu_seqlens_q[i]: cu_seqlens_q[i + 1]] + delta_ = delta[:, cu_seqlens_q[i]: cu_seqlens_q[i + 1]] + do_ = do[cu_seqlens_q[i]: cu_seqlens_q[i + 1]] + dq_ = dq[cu_seqlens_q[i]: cu_seqlens_q[i + 1]] + + backward_dq_opt_per_seq( + q_, + k_, + v_, + topk_idx_, + lse_, + delta_, + do_, + dq_, + cu_seqlens_q_, + cu_seqlens_k_, + num_k_heads, + num_share_q_heads, + head_dim, + topk, + sm_scale, + block_size, + permute_results_, + ) + + dq[cu_seqlens_q[i]: cu_seqlens_q[i + 1]] = dq_ + + return dq + + +def backward_dq_opt_per_seq( + q, # [total_len, num_k_heads, head_dim] + k, # [total_len, num_k_heads, head_dim] + v, # [total_len, num_k_heads, head_dim] + topk_idx, # [num_k_heads, total_len, topk] + lse, # [num_k_heads, total_len] + delta, # [num_k_heads, total_len] + do, # [total_len, num_k_heads, head_dim] + dq, # [total_len, num_k_heads, head_dim] + cu_seqlens_q, + cu_seqlens_k, + num_k_heads, + num_share_q_heads, + head_dim, + topk, + sm_scale, + block_size, + permute_results, +): + head_tile = 1 + total_len = topk_idx.shape[1] + global_max_valid_tokens = permute_results['global_max_valid_tokens'] + num_blocks = permute_results['num_blocks'] + reduce_tile_size = num_blocks - 1 + dq_buffer_first = torch.zeros((head_tile, 1, total_len, head_dim), dtype=torch.bfloat16, device=dq.device) + dq_buffer_rest = torch.zeros( + (head_tile, reduce_tile_size, global_max_valid_tokens, head_dim), dtype=torch.bfloat16, device=dq.device + ) + + num_heads = num_share_q_heads * num_k_heads + + token_index_mapping = torch.full((head_tile, num_blocks, total_len), 0, dtype=torch.int32, device=q.device) + for h in range(num_heads // head_tile): + valid_topk_idx_permuted_tile = permute_results['valid_topk_idx_permuted_tile'][h // num_share_q_heads] + + valid_lens = permute_results['valid_lens'][h // num_share_q_heads] + valid_start_indices = permute_results['valid_start_indices'][h // num_share_q_heads] + + index_mapping(token_index_mapping, valid_topk_idx_permuted_tile, valid_lens, valid_start_indices, num_blocks) + q_tile = q[:, h * head_tile: (h + 1) * head_tile] + k_tile = k[:, (h // num_share_q_heads) * head_tile: ((h // num_share_q_heads + 1)) * head_tile] + v_tile = v[:, (h // num_share_q_heads) * head_tile: ((h // num_share_q_heads + 1)) * head_tile] + do_tile = do[:, h * head_tile: (h + 1) * head_tile] + lse_tile = lse[h * head_tile: (h + 1) * head_tile] + topk_idx_tile = topk_idx[(h // num_share_q_heads) * head_tile: ((h // num_share_q_heads + 1)) * head_tile] + delta_tile = delta[h * head_tile: (h + 1) * head_tile] + dq_tile = dq[:, h * head_tile: (h + 1) * head_tile] + + for compute_min_block_id in range(min(2, num_blocks)): + if compute_min_block_id == 0: + compute_tile_size = 1 + cur_max_valid_tokens = total_len + cur_valid_lens = valid_lens[0] + cur_valid_start_indices = valid_start_indices[0] + dq_buffer = dq_buffer_first + else: + compute_tile_size = num_blocks - 1 + cur_max_valid_tokens = valid_lens[compute_min_block_id:].max() + cur_valid_lens = valid_lens[compute_min_block_id:] + cur_valid_start_indices = valid_start_indices[compute_min_block_id:] + dq_buffer = dq_buffer_rest + + BLOCK_SIZE_Q = 128 + num_dq_blocks = 8 + grid_dq = lambda META: ( + compute_tile_size, + triton.cdiv(cur_max_valid_tokens, BLOCK_SIZE_Q * num_dq_blocks), + ) + + num_warps, num_stages = get_num_warps_stages(head_dim, BLOCK_SIZE_Q, IS_HOPPER_GPU) + BLOCK_SIZE_D = triton.next_power_of_2(head_dim) + BLOCK_SIZE_K = triton.next_power_of_2(block_size) + debug = torch.zeros((BLOCK_SIZE_Q,), dtype=torch.int32, device=dq.device) + dq_compute_kernel[grid_dq]( + q_tile, + k_tile, + v_tile, + lse_tile, + delta_tile, + do_tile, + dq_buffer, + token_index_mapping, + valid_topk_idx_permuted_tile, + cur_valid_lens, + cur_valid_start_indices, + cur_max_valid_tokens, + compute_min_block_id, + head_tile, + num_blocks, + head_dim, + cu_seqlens_k, + num_dq_blocks, + sm_scale, + debug, + q_tile.stride(0), + q_tile.stride(1), + q_tile.stride(2), + k_tile.stride(0), + k_tile.stride(1), + k_tile.stride(2), + v_tile.stride(0), + v_tile.stride(1), + v_tile.stride(2), + token_index_mapping.stride(0), + token_index_mapping.stride(1), + token_index_mapping.stride(2), + dq_buffer.stride(0), + dq_buffer.stride(1), + dq_buffer.stride(2), + dq_buffer.stride(3), + BLOCK_SIZE_Q=BLOCK_SIZE_Q, + BLOCK_SIZE_K=BLOCK_SIZE_K, + BLOCK_SIZE_D=BLOCK_SIZE_D, + num_warps=num_warps, + num_stages=num_stages, + ) + + num_qy_loop = 4 + num_qz_loop = total_len // num_qy_loop + + grid_reduce = lambda META: ( + num_qy_loop + (total_len % num_qy_loop != 0), + num_qz_loop, + ) + dq_reduce_kernel[grid_reduce]( + dq_buffer_first, + dq_buffer_rest, + dq_tile, + topk_idx_tile, + token_index_mapping, + num_qz_loop, + topk, + total_len, + dq_buffer_first.stride(0), + dq_buffer_first.stride(1), + dq_buffer_first.stride(2), + dq_buffer_first.stride(3), + dq_buffer_rest.stride(0), + dq_buffer_rest.stride(1), + dq_buffer_rest.stride(2), + dq_buffer_rest.stride(3), + dq_tile.stride(0), + dq_tile.stride(1), + dq_tile.stride(2), + topk_idx_tile.stride(0), + topk_idx_tile.stride(1), + topk_idx_tile.stride(2), + token_index_mapping.stride(0), + token_index_mapping.stride(1), + token_index_mapping.stride(2), + BLOCK_SIZE_T=triton.next_power_of_2(topk), + BLOCK_SIZE_D=BLOCK_SIZE_D, + num_warps=1, + num_stages=2, + ) + + dq[:, h * head_tile: (h + 1) * head_tile] = dq_tile + + return dq + + +@triton.jit +def backward_dkdv( + q_ptr, # Q: n x qh x d + k_ptr, # K: n x kh x d + v_ptr, # V: n x kh x d + tq_ptr, # topk_q_idx: kh x N + lse_ptr, # LSE: qh x n + d_ptr, # Delta: qh x n + do_ptr, + dk_ptr, # DK: sh x n x kh x d + dv_ptr, # DK: sh x n x kh x d + # seqlens + cu_seqlens_q, # [batch_size + 1] + cu_seqlens_k, # [batch_size + 1] + cu_seqblocks, # [batch_size + 1] + cu_topk_q_count, # [kh, total_blocks] + # shape + NUM_KV_HEADS, + NUM_SHARE_Q_HEADS, + HEAD_DIM, + TOPK, + # sm_scale + sm_scale, + # stride + stride_qn, + stride_qh, + stride_qd, + stride_kn, + stride_kh, + stride_kd, + stride_vn, + stride_vh, + stride_vd, + stride_tqh, + stride_tqn, + stride_ctqh, + stride_ctqn, + stride_lh, + stride_ln, + stride_dh, + stride_dn, + stride_don, + stride_doh, + stride_dod, + stride_dks, + stride_dkn, + stride_dkh, + stride_dkd, + stride_dvs, + stride_dvn, + stride_dvh, + stride_dvd, + # META parameters + BLOCK_SIZE_Q: tl.constexpr, # q block size + BLOCK_SIZE_K: tl.constexpr, # k block size + BLOCK_SIZE_D: tl.constexpr, +): + qk_scale = sm_scale * 1.44269504 + # get batch id and head id + pid_b = tl.program_id(0) + pid_h = tl.program_id(1) + pid_kh = pid_h // NUM_SHARE_Q_HEADS + pid_sh = pid_h % NUM_SHARE_Q_HEADS + pid_k = tl.program_id(2) + # get q k start and len after rmpad + q_start = tl.load(cu_seqlens_q + pid_b) + tl.load(cu_seqlens_q + pid_b + 1) - q_start + k_start = tl.load(cu_seqlens_k + pid_b) + k_len = tl.load(cu_seqlens_k + pid_b + 1) - k_start + if BLOCK_SIZE_K * pid_k >= k_len: + return + # get topk_q_idx + b_start = tl.load(cu_seqblocks + pid_b) # how many blocks before current sequence + act_q_start = tl.load(cu_topk_q_count + pid_kh * stride_ctqh + (b_start + pid_k) * stride_ctqn) + act_q_end = tl.load(cu_topk_q_count + pid_kh * stride_ctqh + (b_start + pid_k + 1) * stride_ctqn) + act_q_len = act_q_end - act_q_start + tq_ptr = tq_ptr + pid_kh * stride_tqh + act_q_start * stride_tqn + # init pointers + k_ptrs = tl.make_block_ptr( + base=k_ptr + k_start * stride_kn + pid_kh * stride_kh, + shape=(k_len, HEAD_DIM), + strides=(stride_kn, stride_kd), + offsets=(pid_k * BLOCK_SIZE_K, 0), + block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D), + order=(1, 0), + ) + dk_ptrs = tl.make_block_ptr( + base=dk_ptr + k_start * stride_dkn + pid_kh * stride_dkh + pid_sh * stride_dks, + shape=(k_len, HEAD_DIM), + strides=(stride_dkn, stride_dkd), + offsets=(pid_k * BLOCK_SIZE_K, 0), + block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D), + order=(1, 0), + ) + v_ptrs = tl.make_block_ptr( + base=v_ptr + k_start * stride_vn + pid_kh * stride_vh, + shape=(k_len, HEAD_DIM), + strides=(stride_vn, stride_vd), + offsets=(pid_k * BLOCK_SIZE_K, 0), + block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D), + order=(1, 0), + ) + dv_ptrs = tl.make_block_ptr( + base=dv_ptr + k_start * stride_dvn + pid_kh * stride_dvh + pid_sh * stride_dvs, + shape=(k_len, HEAD_DIM), + strides=(stride_dvn, stride_dvd), + offsets=(pid_k * BLOCK_SIZE_K, 0), + block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D), + order=(1, 0), + ) + # offsets + off_q = tl.arange(0, BLOCK_SIZE_Q) + off_k = tl.arange(0, BLOCK_SIZE_K) + pid_k * BLOCK_SIZE_K + off_d = tl.arange(0, BLOCK_SIZE_D) + # load k v and keep in SRAM + k = tl.load(k_ptrs, boundary_check=(0, 1), padding_option="zero") + v = tl.load(v_ptrs, boundary_check=(0, 1), padding_option="zero") + # init dk dv + dk = tl.zeros((BLOCK_SIZE_K, BLOCK_SIZE_D), dtype=tl.float32) + dv = tl.zeros((BLOCK_SIZE_K, BLOCK_SIZE_D), dtype=tl.float32) + # init ptrs + q_ptrs = q_ptr + q_start * stride_qn + pid_h * stride_qh + off_d[None, :] * stride_qd + do_ptrs = do_ptr + q_start * stride_don + pid_h * stride_doh + off_d[None, :] * stride_dod + d_ptrs = d_ptr + q_start * stride_dn + pid_h * stride_dh + lse_ptrs = lse_ptr + q_start * stride_ln + pid_h * stride_lh + # loop for q blocks + for i in range(0, act_q_len, BLOCK_SIZE_Q): + # load + idx_q = tl.load(tq_ptr + i + off_q, mask=off_q < act_q_len - i, other=0).to(tl.int32) + q = tl.load( + q_ptrs + idx_q[:, None] * stride_qn, + mask=(off_q < act_q_len - i)[:, None] & (off_d < HEAD_DIM)[None, :], + other=0, + ) + do = tl.load( + do_ptrs + idx_q[:, None] * stride_don, + mask=(off_q < act_q_len - i)[:, None] & (off_d < HEAD_DIM)[None, :], + other=0, + ) + lse = tl.load( + lse_ptrs + idx_q[:, None] * stride_ln, + mask=(off_q < act_q_len - i)[:, None], + other=0, + ) + d = tl.load( + d_ptrs + idx_q[:, None] * stride_dn, + mask=(off_q < act_q_len - i)[:, None], + other=0, + ) + # compute qk + qk = tl.zeros((BLOCK_SIZE_Q, BLOCK_SIZE_K), dtype=tl.float32) + qk += tl.where(idx_q[:, None] >= off_k[None, :], float(0.0), float("-inf")) + qk += tl.dot(q, k.T) * qk_scale + # compute p, ds + p = tl.exp2(qk - lse) + dp = tl.dot(do, v.T) + ds = sm_scale * p * (dp - d) + # cast dtype + p = p.to(do.dtype) + ds = ds.to(q.dtype) + # update dk and dv + dk += tl.dot(ds.T, q) + dv += tl.dot(p.T, do) + # save dk dv + tl.store(dk_ptrs, dk.to(dk_ptr.dtype.element_ty), boundary_check=(0, 1)) + tl.store(dv_ptrs, dv.to(dv_ptr.dtype.element_ty), boundary_check=(0, 1)) + + +def _topk_sparse_attention_bwd_opt( + o: torch.Tensor, + do: torch.Tensor, + lse: torch.Tensor, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + topk_idx: torch.Tensor, + block_size: int, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + max_seqlen_q: int, + max_seqlen_k: int, + sm_scale: float, + permute_results, +): + + assert block_size in {16, 32, 64, 128, 256} + q_len, num_q_heads, head_dim = q.shape + k_len, num_k_heads, head_dim = k.shape + v_len, num_v_heads, head_dim = v.shape + o_len, num_o_heads, head_dim = o.shape + num_share_q_heads = num_q_heads // num_k_heads + topk = topk_idx.shape[-1] + # compute D + delta = torch.zeros([num_o_heads, o_len], device=o.device, dtype=torch.float32) + BLOCK_SIZE_O = 256 + BLOCK_SIZE_D = triton.next_power_of_2(head_dim) + num_warps, num_stages = get_num_warps_stages(head_dim, BLOCK_SIZE_O, IS_HOPPER_GPU) + grid = (triton.cdiv(o_len, BLOCK_SIZE_O), num_o_heads) + backward_sum_o_do[grid]( + o, + do, + delta, + o_len, + head_dim, + o.stride(0), + o.stride(1), + o.stride(2), + do.stride(0), + do.stride(1), + do.stride(2), + delta.stride(0), + delta.stride(1), + BLOCK_SIZE_O=BLOCK_SIZE_O, + BLOCK_SIZE_D=BLOCK_SIZE_D, + num_warps=num_warps, + num_stages=num_stages, + ) + # count active querys for each key block, shape: (num_k_heads, total_k_blocks) + seqlens = cu_seqlens_q[1:] - cu_seqlens_q[:-1] + seqblocks = torch.ceil(seqlens / block_size).to(torch.int32) + cu_seqblocks = torch.cat( + [ + torch.zeros(1, dtype=torch.int32, device=topk_idx.device), + torch.cumsum(seqblocks, dim=0), + ] + ).to(torch.int32) + + topk_q_count = torch.cat( + [ + permute_results[i]['valid_lens_all'][:, : permute_results[i]['real_num_blocks']] + for i in range(len(permute_results)) + ], + dim=1, + ) + + cu_topk_q_count = torch.cat( + [ + torch.zeros(topk_q_count.shape[0], 1, dtype=torch.int32, device=topk_idx.device), + torch.cumsum(topk_q_count, dim=-1), + ], + dim=-1, + ).to(torch.int32) + # active query idx for each key block + # how to get active query idx for sequence b, head h, kv block i? + topk_q_idx = reorder_topk_idx(topk_idx, cu_topk_q_count, cu_seqlens_q, cu_seqblocks, block_size) + # compute dk dv + dk = torch.zeros(num_share_q_heads, k_len, num_k_heads, head_dim, device=k.device, dtype=k.dtype) + dv = torch.zeros(num_share_q_heads, k_len, num_k_heads, head_dim, device=k.device, dtype=k.dtype) + batch_size = cu_seqlens_q.shape[0] - 1 + BLOCK_SIZE_K = triton.next_power_of_2(block_size) + BLOCK_SIZE_Q = 64 + BLOCK_SIZE_D = triton.next_power_of_2(head_dim) + num_warps, num_stages = get_num_warps_stages(head_dim, BLOCK_SIZE_Q, IS_HOPPER_GPU) + grid = (batch_size, num_q_heads, triton.cdiv(max_seqlen_k, BLOCK_SIZE_K)) + backward_dkdv[grid]( + q, + k, + v, + topk_q_idx, + lse, + delta, + do, + dk, + dv, + cu_seqlens_q, + cu_seqlens_k, + cu_seqblocks, + cu_topk_q_count, + num_k_heads, + num_share_q_heads, + head_dim, + topk, + sm_scale, + q.stride(0), + q.stride(1), + q.stride(2), + k.stride(0), + k.stride(1), + k.stride(2), + v.stride(0), + v.stride(1), + v.stride(2), + topk_q_idx.stride(0), + topk_q_idx.stride(1), + cu_topk_q_count.stride(0), + cu_topk_q_count.stride(1), + lse.stride(0), + lse.stride(1), + delta.stride(0), + delta.stride(1), + do.stride(0), + do.stride(1), + do.stride(2), + dk.stride(0), + dk.stride(1), + dk.stride(2), + dk.stride(3), + dv.stride(0), + dv.stride(1), + dv.stride(2), + dv.stride(3), + BLOCK_SIZE_Q=BLOCK_SIZE_Q, + BLOCK_SIZE_K=BLOCK_SIZE_K, + BLOCK_SIZE_D=BLOCK_SIZE_D, + num_warps=num_warps, + num_stages=num_stages, + ) + dk = dk.sum(0) + dv = dv.sum(0) + # compute dq + dq = torch.zeros_like(q) + num_q_loop = max_seqlen_q // 32768 + 1 # calculate multiple querys in one kernel if seqlence length is too long + grid = (batch_size, num_k_heads, triton.cdiv(max_seqlen_q, num_q_loop)) + BLOCK_SIZE_K = block_size + BLOCK_SIZE_D = triton.next_power_of_2(head_dim) + num_warps, num_stages = get_num_warps_stages(head_dim, BLOCK_SIZE_K, IS_HOPPER_GPU) + + backward_dq_opt( + q, + k, + v, + topk_idx, + lse, + delta, + do, + dq, + cu_seqlens_q, + cu_seqlens_k, + num_k_heads, + num_share_q_heads, + head_dim, + topk, + sm_scale, + block_size, + permute_results, + ) + + return dq, dk, dv + + +class FSATopkSparseAttention(torch.autograd.Function): + @staticmethod + def forward( + ctx, + q: torch.Tensor, # [total_len, num_q_heads, head_dim] + k: torch.Tensor, # [total_len, num_k_heads, head_dim] + v: torch.Tensor, # [total_len, num_k_heads, head_dim] + topk_idx: torch.Tensor, # [num_kv_heads, total_len, topk] + block_size: int, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + max_seqlen_q: torch.Tensor, + max_seqlen_k: torch.Tensor, + sm_scale=None, + ): + # dtype check + assert q.dtype == torch.bfloat16 or q.dtype == torch.float16 + assert q.dtype == k.dtype and k.dtype == v.dtype + assert topk_idx.dtype == torch.int32 + assert cu_seqlens_q.dtype == torch.int32 and cu_seqlens_k.dtype == torch.int32 + # softmax scale + if sm_scale is None: + sm_scale = 1 / math.sqrt(q.shape[-1]) + + permute_results = None + + o, lse, permute_results = _topk_sparse_attention_fwd_opt( + q, + k, + v, + topk_idx, + block_size, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + sm_scale, + ) + + ctx.save_for_backward(q, k, v, o, lse, cu_seqlens_q, cu_seqlens_k, topk_idx) + ctx.permute_results = permute_results + ctx.sm_scale = sm_scale + ctx.max_seqlen_q = max_seqlen_q + ctx.max_seqlen_k = max_seqlen_k + ctx.block_size = block_size + return o + + @staticmethod + def backward(ctx, do: torch.Tensor, *args) -> Any: + q, k, v, o, lse, cu_seqlens_q, cu_seqlens_k, topk_idx = ctx.saved_tensors + permute_results = ctx.permute_results + + max_seqlen_q = ctx.max_seqlen_q + max_seqlen_k = ctx.max_seqlen_k + sm_scale = ctx.sm_scale + block_size = ctx.block_size + assert block_size in {16, 32, 64, 128, 256} + + dq, dk, dv = _topk_sparse_attention_bwd_opt( + o, + do, + lse, + q, + k, + v, + topk_idx, + block_size, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + sm_scale, + permute_results, + ) + return dq, dk, dv, None, None, None, None, None, None, None, None + + +def FSA_topk_sparse_attention( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + topk_idx: torch.Tensor, + block_size: int, + cu_seqlens: torch.Tensor, + softmax_scale: Optional[float] = None, +) -> torch.Tensor: + """Topk sparse attention varlen version implemented in triton. + + Args: + q (torch.Tensor): shape [total_len, num_q_heads, head_dim] + k (torch.Tensor): shape [total_len, num_kv_heads, head_dim] + v (torch.Tensor): shape [total_len, num_kv_heads, head_dim] + topk_idx (torch.Tensor): topk block idx for each query, shape [num_kv_heads, total_len, topk]. -1 means padding. + block_size (int): key value block size. + cu_seqlens (torch.Tensor): shape [batch_size + 1], similar to cu_seqlens in flash_attn_func_varlen. + softmax_scale (Optional[float], optional): Defaults to None, means 1/sqrt(head_dim). + + Returns: + torch.Tensor: attention output, shape [total_len, num_q_heads, head_dim] + """ + + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() + return FSATopkSparseAttention.apply( + q, + k, + v, + topk_idx, + block_size, + cu_seqlens, + cu_seqlens, + max_seqlen, + max_seqlen, + softmax_scale, + ) + + +def FSA_topk_sparse_attention_varlen( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + topk_idx: torch.Tensor, + block_size: int, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + softmax_scale: Optional[float] = None, +) -> torch.Tensor: + """FSA topk sparse attention with separate Q and K sequence lengths (for extend/prefill). + + Args: + q (torch.Tensor): shape [total_q, num_q_heads, head_dim] + k (torch.Tensor): shape [total_k, num_kv_heads, head_dim] + v (torch.Tensor): shape [total_k, num_kv_heads, head_dim] + topk_idx (torch.Tensor): topk block idx for each query, shape [num_kv_heads, total_q, topk]. -1 means padding. + block_size (int): key value block size. + cu_seqlens_q (torch.Tensor): shape [batch_size + 1], cumulative Q sequence lengths. + cu_seqlens_k (torch.Tensor): shape [batch_size + 1], cumulative K sequence lengths. + softmax_scale (Optional[float], optional): Defaults to None, means 1/sqrt(head_dim). + + Returns: + torch.Tensor: attention output, shape [total_q, num_q_heads, head_dim] + """ + max_seqlen_q = (cu_seqlens_q[1:] - cu_seqlens_q[:-1]).max().item() + max_seqlen_k = (cu_seqlens_k[1:] - cu_seqlens_k[:-1]).max().item() + return FSATopkSparseAttention.apply( + q, + k, + v, + topk_idx, + block_size, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + softmax_scale, + ) diff --git a/vortex_torch/attention_backend/fsa/__init__.py b/vortex_torch/attention_backend/fsa/__init__.py new file mode 100644 index 0000000..9efd474 --- /dev/null +++ b/vortex_torch/attention_backend/fsa/__init__.py @@ -0,0 +1,9 @@ +from .FSA_topk_sparse_attention import ( + FSA_topk_sparse_attention, + FSA_topk_sparse_attention_varlen, +) + +__all__ = [ + "FSA_topk_sparse_attention", + "FSA_topk_sparse_attention_varlen", +] diff --git a/vortex_torch/attention_backend/nsa/__init__.py b/vortex_torch/attention_backend/nsa/__init__.py new file mode 100644 index 0000000..382da01 --- /dev/null +++ b/vortex_torch/attention_backend/nsa/__init__.py @@ -0,0 +1,9 @@ +from .topk_sparse_attention import ( + topk_sparse_attention, + topk_sparse_attention_varlen, +) + +__all__ = [ + "topk_sparse_attention", + "topk_sparse_attention_varlen", +] diff --git a/vortex_torch/attention_backend/nsa/topk_sparse_attention.py b/vortex_torch/attention_backend/nsa/topk_sparse_attention.py new file mode 100644 index 0000000..57a2be7 --- /dev/null +++ b/vortex_torch/attention_backend/nsa/topk_sparse_attention.py @@ -0,0 +1,1280 @@ +# Copyright 2025 Xunhao Lai & Jianqiao Lu. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math +from typing import Any, Optional + +import torch +import triton +import triton.language as tl + +def is_hopper_gpu(): + if torch.cuda.is_available(): + device_capability = torch.cuda.get_device_capability(0) + major, minor = device_capability + return major == 9 + return False + + +def get_num_warps_stages(head_dim, block_size, is_hopper_gpu): + head_large = head_dim > 64 + block_large = block_size > 64 + if is_hopper_gpu: + if head_large and block_large: + num_warps, num_stages = 8, 3 + elif head_large or block_large: + num_warps, num_stages = 4, 3 + else: + num_warps, num_stages = 2, 2 + else: + if head_large and block_large: + num_warps, num_stages = 8, 3 + elif head_large or block_large: + num_warps, num_stages = 8, 3 + else: + num_warps, num_stages = 2, 2 + return num_warps, num_stages + + +IS_HOPPER_GPU = is_hopper_gpu() + + +@triton.jit +def forward_kernel_orig( + q_ptr, # Q: n x h x d + k_ptr, # K: n x kh x d + v_ptr, # V: n x kh x d + t_ptr, # topk_idx: kh x n x k + o_ptr, # O: n x h x d + lse_ptr, # LSE: h x n + # seqlens + cu_seqlens_q, + cu_seqlens_k, + # shape + NUM_KV_HEADS, + NUM_SHARE_Q_HEADS, + HEAD_DIM, + TOPK, + block_size, + # sm_scale + sm_scale, + # stride + stride_qn, + stride_qh, + stride_qd, + stride_kn, + stride_kh, + stride_kd, + stride_vn, + stride_vh, + stride_vd, + stride_th, + stride_tn, + stride_tk, + stride_on, + stride_oh, + stride_od, + stride_lh, + stride_ln, + # META parameters + # q loop num + num_q_loop: tl.constexpr, + num_k_loop: tl.constexpr, + MAX_SEQ_LEN: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, # k block size + BLOCK_SIZE_D: tl.constexpr, + BLOCK_SIZE_H: tl.constexpr, + BLOCK_SIZE_T: tl.constexpr, +): + qk_scale = sm_scale * 1.44269504 + # get batch id and head id + pid = tl.program_id(0) + + Q = MAX_SEQ_LEN // num_q_loop + HK = NUM_KV_HEADS // num_k_loop + + # 第几个 (b, kh_chunk, q_chunk) + pid_b = pid // (HK * Q) + pid_kh_chunk = (pid % (HK * Q)) // Q # 每个block处理num_k_loop个KV head + pid_q = pid % Q + + # get q k start and len after rmpad + q_start = tl.load(cu_seqlens_q + pid_b) + q_len = tl.load(cu_seqlens_q + pid_b + 1) - q_start + k_start = tl.load(cu_seqlens_k + pid_b) + k_len = tl.load(cu_seqlens_k + pid_b + 1) - k_start + + if pid_q * num_q_loop >= q_len: + return + real_q_loop = min(num_q_loop, q_len - pid_q * num_q_loop) + + for kh_offset in range(num_k_loop): + pid_kh = pid_kh_chunk * num_k_loop + kh_offset + pid_h = pid_kh * NUM_SHARE_Q_HEADS + + for j in range(real_q_loop): + pid_q_j = pid_q * num_q_loop + j + # init topk idx pointer + off_t = tl.arange(0, BLOCK_SIZE_T) + t_ptr_j = t_ptr + (q_start + pid_q_j) * stride_tn + pid_kh * stride_th + topk_idx = tl.load(t_ptr_j + off_t * stride_tk, mask=off_t < TOPK, other=-1) + + """Removed causal attention, which should be: + real_topk = tl.sum( + tl.where((topk_idx >= 0) & (topk_idx <= pid_q_j // block_size), 1, 0), + axis=0, + ) + """ + # real_topk = tl.sum( + # tl.where((topk_idx >= 0), 1, 0), + # axis=0, + # ) + real_topk = tl.sum( + tl.where((topk_idx >= 0) & (topk_idx <= pid_q_j // block_size), 1, 0), + axis=0, + ) + # init qkv pointer + q_ptrs = tl.make_block_ptr( + base=q_ptr + (q_start + pid_q_j) * stride_qn + pid_h * stride_qh, + shape=(NUM_SHARE_Q_HEADS, HEAD_DIM), + strides=(stride_qh, stride_qd), + offsets=(0, 0), + block_shape=(BLOCK_SIZE_H, BLOCK_SIZE_D), + order=(1, 0), + ) + k_ptrs = tl.make_block_ptr( + base=k_ptr + k_start * stride_kn + pid_kh * stride_kh, + shape=(HEAD_DIM, k_len), + strides=(stride_kd, stride_kn), + offsets=(0, 0), + block_shape=(BLOCK_SIZE_D, BLOCK_SIZE_K), + order=(0, 1), + ) + v_ptrs = tl.make_block_ptr( + base=v_ptr + k_start * stride_vn + pid_kh * stride_vh, + shape=(k_len, HEAD_DIM), + strides=(stride_vn, stride_vd), + offsets=(0, 0), + block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D), + order=(1, 0), + ) + # load q + q = tl.load(q_ptrs, boundary_check=(0, 1), padding_option="zero") + # init statistics + off_h = tl.arange(0, BLOCK_SIZE_H) + off_k = tl.arange(0, BLOCK_SIZE_K) + m_i = tl.full((BLOCK_SIZE_H,), float("-inf"), dtype=tl.float32) + lse_i = tl.full((BLOCK_SIZE_H,), float("-inf"), dtype=tl.float32) + acc_o = tl.full((BLOCK_SIZE_H, BLOCK_SIZE_D), 0, dtype=tl.float32) + # sparse attention + for i in range(real_topk): + # get current block start index + c = tl.load(t_ptr_j).to(tl.int32) * BLOCK_SIZE_K + t_ptr_j = t_ptr_j + stride_tk + # load k + k = tl.load(tl.advance(k_ptrs, (0, c)), boundary_check=(1, 0), padding_option="zero") + # compute qk + qk = tl.zeros((BLOCK_SIZE_H, BLOCK_SIZE_K), dtype=tl.float32) + qk += tl.where((pid_q_j >= c + off_k)[None, :], 0, float("-inf")) + # [BLOCK_SIZE_H, HEAD_DIM] @ [HEAD_DIM, BLOCK_SIZE_K] -> [BLOCK_SIZE_H, BLOCK_SIZE_K] + qk += tl.dot(q, k) * qk_scale + # compute m_ij and l_ij + m_ij = tl.maximum(m_i, tl.max(qk, axis=1)) + p = tl.exp2(qk - m_ij[:, None]) + l_ij = tl.sum(p, axis=1) + # scale acc_o + acc_o_scale = tl.exp2(m_i - m_ij) + acc_o = acc_o * acc_o_scale[:, None] + # load v and update acc_o + v = tl.load(tl.advance(v_ptrs, (c, 0)), boundary_check=(0, 1), padding_option="zero") + p = p.to(v.dtype) + acc_o += tl.dot(p, v) + # update statistics + m_i = m_ij + lse_i = m_ij + tl.math.log2(tl.exp2(lse_i - m_ij) + l_ij) + + # final scale + acc_o = acc_o * tl.exp2(m_i - lse_i)[:, None] + # save output + o_ptrs = tl.make_block_ptr( + base=o_ptr + (q_start + pid_q_j) * stride_on + pid_h * stride_oh, + shape=(NUM_SHARE_Q_HEADS, HEAD_DIM), + strides=(stride_oh, stride_od), + offsets=(0, 0), + block_shape=(BLOCK_SIZE_H, BLOCK_SIZE_D), + order=(1, 0), + ) + tl.store(o_ptrs, acc_o.to(o_ptr.dtype.element_ty), boundary_check=(0, 1)) + # save lse + lse_ptrs = lse_ptr + (q_start + pid_q_j) * stride_ln + (pid_h + off_h) * stride_lh + tl.store(lse_ptrs, lse_i, mask=off_h < NUM_SHARE_Q_HEADS) + + +@triton.jit +def backward_sum_o_do( + o_ptr, # O: n x h x d + do_ptr, # dO: n x h x d + delta_ptr, # D: h x n + o_len, + HEAD_DIM, + stride_on, + stride_oh, + stride_od, + stride_don, + stride_doh, + stride_dod, + stride_dh, + stride_dn, + BLOCK_SIZE_O: tl.constexpr, + BLOCK_SIZE_D: tl.constexpr, +): + pid_n = tl.program_id(0) + pid_h = tl.program_id(1) + off_o = pid_n * BLOCK_SIZE_O + tl.arange(0, BLOCK_SIZE_O) + off_d = tl.arange(0, BLOCK_SIZE_D) + o = tl.load( + o_ptr + off_o[:, None] * stride_on + pid_h * stride_oh + off_d[None, :] * stride_od, + mask=(off_o[:, None] < o_len) & (off_d[None, :] < HEAD_DIM), + other=0, + ).to(tl.float32) + do = tl.load( + do_ptr + off_o[:, None] * stride_don + pid_h * stride_doh + off_d[None, :] * stride_dod, + mask=(off_o[:, None] < o_len) & (off_d[None, :] < HEAD_DIM), + other=0, + ).to(tl.float32) + delta = tl.sum(o * do, axis=1) + tl.store(delta_ptr + pid_h * stride_dh + off_o * stride_dn, delta, mask=off_o < o_len) + + +@triton.jit +def count_kernel( + x_ptr, # [num_kv_heads, total_len, topk] + y_ptr, # [num_kv_heads, total_blocks] + cu_seqlens, # [batch_size + 1] + cu_seqblocks, # [batch_size + 1] + topk, + stride_xh, + stride_xn, + stride_xk, + stride_yh, + stride_yn, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + BLOCK_SIZE_R: tl.constexpr, +): + pid_h = tl.program_id(0) + pid_b = tl.program_id(1) + # get start and len after rmpad + seq_start = tl.load(cu_seqlens + pid_b) + seq_len = tl.load(cu_seqlens + pid_b + 1) - seq_start + blocks_start = tl.load(cu_seqblocks + pid_b) + num_blocks = tl.load(cu_seqblocks + pid_b + 1) - blocks_start + # load x + off_k = tl.arange(0, BLOCK_SIZE_K) + off_n = tl.arange(0, BLOCK_SIZE_N) + x_ptr = x_ptr + pid_h * stride_xh + seq_start * stride_xn + x_ptrs = x_ptr + off_n[:, None] * stride_xn + off_k[None, :] * stride_xk + # init y + y = tl.zeros((BLOCK_SIZE_R,), dtype=tl.int32) + # loop + for i in range(0, seq_len, BLOCK_SIZE_N): + x = tl.load( + x_ptrs, + mask=(off_n < seq_len - i)[:, None] & (off_k < topk)[None, :], + other=-1, + ) + x = tl.ravel(x) + y += tl.histogram(x, BLOCK_SIZE_R) + x_ptrs += BLOCK_SIZE_N * stride_xn + # store result + off_r = tl.arange(0, BLOCK_SIZE_R) + y_ptr = y_ptr + pid_h * stride_yh + blocks_start * stride_yn + y_ptrs = y_ptr + off_r * stride_yn + tl.store(y_ptrs, y.to(y_ptr.dtype.element_ty), mask=off_r < num_blocks) + + +def count_query( + topk_idx: torch.Tensor, + cu_seqlens: torch.Tensor, + cu_seqblocks: torch.Tensor, + block_size: int, +): + num_kv_heads, total_len, topk = topk_idx.shape + seqlens = cu_seqlens[1:] - cu_seqlens[:-1] + seqblocks = cu_seqblocks[1:] - cu_seqblocks[:-1] + batch_size = seqlens.shape[0] + BLOCK_SIZE_K = triton.next_power_of_2(topk) + BLOCK_SIZE_N = triton.next_power_of_2(4096 // BLOCK_SIZE_K) + BLOCK_SIZE_R = triton.next_power_of_2(seqblocks.max().item() + 2) + active_query_count = torch.zeros(num_kv_heads, cu_seqblocks[-1], dtype=torch.int32, device=topk_idx.device) + grid = (num_kv_heads, batch_size) + count_kernel[grid]( + topk_idx, + active_query_count, + cu_seqlens, + cu_seqblocks, + topk, + topk_idx.stride(0), + topk_idx.stride(1), + topk_idx.stride(2), + active_query_count.stride(0), + active_query_count.stride(1), + BLOCK_SIZE_N=BLOCK_SIZE_N, + BLOCK_SIZE_K=BLOCK_SIZE_K, + BLOCK_SIZE_R=BLOCK_SIZE_R, + num_warps=4, + num_stages=3, + ) + return active_query_count + + +@triton.jit +def pad_topk_idx_kernel( + t_ptr, + p_ptr, + cu_seqlens, + topk, + stride_th, + stride_tn, + stride_tk, + stride_pb, + stride_ph, + stride_pn, + stride_pk, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_T: tl.constexpr, +): + pid_b = tl.program_id(0) + pid_h = tl.program_id(1) + pid_n = tl.program_id(2) + # get q start and len after rmpad + q_start = tl.load(cu_seqlens + pid_b) + q_len = tl.load(cu_seqlens + pid_b + 1) - q_start + if BLOCK_SIZE_N * pid_n >= q_len: + return + # init prts + t_ptrs = tl.make_block_ptr( + base=t_ptr + pid_h * stride_th + q_start * stride_tn, + shape=(q_len, topk), + strides=(stride_tn, stride_tk), + offsets=(pid_n * BLOCK_SIZE_N, 0), + block_shape=(BLOCK_SIZE_N, BLOCK_SIZE_T), + order=(1, 0), + ) + p_ptrs = tl.make_block_ptr( + base=p_ptr + pid_b * stride_pb + pid_h * stride_ph, + shape=(q_len, topk), + strides=(stride_pn, stride_pk), + offsets=(pid_n * BLOCK_SIZE_N, 0), + block_shape=(BLOCK_SIZE_N, BLOCK_SIZE_T), + order=(1, 0), + ) + # load and save + idxs = tl.load(t_ptrs, boundary_check=(0, 1)) + tl.store(p_ptrs, idxs, boundary_check=(0, 1)) + + +@triton.jit +def save_topk_idx_kernel( + p_ptr, + t_ptr, + cu_seqblocks, + cu_topk_q_count, + n_len, + stride_pb, + stride_ph, + stride_pn, + stride_th, + stride_tn, + stride_ch, + stride_cn, + BLOCK_SIZE_N: tl.constexpr, +): + pid_b = tl.program_id(0) + pid_h = tl.program_id(1) + pid_n = tl.program_id(2) + # get q start and len after rmpad + q_block_start = tl.load(cu_seqblocks + pid_b) + q_block_end = tl.load(cu_seqblocks + pid_b + 1) + c_start = tl.load(cu_topk_q_count + pid_h * stride_ch + q_block_start * stride_cn) + c_end = tl.load(cu_topk_q_count + pid_h * stride_ch + q_block_end * stride_cn) + c_len = c_end - c_start + if c_len <= 0: + return + if pid_n * BLOCK_SIZE_N >= c_len: + return + # init ptrs + p_ptrs = tl.make_block_ptr( + base=p_ptr + pid_b * stride_pb + pid_h * stride_ph + (n_len - c_len) * stride_pn, + shape=(c_len,), + strides=(stride_pn,), + offsets=(pid_n * BLOCK_SIZE_N,), + block_shape=(BLOCK_SIZE_N,), + order=(0,), + ) + t_ptrs = tl.make_block_ptr( + base=t_ptr + pid_h * stride_th + c_start * stride_tn, + shape=(c_len,), + strides=(stride_tn,), + offsets=(pid_n * BLOCK_SIZE_N,), + block_shape=(BLOCK_SIZE_N,), + order=(0,), + ) + # load and save + idxs = tl.load(p_ptrs, boundary_check=(0,)) + tl.store(t_ptrs, idxs, boundary_check=(0,)) + + +def reorder_topk_idx( + topk_idx: torch.Tensor, + cu_topk_q_count: torch.Tensor, + cu_seqlens: torch.Tensor, + cu_seqblocks: torch.Tensor, + block_size: int, +): + num_kv_heads, total_len, topk = topk_idx.shape + batch_size = cu_seqlens.shape[0] - 1 + seq_lens = cu_seqlens[1:] - cu_seqlens[:-1] + max_seqlen = seq_lens.max().item() + # pad shape [num_kv_heads, total_seqlen, topk] to [batch_size, num_kv_heads, max_seqlen, topk] + pad_topk_idx = torch.full( + (batch_size, num_kv_heads, max_seqlen, topk), + fill_value=-1, + device=topk_idx.device, + dtype=torch.int32, + ) + BLOCK_SIZE_T = triton.next_power_of_2(topk) + BLOCK_SIZE_N = min(triton.next_power_of_2(max_seqlen), triton.next_power_of_2(8192 // BLOCK_SIZE_T)) + grid = (batch_size, num_kv_heads, triton.cdiv(max_seqlen, BLOCK_SIZE_N)) + pad_topk_idx_kernel[grid]( + topk_idx, + pad_topk_idx, + cu_seqlens, + topk, + topk_idx.stride(0), + topk_idx.stride(1), + topk_idx.stride(2), + pad_topk_idx.stride(0), + pad_topk_idx.stride(1), + pad_topk_idx.stride(2), + pad_topk_idx.stride(3), + BLOCK_SIZE_N=BLOCK_SIZE_N, + BLOCK_SIZE_T=BLOCK_SIZE_T, + ) + # argsort + pad_topk_q_idx = pad_topk_idx.view(batch_size, num_kv_heads, -1).argsort(-1) // topk + pad_topk_q_idx = pad_topk_q_idx.to(torch.int32) + # save as remove pad version + topk_q_idx = torch.full( + (num_kv_heads, cu_topk_q_count[:, -1].max().item()), + fill_value=-1, + device=topk_idx.device, + dtype=torch.int32, + ) + max_len = (cu_topk_q_count[:, cu_seqblocks][:, 1:] - cu_topk_q_count[:, cu_seqblocks][:, :-1]).max().item() + BLOCK_SIZE_N = min(triton.next_power_of_2(max_len), 8192) + grid = (batch_size, num_kv_heads, triton.cdiv(max_len, BLOCK_SIZE_N)) + save_topk_idx_kernel[grid]( + pad_topk_q_idx, + topk_q_idx, + cu_seqblocks, + cu_topk_q_count, + pad_topk_q_idx.shape[-1], + pad_topk_q_idx.stride(0), + pad_topk_q_idx.stride(1), + pad_topk_q_idx.stride(2), + topk_q_idx.stride(0), + topk_q_idx.stride(1), + cu_topk_q_count.stride(0), + cu_topk_q_count.stride(1), + BLOCK_SIZE_N=BLOCK_SIZE_N, + ) + return topk_q_idx + + +@triton.jit +def backward_dkdv( + q_ptr, # Q: n x qh x d + k_ptr, # K: n x kh x d + v_ptr, # V: n x kh x d + tq_ptr, # topk_q_idx: kh x N + lse_ptr, # LSE: qh x n + d_ptr, # Delta: qh x n + do_ptr, + dk_ptr, # DK: sh x n x kh x d + dv_ptr, # DK: sh x n x kh x d + # seqlens + cu_seqlens_q, # [batch_size + 1] + cu_seqlens_k, # [batch_size + 1] + cu_seqblocks, # [batch_size + 1] + cu_topk_q_count, # [kh, total_blocks] + # shape + NUM_KV_HEADS, + NUM_SHARE_Q_HEADS, + HEAD_DIM, + TOPK, + # sm_scale + sm_scale, + # stride + stride_qn, + stride_qh, + stride_qd, + stride_kn, + stride_kh, + stride_kd, + stride_vn, + stride_vh, + stride_vd, + stride_tqh, + stride_tqn, + stride_ctqh, + stride_ctqn, + stride_lh, + stride_ln, + stride_dh, + stride_dn, + stride_don, + stride_doh, + stride_dod, + stride_dks, + stride_dkn, + stride_dkh, + stride_dkd, + stride_dvs, + stride_dvn, + stride_dvh, + stride_dvd, + # META parameters + BLOCK_SIZE_Q: tl.constexpr, # q block size + BLOCK_SIZE_K: tl.constexpr, # k block size + BLOCK_SIZE_D: tl.constexpr, +): + qk_scale = sm_scale * 1.44269504 + # get batch id and head id + pid_b = tl.program_id(0) + pid_h = tl.program_id(1) + pid_kh = pid_h // NUM_SHARE_Q_HEADS + pid_sh = pid_h % NUM_SHARE_Q_HEADS + pid_k = tl.program_id(2) + # get q k start and len after rmpad + q_start = tl.load(cu_seqlens_q + pid_b) + tl.load(cu_seqlens_q + pid_b + 1) - q_start + k_start = tl.load(cu_seqlens_k + pid_b) + k_len = tl.load(cu_seqlens_k + pid_b + 1) - k_start + if BLOCK_SIZE_K * pid_k >= k_len: + return + # get topk_q_idx + b_start = tl.load(cu_seqblocks + pid_b) # how many blocks before current sequence + act_q_start = tl.load(cu_topk_q_count + pid_kh * stride_ctqh + (b_start + pid_k) * stride_ctqn) + act_q_end = tl.load(cu_topk_q_count + pid_kh * stride_ctqh + (b_start + pid_k + 1) * stride_ctqn) + act_q_len = act_q_end - act_q_start + tq_ptr = tq_ptr + pid_kh * stride_tqh + act_q_start * stride_tqn + # init pointers + k_ptrs = tl.make_block_ptr( + base=k_ptr + k_start * stride_kn + pid_kh * stride_kh, + shape=(k_len, HEAD_DIM), + strides=(stride_kn, stride_kd), + offsets=(pid_k * BLOCK_SIZE_K, 0), + block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D), + order=(1, 0), + ) + dk_ptrs = tl.make_block_ptr( + base=dk_ptr + k_start * stride_dkn + pid_kh * stride_dkh + pid_sh * stride_dks, + shape=(k_len, HEAD_DIM), + strides=(stride_dkn, stride_dkd), + offsets=(pid_k * BLOCK_SIZE_K, 0), + block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D), + order=(1, 0), + ) + v_ptrs = tl.make_block_ptr( + base=v_ptr + k_start * stride_vn + pid_kh * stride_vh, + shape=(k_len, HEAD_DIM), + strides=(stride_vn, stride_vd), + offsets=(pid_k * BLOCK_SIZE_K, 0), + block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D), + order=(1, 0), + ) + dv_ptrs = tl.make_block_ptr( + base=dv_ptr + k_start * stride_dvn + pid_kh * stride_dvh + pid_sh * stride_dvs, + shape=(k_len, HEAD_DIM), + strides=(stride_dvn, stride_dvd), + offsets=(pid_k * BLOCK_SIZE_K, 0), + block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D), + order=(1, 0), + ) + # offsets + off_q = tl.arange(0, BLOCK_SIZE_Q) + off_k = tl.arange(0, BLOCK_SIZE_K) + pid_k * BLOCK_SIZE_K + off_d = tl.arange(0, BLOCK_SIZE_D) + # load k v and keep in SRAM + k = tl.load(k_ptrs, boundary_check=(0, 1), padding_option="zero") + v = tl.load(v_ptrs, boundary_check=(0, 1), padding_option="zero") + # init dk dv + dk = tl.zeros((BLOCK_SIZE_K, BLOCK_SIZE_D), dtype=tl.float32) + dv = tl.zeros((BLOCK_SIZE_K, BLOCK_SIZE_D), dtype=tl.float32) + # init ptrs + q_ptrs = q_ptr + q_start * stride_qn + pid_h * stride_qh + off_d[None, :] * stride_qd + do_ptrs = do_ptr + q_start * stride_don + pid_h * stride_doh + off_d[None, :] * stride_dod + d_ptrs = d_ptr + q_start * stride_dn + pid_h * stride_dh + lse_ptrs = lse_ptr + q_start * stride_ln + pid_h * stride_lh + # loop for q blocks + for i in range(0, act_q_len, BLOCK_SIZE_Q): + # load + idx_q = tl.load(tq_ptr + i + off_q, mask=off_q < act_q_len - i, other=0).to(tl.int32) + q = tl.load( + q_ptrs + idx_q[:, None] * stride_qn, + mask=(off_q < act_q_len - i)[:, None] & (off_d < HEAD_DIM)[None, :], + other=0, + ) + do = tl.load( + do_ptrs + idx_q[:, None] * stride_don, + mask=(off_q < act_q_len - i)[:, None] & (off_d < HEAD_DIM)[None, :], + other=0, + ) + lse = tl.load( + lse_ptrs + idx_q[:, None] * stride_ln, + mask=(off_q < act_q_len - i)[:, None], + other=0, + ) + d = tl.load( + d_ptrs + idx_q[:, None] * stride_dn, + mask=(off_q < act_q_len - i)[:, None], + other=0, + ) + # compute qk + qk = tl.zeros((BLOCK_SIZE_Q, BLOCK_SIZE_K), dtype=tl.float32) + qk += tl.where(idx_q[:, None] >= off_k[None, :], float(0.0), float("-inf")) + qk += tl.dot(q, k.T) * qk_scale + # compute p, ds + p = tl.exp2(qk - lse) + dp = tl.dot(do, v.T) + ds = sm_scale * p * (dp - d) + # cast dtype + p = p.to(do.dtype) + ds = ds.to(q.dtype) + # update dk and dv + dk += tl.dot(ds.T, q) + dv += tl.dot(p.T, do) + # save dk dv + tl.store(dk_ptrs, dk.to(dk_ptr.dtype.element_ty), boundary_check=(0, 1)) + tl.store(dv_ptrs, dv.to(dv_ptr.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.jit +def backward_dq( + q_ptr, # Q: n x qh x d + k_ptr, # K: n x kh x d + v_ptr, # V: n x kh x d + t_ptr, # topk_idx: kh x n x k + lse_ptr, # LSE: qh x n + d_ptr, # Delta: qh x n + do_ptr, + dq_ptr, + # seqlens + cu_seqlens_q, + cu_seqlens_k, + # shape + NUM_KV_HEADS, + NUM_SHARE_Q_HEADS, + HEAD_DIM, + TOPK, + # q loop num + num_q_loop, + # sm_scale + sm_scale, + # stride + stride_qn, + stride_qh, + stride_qd, + stride_kn, + stride_kh, + stride_kd, + stride_vn, + stride_vh, + stride_vd, + stride_th, + stride_tn, + stride_tk, + stride_lh, + stride_ln, + stride_dh, + stride_dn, + stride_don, + stride_doh, + stride_dod, + stride_dqn, + stride_dqh, + stride_dqd, + # META parameters + BLOCK_SIZE_K: tl.constexpr, # k block size + BLOCK_SIZE_D: tl.constexpr, + BLOCK_SIZE_H: tl.constexpr, + BLOCK_SIZE_T: tl.constexpr, +): + qk_scale = sm_scale * 1.44269504 + # get batch id and head id + pid_b = tl.program_id(0) + pid_kh = tl.program_id(1) + pid_q = tl.program_id(2) + pid_h = pid_kh * NUM_SHARE_Q_HEADS + # get q k start and len after rmpad + q_start = tl.load(cu_seqlens_q + pid_b) + q_len = tl.load(cu_seqlens_q + pid_b + 1) - q_start + k_start = tl.load(cu_seqlens_k + pid_b) + k_len = tl.load(cu_seqlens_k + pid_b + 1) - k_start + if pid_q * num_q_loop >= q_len: + return + real_q_loop = min(num_q_loop, q_len - pid_q * num_q_loop) + for j in range(real_q_loop): + pid_q_j = pid_q * num_q_loop + j + # init topk idx pointer + off_t = tl.arange(0, BLOCK_SIZE_T) + t_ptr_j = t_ptr + (q_start + pid_q_j) * stride_tn + pid_kh * stride_th + topk_idx = tl.load(t_ptr_j + off_t * stride_tk, mask=off_t < TOPK, other=-1) + + real_topk = tl.sum( + tl.where((topk_idx >= 0) & (topk_idx <= pid_q_j // BLOCK_SIZE_K), 1, 0), + axis=0, + ) + # init pointers + q_ptrs = tl.make_block_ptr( + base=q_ptr + (q_start + pid_q_j) * stride_qn + pid_h * stride_qh, + shape=(NUM_SHARE_Q_HEADS, HEAD_DIM), + strides=(stride_qh, stride_qd), + offsets=(0, 0), + block_shape=(BLOCK_SIZE_H, BLOCK_SIZE_D), + order=(1, 0), + ) + dq_ptrs = tl.make_block_ptr( + base=dq_ptr + (q_start + pid_q_j) * stride_dqn + pid_h * stride_dqh, + shape=(NUM_SHARE_Q_HEADS, HEAD_DIM), + strides=(stride_dqh, stride_dqd), + offsets=(0, 0), + block_shape=(BLOCK_SIZE_H, BLOCK_SIZE_D), + order=(1, 0), + ) + k_ptrs = tl.make_block_ptr( + base=k_ptr + k_start * stride_kn + pid_kh * stride_kh, + shape=(k_len, HEAD_DIM), + strides=(stride_kn, stride_kd), + offsets=(0, 0), + block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D), + order=(1, 0), + ) + v_ptrs = tl.make_block_ptr( + base=v_ptr + k_start * stride_vn + pid_kh * stride_vh, + shape=(HEAD_DIM, k_len), + strides=(stride_vd, stride_vn), + offsets=(0, 0), + block_shape=(BLOCK_SIZE_D, BLOCK_SIZE_K), + order=(0, 1), + ) + do_ptrs = tl.make_block_ptr( + base=do_ptr + (q_start + pid_q_j) * stride_don + pid_h * stride_doh, + shape=(NUM_SHARE_Q_HEADS, HEAD_DIM), + strides=(stride_doh, stride_dod), + offsets=(0, 0), + block_shape=(BLOCK_SIZE_H, BLOCK_SIZE_D), + order=(1, 0), + ) + d_ptrs = tl.make_block_ptr( + base=d_ptr + (q_start + pid_q_j) * stride_dn + pid_h * stride_dh, + shape=(NUM_SHARE_Q_HEADS, 1), + strides=(stride_dh, stride_dn), + offsets=(0, 0), + block_shape=(BLOCK_SIZE_H, 1), + order=(1, 0), + ) + lse_ptrs = tl.make_block_ptr( + base=lse_ptr + (q_start + pid_q_j) * stride_ln + pid_h * stride_lh, + shape=(NUM_SHARE_Q_HEADS, 1), + strides=(stride_lh, stride_ln), + offsets=(0, 0), + block_shape=(BLOCK_SIZE_H, 1), + order=(1, 0), + ) + # offsets + off_k = tl.arange(0, BLOCK_SIZE_K) + # load q, do, lse, delta, and keep in SRAM + q = tl.load(q_ptrs, boundary_check=(1, 0), padding_option="zero") + do = tl.load(do_ptrs, boundary_check=(0, 1), padding_option="zero") + lse = tl.load(lse_ptrs, boundary_check=(0, 1), padding_option="zero") + d = tl.load(d_ptrs, boundary_check=(0, 1), padding_option="zero") + # init dq + dq = tl.zeros((BLOCK_SIZE_H, BLOCK_SIZE_D), dtype=tl.float32) + # sparse + for i in range(real_topk): + # get current block start index + c = tl.load(t_ptr_j).to(tl.int32) * BLOCK_SIZE_K + t_ptr_j = t_ptr_j + stride_tk + # load + k = tl.load(tl.advance(k_ptrs, (c, 0)), boundary_check=(1, 0), padding_option="zero") + v = tl.load(tl.advance(v_ptrs, (0, c)), boundary_check=(0, 1), padding_option="zero") + # compute qk + qk = tl.zeros((BLOCK_SIZE_H, BLOCK_SIZE_K), dtype=tl.float32) + qk += tl.where((pid_q_j >= c + off_k)[None, :], 0, float("-inf")) + # [BLOCK_SIZE_H, HEAD_DIM] @ [BLOCK_SIZE_K, HEAD_DIM].T -> [BLOCK_SIZE_H, BLOCK_SIZE_K] + qk += tl.dot(q, tl.trans(k)) * qk_scale + # compute p, ds + p = tl.exp2(qk - lse) + dp = tl.dot(do, v) + ds = sm_scale * p * (dp - d) + # cast dtype + ds = ds.to(q.dtype) + # update dq + dq += tl.dot(ds, k) + # save dq + tl.store(dq_ptrs, dq.to(dq_ptr.dtype.element_ty), boundary_check=(0, 1)) + + +def _topk_sparse_attention_fwd( + q: torch.Tensor, # [total_len, num_q_heads, head_dim] + k: torch.Tensor, # [total_len, num_k_heads, head_dim] + v: torch.Tensor, # [total_len, num_k_heads, head_dim] + topk_idx: torch.Tensor, # [num_kv_heads, total_len, topk] + block_size: int, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + max_seqlen_q: int, + max_seqlen_k: int, + sm_scale: float, +): + # dtype check + assert k.dtype == q.dtype and v.dtype == q.dtype + assert cu_seqlens_q.dtype == torch.int32 and cu_seqlens_k.dtype == torch.int32 + assert block_size in {16, 32, 64, 128, 256} + # shape + q_len, num_q_heads, head_dim = q.shape + k_len, num_k_heads, head_dim = k.shape + v_len, num_v_heads, head_dim = v.shape + batch_size = cu_seqlens_q.shape[0] - 1 + # assert q_len == k_len and k_len == v_len + topk = topk_idx.shape[-1] + assert topk_idx.shape[0] == num_k_heads + assert topk_idx.shape[1] == q_len + # gqa + assert num_k_heads == num_v_heads + assert num_q_heads % num_k_heads == 0 + num_share_q_heads = num_q_heads // num_k_heads + # output tensor + o = torch.zeros_like(q) + + lse = torch.zeros(num_q_heads, q_len, dtype=torch.float32, device=q.device) + + # launch kernel + num_q_loop = num_k_loop = 1 + BLOCK_SIZE_K = triton.next_power_of_2(block_size) + BLOCK_SIZE_D = triton.next_power_of_2(head_dim) + BLOCK_SIZE_H = max(16, triton.next_power_of_2(num_share_q_heads)) + BLOCK_SIZE_T = triton.next_power_of_2(topk) + + def grid(meta): + grid = ( + batch_size * triton.cdiv(num_k_heads, num_k_loop) * triton.cdiv(max_seqlen_q, num_q_loop), + ) + return grid + + num_warps, num_stages = get_num_warps_stages(head_dim, block_size, IS_HOPPER_GPU) + forward_kernel_orig[grid]( + q, + k, + v, + topk_idx, + o, + lse, + cu_seqlens_q, + cu_seqlens_k, + num_k_heads, + num_share_q_heads, + head_dim, + topk, + block_size, + # num_q_loop, + sm_scale, + q.stride(0), + q.stride(1), + q.stride(2), + k.stride(0), + k.stride(1), + k.stride(2), + v.stride(0), + v.stride(1), + v.stride(2), + topk_idx.stride(0), + topk_idx.stride(1), + topk_idx.stride(2), + o.stride(0), + o.stride(1), + o.stride(2), + lse.stride(0), + lse.stride(1), + num_q_loop=num_q_loop, + num_k_loop=num_k_loop, + MAX_SEQ_LEN=max_seqlen_q, + BLOCK_SIZE_K=BLOCK_SIZE_K, + BLOCK_SIZE_D=BLOCK_SIZE_D, + BLOCK_SIZE_H=BLOCK_SIZE_H, + BLOCK_SIZE_T=BLOCK_SIZE_T, + num_warps=num_warps, + num_stages=num_stages, + ) + return o, lse + + +def _topk_sparse_attention_bwd( + o: torch.Tensor, + do: torch.Tensor, + lse: torch.Tensor, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + topk_idx: torch.Tensor, + block_size: int, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + max_seqlen_q: int, + max_seqlen_k: int, + sm_scale: float, +): + + assert block_size in {16, 32, 64, 128, 256} + q_len, num_q_heads, head_dim = q.shape + k_len, num_k_heads, head_dim = k.shape + v_len, num_v_heads, head_dim = v.shape + o_len, num_o_heads, head_dim = o.shape + num_share_q_heads = num_q_heads // num_k_heads + topk = topk_idx.shape[-1] + # compute D + delta = torch.zeros([num_o_heads, o_len], device=o.device, dtype=torch.float32) + BLOCK_SIZE_O = 256 + BLOCK_SIZE_D = triton.next_power_of_2(head_dim) + num_warps, num_stages = get_num_warps_stages(head_dim, BLOCK_SIZE_O, IS_HOPPER_GPU) + grid = (triton.cdiv(o_len, BLOCK_SIZE_O), num_o_heads) + + backward_sum_o_do[grid]( + o, + do, + delta, + o_len, + head_dim, + o.stride(0), + o.stride(1), + o.stride(2), + do.stride(0), + do.stride(1), + do.stride(2), + delta.stride(0), + delta.stride(1), + BLOCK_SIZE_O=BLOCK_SIZE_O, + BLOCK_SIZE_D=BLOCK_SIZE_D, + num_warps=num_warps, + num_stages=num_stages, + ) + # count active querys for each key block, shape: (num_k_heads, total_k_blocks) + seqlens = cu_seqlens_q[1:] - cu_seqlens_q[:-1] + seqblocks = torch.ceil(seqlens / block_size).to(torch.int32) + cu_seqblocks = torch.cat( + [ + torch.zeros(1, dtype=torch.int32, device=topk_idx.device), + torch.cumsum(seqblocks, dim=0), + ] + ).to(torch.int32) + + topk_q_count = count_query(topk_idx, cu_seqlens_q, cu_seqblocks, block_size) + + cu_topk_q_count = torch.cat( + [ + torch.zeros(topk_q_count.shape[0], 1, dtype=torch.int32, device=topk_idx.device), + torch.cumsum(topk_q_count, dim=-1), + ], + dim=-1, + ).to(torch.int32) + # active query idx for each key block + # how to get active query idx for sequence b, head h, kv block i? + topk_q_idx = reorder_topk_idx(topk_idx, cu_topk_q_count, cu_seqlens_q, cu_seqblocks, block_size) + # compute dk dv + dk = torch.zeros(num_share_q_heads, k_len, num_k_heads, head_dim, device=k.device, dtype=k.dtype) + dv = torch.zeros(num_share_q_heads, k_len, num_k_heads, head_dim, device=k.device, dtype=k.dtype) + batch_size = cu_seqlens_q.shape[0] - 1 + BLOCK_SIZE_K = triton.next_power_of_2(block_size) + BLOCK_SIZE_Q = 64 + BLOCK_SIZE_D = triton.next_power_of_2(head_dim) + num_warps, num_stages = get_num_warps_stages(head_dim, BLOCK_SIZE_Q, IS_HOPPER_GPU) + grid = (batch_size, num_q_heads, triton.cdiv(max_seqlen_k, BLOCK_SIZE_K)) + backward_dkdv[grid]( + q, + k, + v, + topk_q_idx, + lse, + delta, + do, + dk, + dv, + cu_seqlens_q, + cu_seqlens_k, + cu_seqblocks, + cu_topk_q_count, + num_k_heads, + num_share_q_heads, + head_dim, + topk, + sm_scale, + q.stride(0), + q.stride(1), + q.stride(2), + k.stride(0), + k.stride(1), + k.stride(2), + v.stride(0), + v.stride(1), + v.stride(2), + topk_q_idx.stride(0), + topk_q_idx.stride(1), + cu_topk_q_count.stride(0), + cu_topk_q_count.stride(1), + lse.stride(0), + lse.stride(1), + delta.stride(0), + delta.stride(1), + do.stride(0), + do.stride(1), + do.stride(2), + dk.stride(0), + dk.stride(1), + dk.stride(2), + dk.stride(3), + dv.stride(0), + dv.stride(1), + dv.stride(2), + dv.stride(3), + BLOCK_SIZE_Q=BLOCK_SIZE_Q, + BLOCK_SIZE_K=BLOCK_SIZE_K, + BLOCK_SIZE_D=BLOCK_SIZE_D, + num_warps=num_warps, + num_stages=num_stages, + ) + dk = dk.sum(0) + dv = dv.sum(0) + # compute dq + dq = torch.zeros_like(q) + num_q_loop = max_seqlen_q // 32768 + 1 # calculate multiple querys in one kernel if seqlence length is too long + grid = (batch_size, num_k_heads, triton.cdiv(max_seqlen_q, num_q_loop)) + BLOCK_SIZE_K = block_size + BLOCK_SIZE_D = triton.next_power_of_2(head_dim) + BLOCK_SIZE_H = max(16, triton.next_power_of_2(num_share_q_heads)) + BLOCK_SIZE_T = triton.next_power_of_2(topk) + num_warps, num_stages = get_num_warps_stages(head_dim, BLOCK_SIZE_K, IS_HOPPER_GPU) + + backward_dq[grid]( + q, + k, + v, + topk_idx, + lse, + delta, + do, + dq, + cu_seqlens_q, + cu_seqlens_k, + num_k_heads, + num_share_q_heads, + head_dim, + topk, + num_q_loop, + sm_scale, + q.stride(0), + q.stride(1), + q.stride(2), + k.stride(0), + k.stride(1), + k.stride(2), + v.stride(0), + v.stride(1), + v.stride(2), + topk_idx.stride(0), + topk_idx.stride(1), + topk_idx.stride(2), + lse.stride(0), + lse.stride(1), + delta.stride(0), + delta.stride(1), + do.stride(0), + do.stride(1), + do.stride(2), + dq.stride(0), + dq.stride(1), + dq.stride(2), + BLOCK_SIZE_K=BLOCK_SIZE_K, + BLOCK_SIZE_D=BLOCK_SIZE_D, + BLOCK_SIZE_H=BLOCK_SIZE_H, + BLOCK_SIZE_T=BLOCK_SIZE_T, + num_warps=num_warps, + num_stages=num_stages, + ) + return dq, dk, dv + + +class TopkSparseAttention(torch.autograd.Function): + @staticmethod + def forward( + ctx, + q: torch.Tensor, # [total_len, num_q_heads, head_dim] + k: torch.Tensor, # [total_len, num_k_heads, head_dim] + v: torch.Tensor, # [total_len, num_k_heads, head_dim] + topk_idx: torch.Tensor, # [num_kv_heads, total_len, topk] + block_size: int, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + max_seqlen_q: torch.Tensor, + max_seqlen_k: torch.Tensor, + sm_scale=None, + ): + # dtype check + assert q.dtype == torch.bfloat16 or q.dtype == torch.float16 + assert q.dtype == k.dtype and k.dtype == v.dtype + assert topk_idx.dtype == torch.int32 + assert cu_seqlens_q.dtype == torch.int32 and cu_seqlens_k.dtype == torch.int32 + # softmax scale + if sm_scale is None: + sm_scale = 1 / math.sqrt(q.shape[-1]) + + o, lse = _topk_sparse_attention_fwd( + q, + k, + v, + topk_idx, + block_size, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + sm_scale, + ) + + ctx.save_for_backward(q, k, v, o, lse, cu_seqlens_q, cu_seqlens_k, topk_idx) + ctx.sm_scale = sm_scale + ctx.max_seqlen_q = max_seqlen_q + ctx.max_seqlen_k = max_seqlen_k + ctx.block_size = block_size + return o + + @staticmethod + def backward(ctx, do: torch.Tensor, *args) -> Any: + q, k, v, o, lse, cu_seqlens_q, cu_seqlens_k, topk_idx = ctx.saved_tensors + + max_seqlen_q = ctx.max_seqlen_q + max_seqlen_k = ctx.max_seqlen_k + sm_scale = ctx.sm_scale + block_size = ctx.block_size + assert block_size in {16, 32, 64, 128, 256} + + dq, dk, dv = _topk_sparse_attention_bwd( + o, + do, + lse, + q, + k, + v, + topk_idx, + block_size, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + sm_scale, + ) + return dq, dk, dv, None, None, None, None, None, None, None, None + + +def topk_sparse_attention( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + topk_idx: torch.Tensor, + block_size: int, + cu_seqlens: torch.Tensor, + softmax_scale: Optional[float] = None, +) -> torch.Tensor: + """Topk sparse attention varlen version implemented in triton. + + Args: + q (torch.Tensor): shape [total_len, num_q_heads, head_dim] + k (torch.Tensor): shape [total_len, num_kv_heads, head_dim] + v (torch.Tensor): shape [total_len, num_kv_heads, head_dim] + topk_idx (torch.Tensor): topk block idx for each query, shape [num_kv_heads, total_len, topk]. -1 means padding. + block_size (int): key value block size. + cu_seqlens (torch.Tensor): shape [batch_size + 1], similar to cu_seqlens in flash_attn_func_varlen. + softmax_scale (Optional[float], optional): Defaults to None, means 1/sqrt(head_dim). + + Returns: + torch.Tensor: attention output, shape [total_len, num_q_heads, head_dim] + """ + + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() + return TopkSparseAttention.apply( + q, + k, + v, + topk_idx, + block_size, + cu_seqlens, + cu_seqlens, + max_seqlen, + max_seqlen, + softmax_scale, + ) + + +def topk_sparse_attention_varlen( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + topk_idx: torch.Tensor, + block_size: int, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + softmax_scale: Optional[float] = None, +) -> torch.Tensor: + """Topk sparse attention with separate Q and K sequence lengths (for extend/prefill). + + Same as topk_sparse_attention but accepts separate cu_seqlens for Q and K. + Useful when Q only covers new tokens while K covers all tokens (prefix + new). + + Args: + q (torch.Tensor): shape [total_q, num_q_heads, head_dim] + k (torch.Tensor): shape [total_k, num_kv_heads, head_dim] + v (torch.Tensor): shape [total_k, num_kv_heads, head_dim] + topk_idx (torch.Tensor): topk block idx for each query, shape [num_kv_heads, total_q, topk]. -1 means padding. + block_size (int): key value block size. + cu_seqlens_q (torch.Tensor): shape [batch_size + 1], cumulative Q sequence lengths. + cu_seqlens_k (torch.Tensor): shape [batch_size + 1], cumulative K sequence lengths. + softmax_scale (Optional[float], optional): Defaults to None, means 1/sqrt(head_dim). + + Returns: + torch.Tensor: attention output, shape [total_q, num_q_heads, head_dim] + """ + max_seqlen_q = (cu_seqlens_q[1:] - cu_seqlens_q[:-1]).max().item() + max_seqlen_k = (cu_seqlens_k[1:] - cu_seqlens_k[:-1]).max().item() + return TopkSparseAttention.apply( + q, + k, + v, + topk_idx, + block_size, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + softmax_scale, + ) diff --git a/vortex_torch/cache/__init__.py b/vortex_torch/cache/__init__.py index eddfa46..6b54905 100644 --- a/vortex_torch/cache/__init__.py +++ b/vortex_torch/cache/__init__.py @@ -29,11 +29,14 @@ from .matmul import GeMM from .elementwise import Relu, Silu, Sigmoid, Abs, Add_Mul from .elementwise_binary import Maximum, Minimum, Multiply, Add -from .triton_kernels import set_kv_buffer_launcher +from .triton_kernels import set_kv_buffer_launcher, set_kv_buffer_int8_launcher, set_kv_buffer_fp8_launcher, dequant_pages_to_bf16_inplace __all__ = [ "set_kv_buffer_launcher", + "set_kv_buffer_int8_launcher", + "set_kv_buffer_fp8_launcher", + "dequant_pages_to_bf16_inplace", "Mean", "Max", "Min", "L2Norm", "GeMM", "Relu", "Silu", "Sigmoid", "Abs", "Add_Mul", diff --git a/vortex_torch/cache/context.py b/vortex_torch/cache/context.py index ae2dd5c..3cdf095 100644 --- a/vortex_torch/cache/context.py +++ b/vortex_torch/cache/context.py @@ -10,17 +10,25 @@ class Context(ContextBase): """ __slots__ = ContextBase.__slots__ + ( - + #page infomation "max_new_tokens_per_batch", "page_size", "total_num_pages", - + #model infomation "head_dim", "head_num", - + # auxilary memory in graph "_aux_total_bytes", - - "_aux_total_flops" + + "_aux_total_flops", + + # Quantization: quant_type (0=none, 1=int8, 2=e4m3, 3=e5m2), + # kv_scale (per-tensor fp8 scale), kv_scale_ptr (per-token int8 scale tensor) + # fp8_type: 0=none, 1=e4m3, 2=e5m2 (encoding for Triton kernels) + "quant_type", + "kv_scale", + "kv_scale_ptr", + "fp8_type", ) @@ -36,7 +44,15 @@ def __init__(self) -> None: elif name == "_aux_total_flops": object.__setattr__(self, name, 0) # start from 0 flops elif name == "mode": - object.__setattr__(self, name, Mode.profile) + object.__setattr__(self, name, Mode.profile) + elif name == "quant_type": + object.__setattr__(self, name, 0) # 0 = none (bf16 default) + elif name == "kv_scale": + object.__setattr__(self, name, 1.0) # identity scale for bf16 + elif name == "kv_scale_ptr": + object.__setattr__(self, name, None) # per-token scale tensor (int8 only) + elif name == "fp8_type": + object.__setattr__(self, name, 0) # 0 = none (bf16 default) else: object.__setattr__(self, name, UNSET) diff --git a/vortex_torch/cache/reduce.py b/vortex_torch/cache/reduce.py index 3c4edf2..eb94795 100644 --- a/vortex_torch/cache/reduce.py +++ b/vortex_torch/cache/reduce.py @@ -345,8 +345,11 @@ def execute( ) output = self.output_buffer - # Launch the kernel/implementation: impl(x, output, loc, ctx, dim, reduce_type) - self.impl(x, output, loc, ctx, self.dim, self.reduce_type) + # Launch the kernel/implementation: impl(x, output, loc, ctx, dim, reduce_type, quant_type, scale, kv_scale_ptr) + quant_type = getattr(ctx, 'quant_type', 0) + scale = getattr(ctx, 'kv_scale', 1.0) + kv_scale_ptr = getattr(ctx, 'kv_scale_ptr', None) + self.impl(x, output, loc, ctx, self.dim, self.reduce_type, quant_type, scale, kv_scale_ptr) return output diff --git a/vortex_torch/cache/triton_kernels/__init__.py b/vortex_torch/cache/triton_kernels/__init__.py index 6bf6dfc..de4fcbd 100644 --- a/vortex_torch/cache/triton_kernels/__init__.py +++ b/vortex_torch/cache/triton_kernels/__init__.py @@ -1,4 +1,17 @@ -from .set_kv import set_kv_buffer_launcher - -__all__ = ["set_kv_buffer_launcher"] +from .set_kv import ( + set_kv_buffer_launcher, + set_kv_buffer_int8_launcher, + set_kv_buffer_fp8_launcher, + paged_decode, + dequant_pages_to_bf16, + dequant_pages_to_bf16_inplace, +) +__all__ = [ + "set_kv_buffer_launcher", + "set_kv_buffer_int8_launcher", + "set_kv_buffer_fp8_launcher", + "paged_decode", + "dequant_pages_to_bf16", + "dequant_pages_to_bf16_inplace", +] diff --git a/vortex_torch/cache/triton_kernels/reduce_impl.py b/vortex_torch/cache/triton_kernels/reduce_impl.py index 9921e08..0146af7 100644 --- a/vortex_torch/cache/triton_kernels/reduce_impl.py +++ b/vortex_torch/cache/triton_kernels/reduce_impl.py @@ -4,6 +4,17 @@ from ..context import Context from ...utils import ReduceType + +# --------------------------------------------------------------------------- +# Helper: Load a page block from src_ptr, handling bf16 / int8 / fp8-stored-as-uint8. +# QUANT_TYPE == 0 -> bf16 pointer, load normally +# QUANT_TYPE == 1 -> int8 pointer, dequant with per-row scale from kv_scale_ptr +# QUANT_TYPE == 2 -> uint8 pointer, bitcast to float8e4nv, dequant with per-tensor scale +# QUANT_TYPE == 3 -> uint8 pointer, bitcast to float8e5, dequant with per-tensor scale +# All quantised paths return a float32 tensor ready for reduction. +# --------------------------------------------------------------------------- + + @triton.jit def reduce_pp_kernel( x, output, loc, @@ -12,9 +23,12 @@ def reduce_pp_kernel( NUM_KV_HEAD: tl.constexpr, PAGE_SIZE: tl.constexpr, REDUCE_TYPE: tl.constexpr, # 0:Mean, 1:Max, 2:Min, 3:L2Norm -DIM: tl.constexpr # 1: over rows (axis=0) -> len x_D1; 2: over cols (axis=1) -> len x_D0 -): - +DIM: tl.constexpr, # 1: over rows (axis=0) -> len x_D1; 2: over cols (axis=1) -> len x_D0 +QUANT_TYPE: tl.constexpr, # 0: bf16, 1: int8, 2: e4m3, 3: e5m2 +scale, # float: 1.0 for bf16, kv_scale for fp8 +kv_scale_ptr, # pointer to per-token int8 scales (unused when QUANT_TYPE != 1) +): + token_id = tl.program_id(0) head_id = tl.program_id(1) @@ -29,7 +43,22 @@ def reduce_pp_kernel( rows = tl.arange(0, x_D0)[:, None] # [x_D0, 1] cols = tl.arange(0, x_D1)[None, :] # [1, x_D1] src_ptr = x + x_offset + rows * x_D1 + cols - page_block = tl.load(src_ptr) + + if QUANT_TYPE == 1: + # int8: load int8 values, dequant with per-row scale + raw = tl.load(src_ptr).to(tl.float32) + # Per-row scales stored at kv_scale_ptr[page_id * x_D0 + row] + scale_offset = page_id * x_D0 + tl.arange(0, x_D0) + row_scales = tl.load(kv_scale_ptr + scale_offset).to(tl.float32) # [x_D0] + page_block = raw * row_scales[:, None] # broadcast [x_D0, 1] + elif QUANT_TYPE == 2: + raw = tl.load(src_ptr) + page_block = raw.to(tl.float8e4nv, bitcast=True).to(tl.float32) * scale + elif QUANT_TYPE == 3: + raw = tl.load(src_ptr) + page_block = raw.to(tl.float8e5, bitcast=True).to(tl.float32) * scale + else: + page_block = tl.load(src_ptr) if DIM == 1: # reduce over rows -> axis=0 -> length x_D1 @@ -71,11 +100,14 @@ def reduce_pp( ctx: Context, dim: int, reduce_type: ReduceType, +quant_type: int = 0, +scale: float = 1.0, +kv_scale_ptr=None, ): - + NNZ = loc.shape[0] NUM_KV_HEAD = ctx.head_num - + reduce_pp_kernel[(NNZ, NUM_KV_HEAD)]( x=x, output=output, @@ -85,7 +117,10 @@ def reduce_pp( NUM_KV_HEAD=NUM_KV_HEAD, PAGE_SIZE=ctx.page_size, REDUCE_TYPE=reduce_type.value, - DIM=dim + DIM=dim, + QUANT_TYPE=quant_type, + scale=scale, + kv_scale_ptr=kv_scale_ptr if kv_scale_ptr is not None else x, ) @@ -97,11 +132,14 @@ def _reduce_pp( page_size: int, dim: int, reduce_type: ReduceType, +quant_type: int = 0, +scale: float = 1.0, +kv_scale_ptr=None, ): - + NNZ = loc.shape[0] NUM_KV_HEAD = num_kv_heads - + reduce_pp_kernel[(NNZ, NUM_KV_HEAD)]( x=x, output=output, @@ -111,7 +149,10 @@ def _reduce_pp( NUM_KV_HEAD=NUM_KV_HEAD, PAGE_SIZE=page_size, REDUCE_TYPE=reduce_type.value, - DIM=dim + DIM=dim, + QUANT_TYPE=quant_type, + scale=scale, + kv_scale_ptr=kv_scale_ptr if kv_scale_ptr is not None else x, ) @@ -124,9 +165,12 @@ def reduce_rp_kernel( NUM_KV_HEAD: tl.constexpr, PAGE_SIZE: tl.constexpr, REDUCE_TYPE: tl.constexpr, # 0: Mean, 1: Max, 2: Min, 3: L2Norm (not RMS) - DIM: tl.constexpr # 1: reduce over rows -> len x_D1; 2: reduce over cols -> len x_D0 + DIM: tl.constexpr, # 1: reduce over rows -> len x_D1; 2: reduce over cols -> len x_D0 + QUANT_TYPE: tl.constexpr, # 0: bf16, 1: int8, 2: e4m3, 3: e5m2 + scale, # float: 1.0 for bf16, kv_scale for fp8 + kv_scale_ptr, # pointer to per-token int8 scales (unused when QUANT_TYPE != 1) ): - + # Program IDs: # pid0 = token index (0 .. num_tokens-1) # pid1 = head index (0 .. NUM_KV_HEAD-1) @@ -156,7 +200,20 @@ def reduce_rp_kernel( # Load the full page block for this (token_id, head_id). # Assumes the page is full; add masks here if you have partial tiles. - page_block = tl.load(src_ptr) + if QUANT_TYPE == 1: + # int8: load int8 values, dequant with per-row scale + raw = tl.load(src_ptr).to(tl.float32) + scale_offset = page_id * x_D0 + tl.arange(0, x_D0) + row_scales = tl.load(kv_scale_ptr + scale_offset).to(tl.float32) + page_block = raw * row_scales[:, None] + elif QUANT_TYPE == 2: + raw = tl.load(src_ptr) + page_block = raw.to(tl.float8e4nv, bitcast=True).to(tl.float32) * scale + elif QUANT_TYPE == 3: + raw = tl.load(src_ptr) + page_block = raw.to(tl.float8e5, bitcast=True).to(tl.float32) * scale + else: + page_block = tl.load(src_ptr) # Reduction: if DIM == 1: @@ -196,7 +253,7 @@ def reduce_rp_kernel( # Write to output: layout [num_pages, x_D0] for DIM==2. dst_ptr = output + page_id * x_D0 + tl.arange(0, x_D0) tl.store(dst_ptr, reduce_vec) - + def reduce_rp( @@ -206,11 +263,14 @@ def reduce_rp( ctx: Context, dim: int, reduce_type: ReduceType, +quant_type: int = 0, +scale: float = 1.0, +kv_scale_ptr=None, ): - + NNZ = loc.shape[0] NUM_KV_HEAD = ctx.head_num - + reduce_rp_kernel[(NNZ, NUM_KV_HEAD)]( x=x, output=output, @@ -220,7 +280,10 @@ def reduce_rp( NUM_KV_HEAD=NUM_KV_HEAD, PAGE_SIZE=ctx.page_size, REDUCE_TYPE=reduce_type.value, - DIM=dim + DIM=dim, + QUANT_TYPE=quant_type, + scale=scale, + kv_scale_ptr=kv_scale_ptr if kv_scale_ptr is not None else x, ) @@ -232,11 +295,14 @@ def _reduce_rp( page_size: int, dim: int, reduce_type: ReduceType, +quant_type: int = 0, +scale: float = 1.0, +kv_scale_ptr=None, ): - + NNZ = loc.shape[0] NUM_KV_HEAD = num_kv_heads - + reduce_rp_kernel[(NNZ, NUM_KV_HEAD)]( x=x, output=output, @@ -246,7 +312,10 @@ def _reduce_rp( NUM_KV_HEAD=NUM_KV_HEAD, PAGE_SIZE=page_size, REDUCE_TYPE=reduce_type.value, - DIM=dim + DIM=dim, + QUANT_TYPE=quant_type, + scale=scale, + kv_scale_ptr=kv_scale_ptr if kv_scale_ptr is not None else x, ) @@ -258,7 +327,10 @@ def reduce_pr_kernel( NUM_KV_HEAD: tl.constexpr, PAGE_SIZE: tl.constexpr, REDUCE_TYPE: tl.constexpr, # 0: Mean, 1: Max, 2: Min, 3: L2Norm (not RMS) -DIM: tl.constexpr # 1: reduce over rows -> len x_D1; 2: reduce over cols -> len x_D0 +DIM: tl.constexpr, # 1: reduce over rows -> len x_D1; 2: reduce over cols -> len x_D0 +QUANT_TYPE: tl.constexpr, # 0: bf16, 1: int8, 2: e4m3, 3: e5m2 +scale, # float: 1.0 for bf16, kv_scale for fp8 +kv_scale_ptr, # pointer to per-token int8 scales (unused when QUANT_TYPE != 1) ): """ Layouts: @@ -297,7 +369,20 @@ def reduce_pr_kernel( src_ptr = x + x_offset + rows * x_D1 + cols # Load the full page block. Assumes full tiles; add masks if needed. - page_block = tl.load(src_ptr) + if QUANT_TYPE == 1: + # int8: load int8 values, dequant with per-row scale + raw = tl.load(src_ptr).to(tl.float32) + scale_offset = page_id * x_D0 + tl.arange(0, x_D0) + row_scales = tl.load(kv_scale_ptr + scale_offset).to(tl.float32) + page_block = raw * row_scales[:, None] + elif QUANT_TYPE == 2: + raw = tl.load(src_ptr) + page_block = raw.to(tl.float8e4nv, bitcast=True).to(tl.float32) * scale + elif QUANT_TYPE == 3: + raw = tl.load(src_ptr) + page_block = raw.to(tl.float8e5, bitcast=True).to(tl.float32) * scale + else: + page_block = tl.load(src_ptr) # --- Reduction & write-out --- if DIM == 1: @@ -344,11 +429,14 @@ def reduce_pr( ctx: Context, dim: int, reduce_type: ReduceType, +quant_type: int = 0, +scale: float = 1.0, +kv_scale_ptr=None, ): - + NNZ = loc.shape[0] NUM_KV_HEAD = ctx.head_num - + reduce_pr_kernel[(NNZ, NUM_KV_HEAD)]( x=x, output=output, @@ -358,9 +446,12 @@ def reduce_pr( NUM_KV_HEAD=NUM_KV_HEAD, PAGE_SIZE=ctx.page_size, REDUCE_TYPE=reduce_type.value, - DIM=dim + DIM=dim, + QUANT_TYPE=quant_type, + scale=scale, + kv_scale_ptr=kv_scale_ptr if kv_scale_ptr is not None else x, ) - + def _reduce_pr( x: torch.Tensor, output: torch.Tensor, @@ -369,11 +460,14 @@ def _reduce_pr( page_size: int, dim: int, reduce_type: ReduceType, +quant_type: int = 0, +scale: float = 1.0, +kv_scale_ptr=None, ): - + NNZ = loc.shape[0] NUM_KV_HEAD = num_kv_heads - + reduce_pr_kernel[(NNZ, NUM_KV_HEAD)]( x=x, output=output, @@ -383,7 +477,10 @@ def _reduce_pr( NUM_KV_HEAD=NUM_KV_HEAD, PAGE_SIZE=page_size, REDUCE_TYPE=reduce_type.value, - DIM=dim + DIM=dim, + QUANT_TYPE=quant_type, + scale=scale, + kv_scale_ptr=kv_scale_ptr if kv_scale_ptr is not None else x, ) @@ -395,7 +492,10 @@ def reduce_rr_kernel( NUM_KV_HEAD: tl.constexpr, PAGE_SIZE: tl.constexpr, REDUCE_TYPE: tl.constexpr, # 0: Mean, 1: Max, 2: Min, 3: L2Norm (not RMS) -DIM: tl.constexpr # 1: reduce over rows -> len x_D1; 2: reduce over cols -> len x_D0 +DIM: tl.constexpr, # 1: reduce over rows -> len x_D1; 2: reduce over cols -> len x_D0 +QUANT_TYPE: tl.constexpr, # 0: bf16, 1: int8, 2: e4m3, 3: e5m2 +scale, # float: 1.0 for bf16, kv_scale for fp8 +kv_scale_ptr, # pointer to per-token int8 scales (unused when QUANT_TYPE != 1) ): """ Layouts: @@ -420,7 +520,22 @@ def reduce_rr_kernel( rows = tl.arange(0, x_D0)[:, None] # [x_D0, 1] cols = tl.arange(0, x_D1)[None, :] # [1, x_D1] src_ptr = x + x_base + rows * x_D1 + cols - page_blk = tl.load(src_ptr) # assumes full page; add masks if needed + + if QUANT_TYPE == 1: + # int8: load int8 values, dequant with per-row scale + raw = tl.load(src_ptr).to(tl.float32) + page_id = (token_position // PAGE_SIZE) * NUM_KV_HEAD + head_id + scale_offset = page_id * x_D0 + tl.arange(0, x_D0) + row_scales = tl.load(kv_scale_ptr + scale_offset).to(tl.float32) + page_blk = raw * row_scales[:, None] + elif QUANT_TYPE == 2: + raw = tl.load(src_ptr) + page_blk = raw.to(tl.float8e4nv, bitcast=True).to(tl.float32) * scale + elif QUANT_TYPE == 3: + raw = tl.load(src_ptr) + page_blk = raw.to(tl.float8e5, bitcast=True).to(tl.float32) * scale + else: + page_blk = tl.load(src_ptr) # assumes full page; add masks if needed # ---- reduce ---- if DIM == 1: @@ -464,11 +579,14 @@ def reduce_rr( ctx: Context, dim: int, reduce_type: ReduceType, +quant_type: int = 0, +scale: float = 1.0, +kv_scale_ptr=None, ): - + NNZ = loc.shape[0] NUM_KV_HEAD = ctx.head_num - + reduce_rr_kernel[(NNZ, NUM_KV_HEAD)]( x=x, output=output, @@ -478,9 +596,12 @@ def reduce_rr( NUM_KV_HEAD=NUM_KV_HEAD, PAGE_SIZE=ctx.page_size, REDUCE_TYPE=reduce_type.value, - DIM=dim + DIM=dim, + QUANT_TYPE=quant_type, + scale=scale, + kv_scale_ptr=kv_scale_ptr if kv_scale_ptr is not None else x, ) - + def _reduce_rr( x: torch.Tensor, @@ -490,11 +611,14 @@ def _reduce_rr( page_size: int, dim: int, reduce_type: ReduceType, +quant_type: int = 0, +scale: float = 1.0, +kv_scale_ptr=None, ): - + NNZ = loc.shape[0] NUM_KV_HEAD = num_kv_heads - + reduce_rr_kernel[(NNZ, NUM_KV_HEAD)]( x=x, output=output, @@ -504,5 +628,8 @@ def _reduce_rr( NUM_KV_HEAD=NUM_KV_HEAD, PAGE_SIZE=page_size, REDUCE_TYPE=reduce_type.value, - DIM=dim - ) \ No newline at end of file + DIM=dim, + QUANT_TYPE=quant_type, + scale=scale, + kv_scale_ptr=kv_scale_ptr if kv_scale_ptr is not None else x, + ) diff --git a/vortex_torch/cache/triton_kernels/set_kv.py b/vortex_torch/cache/triton_kernels/set_kv.py index cfa3cab..58468cc 100644 --- a/vortex_torch/cache/triton_kernels/set_kv.py +++ b/vortex_torch/cache/triton_kernels/set_kv.py @@ -36,6 +36,93 @@ def set_kv_buffer_kernel( tl.store(dst_v_ptr, src_v) +@triton.jit +def set_kv_buffer_int8_kernel( + k_cache, # int8 paged K cache + v_cache, # int8 paged V cache + k_scale_cache, # fp16 per-token K scale [num_pages, page_size, 1] + v_scale_cache, # fp16 per-token V scale [num_pages, page_size, 1] + new_k, # bf16 input K [NNZ, NUM_KV_HEAD, HEAD_DIM] + new_v, # bf16 input V [NNZ, NUM_KV_HEAD, HEAD_DIM] + loc, # int64 token positions + NUM_KV_HEAD: tl.constexpr, + NNZ: tl.constexpr, + HEAD_DIM: tl.constexpr, + PAGE_SIZE: tl.constexpr +): + """Quantize bf16 K/V to int8 with per-token absmax scaling and write to paged buffers.""" + token_id = tl.program_id(0) + if token_id >= NNZ: + return + head_id = tl.program_id(1) + dim = tl.arange(0, HEAD_DIM) + + # Load bf16 source values + src_ptr = token_id * NUM_KV_HEAD * HEAD_DIM + head_id * HEAD_DIM + dim + src_k = tl.load(new_k + src_ptr).to(tl.float32) + src_v = tl.load(new_v + src_ptr).to(tl.float32) + + # Compute per-token absmax scale: scale = absmax / 127 + absmax_k = tl.max(tl.abs(src_k), axis=0) + absmax_v = tl.max(tl.abs(src_v), axis=0) + # Avoid division by zero + scale_k = absmax_k / 127.0 + 1e-10 + scale_v = absmax_v / 127.0 + 1e-10 + + # Quantize to int8: round(x / scale), clamp to [-128, 127] + q_k = tl.extra.cuda.libdevice.rint(src_k / scale_k) + q_k = tl.minimum(tl.maximum(q_k, -128.0), 127.0).to(tl.int8) + q_v = tl.extra.cuda.libdevice.rint(src_v / scale_v) + q_v = tl.minimum(tl.maximum(q_v, -128.0), 127.0).to(tl.int8) + + # Compute paged destination offset (same layout as bf16 kernel) + token_position = tl.load(loc + token_id) + page_id = token_position // PAGE_SIZE + in_page_offset = token_position % PAGE_SIZE + position_trans = page_id * (PAGE_SIZE * NUM_KV_HEAD) + head_id * PAGE_SIZE + in_page_offset + + # Write int8 values + dst_k_ptr = k_cache + position_trans * HEAD_DIM + dim + dst_v_ptr = v_cache + position_trans * HEAD_DIM + dim + tl.store(dst_k_ptr, q_k) + tl.store(dst_v_ptr, q_v) + + # Write per-token scales (fp16): shape [num_pages, page_size, 1] + # Layout: page_id * PAGE_SIZE + in_page_offset (flat per-head, one scale per token per head) + scale_offset = (page_id * NUM_KV_HEAD + head_id) * PAGE_SIZE + in_page_offset + tl.store(k_scale_cache + scale_offset, scale_k.to(tl.float16)) + tl.store(v_scale_cache + scale_offset, scale_v.to(tl.float16)) + + +def set_kv_buffer_int8_launcher( + k_cache: torch.Tensor, + v_cache: torch.Tensor, + k_scale_cache: torch.Tensor, + v_scale_cache: torch.Tensor, + new_k: torch.Tensor, + new_v: torch.Tensor, + loc: torch.LongTensor, + page_size: int +): + NNZ = loc.shape[0] + NUM_KV_HEAD = new_k.shape[1] + HEAD_DIM = new_k.shape[2] + + set_kv_buffer_int8_kernel[(NNZ, NUM_KV_HEAD)]( + k_cache, + v_cache, + k_scale_cache, + v_scale_cache, + new_k, + new_v, + loc, + NUM_KV_HEAD, + NNZ, + HEAD_DIM, + page_size + ) + + def set_kv_buffer_launcher( k_cache: torch.Tensor, v_cache: torch.Tensor, @@ -44,11 +131,11 @@ def set_kv_buffer_launcher( loc: torch.LongTensor, page_size: int ): - + NNZ = loc.shape[0] NUM_KV_HEAD = new_k.shape[1] HEAD_DIM = new_k.shape[2] - + set_kv_buffer_kernel[(NNZ, NUM_KV_HEAD)]( k_cache, v_cache, @@ -61,3 +148,591 @@ def set_kv_buffer_launcher( page_size ) + +@triton.jit +def set_kv_buffer_fp8_kernel( + k_cache, # uint8 paged K cache + v_cache, # uint8 paged V cache + new_k, # bf16 input K [NNZ, NUM_KV_HEAD, HEAD_DIM] + new_v, # bf16 input V [NNZ, NUM_KV_HEAD, HEAD_DIM] + loc, # int64 token positions + NUM_KV_HEAD: tl.constexpr, + NNZ: tl.constexpr, + HEAD_DIM: tl.constexpr, + PAGE_SIZE: tl.constexpr, + FP8_TYPE: tl.constexpr, # 1: e4m3 (max=448), 2: e5m2 (max=57344) + k_scale, # float: per-tensor scale for K quantization + v_scale, # float: per-tensor scale for V quantization +): + """Quantize bf16 K/V to fp8, bitcast to uint8, and scatter into paged cache.""" + token_id = tl.program_id(0) + if token_id >= NNZ: + return + head_id = tl.program_id(1) + dim = tl.arange(0, HEAD_DIM) + + # Load bf16 source values + src_ptr = token_id * NUM_KV_HEAD * HEAD_DIM + head_id * HEAD_DIM + dim + src_k = tl.load(new_k + src_ptr).to(tl.float32) + src_v = tl.load(new_v + src_ptr).to(tl.float32) + + # Scale down: quantized = real_value / scale + inv_k_scale = 1.0 / k_scale + inv_v_scale = 1.0 / v_scale + scaled_k = src_k * inv_k_scale + scaled_v = src_v * inv_v_scale + + # Clamp and cast to fp8, then bitcast to uint8 for storage + if FP8_TYPE == 1: + # e4m3: max = 448.0 + clamped_k = tl.minimum(tl.maximum(scaled_k, -448.0), 448.0) + clamped_v = tl.minimum(tl.maximum(scaled_v, -448.0), 448.0) + q_k = clamped_k.to(tl.float8e4nv).to(tl.uint8, bitcast=True) + q_v = clamped_v.to(tl.float8e4nv).to(tl.uint8, bitcast=True) + else: + # e5m2: max = 57344.0 + clamped_k = tl.minimum(tl.maximum(scaled_k, -57344.0), 57344.0) + clamped_v = tl.minimum(tl.maximum(scaled_v, -57344.0), 57344.0) + q_k = clamped_k.to(tl.float8e5).to(tl.uint8, bitcast=True) + q_v = clamped_v.to(tl.float8e5).to(tl.uint8, bitcast=True) + + # Compute paged destination offset + token_position = tl.load(loc + token_id) + page_id = token_position // PAGE_SIZE + in_page_offset = token_position % PAGE_SIZE + position_trans = page_id * (PAGE_SIZE * NUM_KV_HEAD) + head_id * PAGE_SIZE + in_page_offset + + # Write uint8 values + dst_k_ptr = k_cache + position_trans * HEAD_DIM + dim + dst_v_ptr = v_cache + position_trans * HEAD_DIM + dim + tl.store(dst_k_ptr, q_k) + tl.store(dst_v_ptr, q_v) + + +def set_kv_buffer_fp8_launcher( + k_cache: torch.Tensor, + v_cache: torch.Tensor, + new_k: torch.Tensor, + new_v: torch.Tensor, + loc: torch.LongTensor, + page_size: int, + k_scale: float, + v_scale: float, + fp8_type: int = 1, +): + """Quantize bf16 K/V to fp8, bitcast to uint8, and scatter into paged cache. + + Args: + fp8_type: 1 for e4m3 (default), 2 for e5m2. + k_scale: per-tensor scale used for K quantization. + v_scale: per-tensor scale used for V quantization. + """ + NNZ = loc.shape[0] + NUM_KV_HEAD = new_k.shape[1] + HEAD_DIM = new_k.shape[2] + + set_kv_buffer_fp8_kernel[(NNZ, NUM_KV_HEAD)]( + k_cache, v_cache, + new_k, new_v, + loc, + NUM_KV_HEAD, NNZ, HEAD_DIM, page_size, + FP8_TYPE=fp8_type, + k_scale=k_scale, + v_scale=v_scale, + ) + + +# --------------------------------------------------------------------------- +# Dequantization kernels (read direction: quantized paged cache → bf16) +# --------------------------------------------------------------------------- + +@triton.jit +def _dequant_pages_kernel( + src, # quantized paged buffer flat + src_scale, # per-token scale buffer flat (int8 only) + dst, # bf16 destination buffer flat + page_indices, # int32 page indices to dequant + NUM_PAGES, + PAGE_SIZE: tl.constexpr, + HEAD_DIM: tl.constexpr, + BLOCK_DIM: tl.constexpr, + QUANT_TYPE: tl.constexpr, # 1: int8, 2: e4m3, 3: e5m2 + tensor_scale, # float: per-tensor scale (fp8 only) + COMPACT: tl.constexpr, # True: compact dst; False: in-place dst +): + """Unified dequant kernel for selected pages → bf16. + + QUANT_TYPE==1: load int8, multiply by per-token scale from src_scale. + QUANT_TYPE==2: load uint8, bitcast to float8e4nv, multiply by tensor_scale. + QUANT_TYPE==3: load uint8, bitcast to float8e5, multiply by tensor_scale. + COMPACT==True: dst offset uses page_idx (compact buffer). + COMPACT==False: dst offset uses global_page_id (in-place). + """ + page_idx = tl.program_id(0) + token_idx = tl.program_id(1) + + if page_idx >= NUM_PAGES: + return + + global_page_id = tl.load(page_indices + page_idx) + dims = tl.arange(0, BLOCK_DIM) + mask_dim = dims < HEAD_DIM + + src_offset = (global_page_id * PAGE_SIZE + token_idx) * HEAD_DIM + dims + scale_offset = global_page_id * PAGE_SIZE + token_idx + + if QUANT_TYPE == 1: + val = tl.load(src + src_offset, mask=mask_dim, other=0).to(tl.float32) + scale = tl.load(src_scale + scale_offset).to(tl.float32) + result = (val * scale).to(tl.bfloat16) + elif QUANT_TYPE == 2: + raw = tl.load(src + src_offset, mask=mask_dim, other=0) + val = raw.to(tl.float8e4nv, bitcast=True).to(tl.float32) + result = (val * tensor_scale).to(tl.bfloat16) + else: # QUANT_TYPE == 3 + raw = tl.load(src + src_offset, mask=mask_dim, other=0) + val = raw.to(tl.float8e5, bitcast=True).to(tl.float32) + result = (val * tensor_scale).to(tl.bfloat16) + + if COMPACT: + dst_offset = (page_idx * PAGE_SIZE + token_idx) * HEAD_DIM + dims + else: + dst_offset = src_offset # same position as source + + tl.store(dst + dst_offset, result, mask=mask_dim) + + +def dequant_pages_to_bf16( + src: torch.Tensor, + src_scale: torch.Tensor, + page_indices: torch.Tensor, + page_size: int, + head_dim: int, + quant_type: int = 1, + tensor_scale: float = 1.0, + out: torch.Tensor = None, +) -> torch.Tensor: + """Dequant selected pages to compact bf16 buffer. + + Args: + quant_type: 1=int8 (per-token scale), 2=fp8 e4m3, 3=fp8 e5m2. + tensor_scale: per-tensor scale (fp8 only, ignored for int8). + out: optional pre-allocated bf16 buffer. + """ + num_accessed_pages = page_indices.shape[0] + if num_accessed_pages == 0: + if out is not None: + return out[:0] + return torch.empty((0, page_size, head_dim), dtype=torch.bfloat16, device=src.device) + + if out is not None: + dst = out[:num_accessed_pages] + else: + dst = torch.empty( + (num_accessed_pages, page_size, head_dim), + dtype=torch.bfloat16, + device=src.device, + ) + + BLOCK_DIM = triton.next_power_of_2(head_dim) + + grid = (num_accessed_pages, page_size) + _dequant_pages_kernel[grid]( + src, src_scale, dst, page_indices, + NUM_PAGES=num_accessed_pages, + PAGE_SIZE=page_size, + HEAD_DIM=head_dim, + BLOCK_DIM=BLOCK_DIM, + QUANT_TYPE=quant_type, + tensor_scale=tensor_scale, + COMPACT=True, + ) + + return dst + + +def dequant_pages_to_bf16_inplace( + src: torch.Tensor, + src_scale: torch.Tensor, + dst: torch.Tensor, + page_indices: torch.Tensor, + page_size: int, + head_dim: int, + quant_type: int = 1, + tensor_scale: float = 1.0, +) -> None: + """Dequant selected pages in-place (same page positions in dst). + + Args: + quant_type: 1=int8 (per-token scale), 2=fp8 e4m3, 3=fp8 e5m2. + tensor_scale: per-tensor scale (fp8 only, ignored for int8). + """ + num_pages = page_indices.shape[0] + if num_pages == 0: + return + + BLOCK_DIM = triton.next_power_of_2(head_dim) + + grid = (num_pages, page_size) + _dequant_pages_kernel[grid]( + src, src_scale, dst, page_indices, + NUM_PAGES=num_pages, + PAGE_SIZE=page_size, + HEAD_DIM=head_dim, + BLOCK_DIM=BLOCK_DIM, + QUANT_TYPE=quant_type, + tensor_scale=tensor_scale, + COMPACT=False, + ) + + +# --------------------------------------------------------------------------- +# Paged decode attention (unified quant_type-parameterized) +# --------------------------------------------------------------------------- + +_MIN_BLOCK_KV = 32 + + +@triton.jit +def _tanh(x): + return 2 * tl.sigmoid(2 * x) - 1 + + +@triton.jit +def _fwd_kernel_paged_decode_stage1( + Q, + K_Buffer, + V_Buffer, + K_Scale_Buffer, + V_Scale_Buffer, + sm_scale, + kv_indptr, + kv_indices, + last_page_len, + Att_Out, + Att_Lse, + num_kv_splits, + stride_qbs, + stride_qh, + stride_buf_kbs, + stride_buf_vbs, + stride_mid_ob, + stride_mid_oh, + stride_mid_os, + kv_group_num: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_DV: tl.constexpr, + BLOCK_N: tl.constexpr, + MIN_BLOCK_KV: tl.constexpr, + logit_cap: tl.constexpr, + Lk: tl.constexpr, + Lv: tl.constexpr, + PAGE_SIZE: tl.constexpr, + QUANT_TYPE: tl.constexpr, # 0: bf16, 1: int8, 2: e4m3, 3: e5m2 + tensor_scale, # per-tensor scale for fp8 +): + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + split_kv_id = tl.program_id(2) + + offs_d = tl.arange(0, BLOCK_DMODEL) + offs_dv = tl.arange(0, BLOCK_DV) + mask_d = offs_d < Lk + mask_dv = offs_dv < Lv + + cur_batch_kv_start_idx = tl.load(kv_indptr + cur_batch) + cur_batch_num_pages = tl.load(kv_indptr + cur_batch + 1) - cur_batch_kv_start_idx + cur_last_page_len = tl.load(last_page_len + cur_batch) + cur_batch_seq_len = (cur_batch_num_pages - 1) * PAGE_SIZE + cur_last_page_len + kv_splits = tl.load(num_kv_splits + cur_batch) + + off_q = cur_batch * stride_qbs + cur_head * stride_qh + offs_d + + kv_len_per_split = ( + tl.cdiv(tl.cdiv(cur_batch_seq_len, kv_splits), MIN_BLOCK_KV) * MIN_BLOCK_KV + ) + split_kv_start = kv_len_per_split * split_kv_id + split_kv_end = tl.minimum(split_kv_start + kv_len_per_split, cur_batch_seq_len) + + e_max = -float("inf") + e_sum = 0.0 + acc = tl.zeros([BLOCK_DV], dtype=tl.float32) + + if split_kv_end > split_kv_start: + q = tl.load(Q + off_q, mask=mask_d, other=0.0).to(tl.float32) + + for start_n in range(split_kv_start, split_kv_end, BLOCK_N): + offs_n = start_n + tl.arange(0, BLOCK_N) + mask_n = offs_n < split_kv_end + + page_indices_in_seq = offs_n // PAGE_SIZE + in_page_offsets = offs_n % PAGE_SIZE + page_ids = tl.load( + kv_indices + cur_batch_kv_start_idx + page_indices_in_seq, + mask=mask_n, other=0, + ) + kv_loc = page_ids * PAGE_SIZE + in_page_offsets + + # Load K with quant-type-dependent dequantization + offs_buf_k = kv_loc[:, None] * stride_buf_kbs + offs_d[None, :] + if QUANT_TYPE == 0: + k = tl.load( + K_Buffer + offs_buf_k, + mask=mask_n[:, None] & mask_d[None, :], other=0, + ).to(tl.float32) + elif QUANT_TYPE == 1: + k_int8 = tl.load( + K_Buffer + offs_buf_k, + mask=mask_n[:, None] & mask_d[None, :], other=0, + ).to(tl.float32) + k_scale = tl.load( + K_Scale_Buffer + kv_loc, mask=mask_n, other=1.0, + ).to(tl.float32) + k = k_int8 * k_scale[:, None] + elif QUANT_TYPE == 2: + raw = tl.load( + K_Buffer + offs_buf_k, + mask=mask_n[:, None] & mask_d[None, :], other=0, + ) + k = raw.to(tl.float8e4nv, bitcast=True).to(tl.float32) * tensor_scale + else: # QUANT_TYPE == 3 + raw = tl.load( + K_Buffer + offs_buf_k, + mask=mask_n[:, None] & mask_d[None, :], other=0, + ) + k = raw.to(tl.float8e5, bitcast=True).to(tl.float32) * tensor_scale + + qk = tl.sum(q[None, :] * k, 1) + qk *= sm_scale + + if logit_cap > 0: + qk = logit_cap * _tanh(qk / logit_cap) + + qk = tl.where(mask_n, qk, float("-inf")) + + # Load V with quant-type-dependent dequantization + offs_buf_v = kv_loc[:, None] * stride_buf_vbs + offs_dv[None, :] + if QUANT_TYPE == 0: + v = tl.load( + V_Buffer + offs_buf_v, + mask=mask_n[:, None] & mask_dv[None, :], other=0, + ).to(tl.float32) + elif QUANT_TYPE == 1: + v_int8 = tl.load( + V_Buffer + offs_buf_v, + mask=mask_n[:, None] & mask_dv[None, :], other=0, + ).to(tl.float32) + v_scale = tl.load( + V_Scale_Buffer + kv_loc, mask=mask_n, other=1.0, + ).to(tl.float32) + v = v_int8 * v_scale[:, None] + elif QUANT_TYPE == 2: + raw = tl.load( + V_Buffer + offs_buf_v, + mask=mask_n[:, None] & mask_dv[None, :], other=0, + ) + v = raw.to(tl.float8e4nv, bitcast=True).to(tl.float32) * tensor_scale + else: # QUANT_TYPE == 3 + raw = tl.load( + V_Buffer + offs_buf_v, + mask=mask_n[:, None] & mask_dv[None, :], other=0, + ) + v = raw.to(tl.float8e5, bitcast=True).to(tl.float32) * tensor_scale + + # Online softmax accumulation + n_e_max = tl.maximum(tl.max(qk, 0), e_max) + re_scale = tl.exp(e_max - n_e_max) + p = tl.exp(qk - n_e_max) + acc *= re_scale + acc += tl.sum(p[:, None] * v, 0) + + e_sum = e_sum * re_scale + tl.sum(p, 0) + e_max = n_e_max + + offs_mid_o = ( + cur_batch * stride_mid_ob + + cur_head * stride_mid_oh + + split_kv_id * stride_mid_os + + offs_dv + ) + + tl.store(Att_Out + offs_mid_o, acc / e_sum, mask=mask_dv) + + offs_mid_o_1 = ( + cur_batch * stride_mid_ob + + cur_head * stride_mid_oh + + split_kv_id * stride_mid_os + ) // Lv + + tl.store(Att_Lse + offs_mid_o_1, e_max + tl.log(e_sum)) + + +@triton.jit +def _fwd_kernel_paged_decode_stage2( + Mid_O, + Mid_O_1, + O, + kv_indptr, + last_page_len, + num_kv_splits, + stride_mid_ob, + stride_mid_oh, + stride_mid_os, + stride_obs, + stride_oh, + MAX_KV_SPLITS: tl.constexpr, + MIN_BLOCK_KV: tl.constexpr, + BLOCK_DV: tl.constexpr, + Lv: tl.constexpr, + PAGE_SIZE: tl.constexpr, +): + """Stage 2: Reduce split outputs via log-sum-exp merge.""" + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + + cur_batch_num_pages = tl.load(kv_indptr + cur_batch + 1) - tl.load(kv_indptr + cur_batch) + cur_last_page_len = tl.load(last_page_len + cur_batch) + cur_batch_seq_len = (cur_batch_num_pages - 1) * PAGE_SIZE + cur_last_page_len + kv_splits = tl.load(num_kv_splits + cur_batch) + + offs_d = tl.arange(0, BLOCK_DV) + mask_d = offs_d < Lv + + e_sum = 0.0 + e_max = -float("inf") + acc = tl.zeros([BLOCK_DV], dtype=tl.float32) + + offs_v = cur_batch * stride_mid_ob + cur_head * stride_mid_oh + offs_d + offs_logic = (cur_batch * stride_mid_ob + cur_head * stride_mid_oh) // Lv + kv_len_per_split = ( + tl.cdiv(tl.cdiv(cur_batch_seq_len, kv_splits), MIN_BLOCK_KV) * MIN_BLOCK_KV + ) + + for split_kv_id in range(0, MAX_KV_SPLITS): + split_kv_start = kv_len_per_split * split_kv_id + split_kv_end = tl.minimum(split_kv_start + kv_len_per_split, cur_batch_seq_len) + + if split_kv_end > split_kv_start: + tv = tl.load( + Mid_O + offs_v + split_kv_id * stride_mid_os, mask=mask_d, other=0.0 + ) + tlogic = tl.load(Mid_O_1 + offs_logic + split_kv_id * stride_mid_os // Lv) + n_e_max = tl.maximum(tlogic, e_max) + + old_scale = tl.exp(e_max - n_e_max) + acc *= old_scale + exp_logic = tl.exp(tlogic - n_e_max) + acc += exp_logic * tv + + e_sum = e_sum * old_scale + exp_logic + e_max = n_e_max + + tl.store( + O + cur_batch * stride_obs + cur_head * stride_oh + offs_d, + acc / e_sum, + mask=mask_d, + ) + + +def paged_decode( + q: torch.Tensor, + k_buffer: torch.Tensor, + v_buffer: torch.Tensor, + o: torch.Tensor, + kv_indptr: torch.Tensor, + kv_indices: torch.Tensor, + last_page_len: torch.Tensor, + num_kv_splits: torch.Tensor, + max_kv_splits: int, + sm_scale: float, + page_size: int, + quant_type: int = 0, + k_scale_buffer: torch.Tensor = None, + v_scale_buffer: torch.Tensor = None, + tensor_scale: float = 1.0, + logit_cap: float = 0.0, + att_out: torch.Tensor = None, + att_lse: torch.Tensor = None, +): + """Unified paged decode attention. + + Args: + quant_type: Controls K/V loading: + 0: bf16 (k_scale_buffer/v_scale_buffer unused) + 1: int8 with per-token scales (k_scale_buffer/v_scale_buffer required) + 2: fp8 e4m3 with per-tensor scale (tensor_scale required) + 3: fp8 e5m2 with per-tensor scale (tensor_scale required) + """ + batch = q.shape[0] + head_num = q.shape[1] + Lk = q.shape[2] + Lv = Lk + + BLOCK_DMODEL = triton.next_power_of_2(Lk) + BLOCK_DV = triton.next_power_of_2(Lv) + BLOCK_N = 128 + MAX_KV_SPLITS = max_kv_splits + + kv_group_num = head_num + num_warps = 4 + + if att_out is None: + att_out = torch.empty( + (batch, head_num, MAX_KV_SPLITS, Lv), + dtype=torch.float32, device=q.device, + ) + else: + att_out = att_out[:batch] + if att_lse is None: + att_lse = torch.empty( + (batch, head_num, MAX_KV_SPLITS), + dtype=torch.float32, device=q.device, + ) + else: + att_lse = att_lse[:batch] + + stride_buf_kbs = k_buffer.shape[-1] + stride_buf_vbs = v_buffer.shape[-1] + + # Use dummy tensors for scale buffers when not needed + _k_scale = k_scale_buffer if k_scale_buffer is not None else k_buffer + _v_scale = v_scale_buffer if v_scale_buffer is not None else v_buffer + + grid_stage1 = (batch, head_num, MAX_KV_SPLITS) + _fwd_kernel_paged_decode_stage1[grid_stage1]( + q, k_buffer, v_buffer, + _k_scale, _v_scale, + sm_scale, kv_indptr, kv_indices, last_page_len, + att_out, att_lse, num_kv_splits, + q.stride(0), q.stride(1), + stride_buf_kbs, stride_buf_vbs, + att_out.stride(0), att_out.stride(1), att_out.stride(2), + kv_group_num=kv_group_num, + BLOCK_DMODEL=BLOCK_DMODEL, + BLOCK_DV=BLOCK_DV, + BLOCK_N=BLOCK_N, + MIN_BLOCK_KV=_MIN_BLOCK_KV, + logit_cap=logit_cap, + num_warps=num_warps, + num_stages=2, + Lk=Lk, Lv=Lv, + PAGE_SIZE=page_size, + QUANT_TYPE=quant_type, + tensor_scale=tensor_scale, + ) + + grid_stage2 = (batch, head_num) + _fwd_kernel_paged_decode_stage2[grid_stage2]( + att_out, att_lse, o, + kv_indptr, last_page_len, num_kv_splits, + att_out.stride(0), att_out.stride(1), att_out.stride(2), + o.stride(0), o.stride(1), + MAX_KV_SPLITS=MAX_KV_SPLITS, + MIN_BLOCK_KV=_MIN_BLOCK_KV, + BLOCK_DV=BLOCK_DV, + Lv=Lv, + PAGE_SIZE=page_size, + num_warps=4, + num_stages=2, + ) + diff --git a/vortex_torch/flow/__init__.py b/vortex_torch/flow/__init__.py index b2fcadc..bb60b89 100644 --- a/vortex_torch/flow/__init__.py +++ b/vortex_torch/flow/__init__.py @@ -34,9 +34,11 @@ class BlockSparseAttention(vFlow): from .registry import register from .loader import build_vflow from . import algorithms +from . import external_algorithms __all__ = [ "vFlow", "register", "build_vflow", - "algorithms" + "algorithms", + "external_algorithms", ] \ No newline at end of file diff --git a/vortex_torch/flow/external_algorithms.py b/vortex_torch/flow/external_algorithms.py new file mode 100644 index 0000000..5f8935f --- /dev/null +++ b/vortex_torch/flow/external_algorithms.py @@ -0,0 +1,76 @@ +""" +External sparse attention algorithm registrations for NSA, FSA, and FlashMoBA. + +These vFlow subclasses use simple centroid-based routing for the DECODE path +(forward_indexer + forward_cache), identical to BlockSparseAttention. + +The EXTEND path (forward_extend) is handled directly in vtx_graph_backend.py +using each algorithm's own sparse attention kernel — these vFlow classes are +not involved in extend. +""" + +import torch +from typing import Dict, Tuple + +from .flow import vFlow +from ..indexer import topK, GeMV +from ..cache import Mean as CMean +from ..abs import ContextBase +from .registry import register + + +class _ExternalAlgoBase(vFlow): + """ + Base vFlow for external sparse attention algorithms (NSA, FSA, FlashMoBA). + + Decode routing: centroid-based (same as BlockSparseAttention). + Extend: bypassed — vtx_graph_backend dispatches to algorithm-specific kernels. + """ + + def __init__(self): + super().__init__() + self.gemv = GeMV() + self.output_func = topK() + self.reduction = CMean(dim=1) + + def forward_indexer( + self, + q: torch.Tensor, + o: torch.Tensor, + cache: Dict[str, torch.Tensor], + ctx: ContextBase, + ): + q_mean = q.mean(dim=1, keepdim=True) + score = self.gemv(q_mean, cache["centroids"], ctx=ctx) + self.output_func(score, o, ctx=ctx) + + def forward_cache( + self, + cache: Dict[str, torch.Tensor], + loc: torch.Tensor, + ctx: ContextBase, + ): + self.reduction(cache["k"], cache["centroids"], loc=loc, ctx=ctx) + + def create_cache(self, page_size: int, head_dim: int) -> Dict[str, Tuple[int, int]]: + return { + "centroids": (1, head_dim), + } + + +@register("nsa") +class NSASparseAttention(_ExternalAlgoBase): + """Naive Sparse Attention — decode uses centroid routing, extend uses NSA kernels.""" + pass + + +@register("fsa") +class FSASparseAttention(_ExternalAlgoBase): + """Flash Sparse Attention — decode uses centroid routing, extend uses FSA kernels.""" + pass + + +@register("flash_moba") +class FlashMoBASparseAttention(_ExternalAlgoBase): + """FlashMoBA — decode uses centroid routing, extend uses FlashMoBA kernels.""" + pass diff --git a/vortex_torch/flow/flow.py b/vortex_torch/flow/flow.py index 7efc80e..7da5c72 100644 --- a/vortex_torch/flow/flow.py +++ b/vortex_torch/flow/flow.py @@ -431,6 +431,7 @@ def run_indexer_virtual(self, group_size: int, page_size: int, head_dim: int): ctx.page_size = page_size ctx.max_num_pages = 0 ctx.max_num_pages_per_request = 0 + ctx.topk_type = "naive" device = "cuda" dtype = torch.bfloat16 diff --git a/vortex_torch/indexer/context.py b/vortex_torch/indexer/context.py index 6d3c586..17dea66 100644 --- a/vortex_torch/indexer/context.py +++ b/vortex_torch/indexer/context.py @@ -22,7 +22,9 @@ class Context(ContextBase): # hardware / paging "num_sms", "page_size", "max_num_pages", "max_num_pages_per_request", # misc - "indexer_dtype", "topk_val", "page_reserved_bos", "page_reserved_eos", + "indexer_dtype", "topk_val", "page_reserved_bos", "page_reserved_eos", "topk_type", + "topk_mapping_mode", "topk_mapping_power", + "topk_histogram_enabled", # auxilary memory in graph "_aux_total_bytes", @@ -68,6 +70,10 @@ class Context(ContextBase): topk_val: int #: Top-K value used in pruning or selection. page_reserved_bos: int #: Reserved page count for BOS (begin-of-sequence). page_reserved_eos: int #: Reserved page count for EOS (end-of-sequence). + topk_type: str #: TopK kernel type: "naive", "sglang" (unmapped) or "sglang_fused" (remap+topk). + topk_mapping_mode: int #: TopK mapping mode for sglang_fused (0=none, 3=power, 4=log, 6=asinh, 7=log1p, 9=erf, 10=tanh, 13=exp_stretch). + topk_mapping_power: float #: Hyperparameter (p / alpha / beta) for the active mapping mode. + topk_histogram_enabled: bool #: Enable histogram profiling during inference (default False). # --- auxiliary --- _aux_total_bytes: int #: Accumulated auxiliary memory in bytes. @@ -144,12 +150,16 @@ def create(self, parent: Any, model_runner: Any, *, overwrite: bool = False) -> self.page_reserved_bos = sa.vortex_page_reserved_bos self.page_reserved_eos = sa.vortex_page_reserved_eos + self.topk_type = getattr(sa, "vortex_topk_type", "naive") + self.topk_mapping_mode = getattr(sa, "vortex_topk_mapping_mode", 0) + self.topk_mapping_power = getattr(sa, "vortex_topk_mapping_power", 0.5) + self.topk_histogram_enabled = getattr(sa, "vortex_topk_histogram", False) + + device = getattr(model_runner, "device", "cpu") self.max_num_workloads = ( (self.max_num_pages // max(1, sa.vortex_lb_min_chunk_size)) + max_bs * self.num_kv_heads ) - - device = getattr(model_runner, "device", "cpu") self.winfo_q_indices = torch.zeros((self.max_num_workloads,), dtype=torch.int32, device=device) self.winfo_kv_offsets = torch.zeros((self.max_num_workloads,), dtype=torch.int32, device=device) self.winfo_kv_lens = torch.zeros((self.max_num_workloads,), dtype=torch.int32, device=device) diff --git a/vortex_torch/indexer/output_func.py b/vortex_torch/indexer/output_func.py index 5df795b..889e068 100644 --- a/vortex_torch/indexer/output_func.py +++ b/vortex_torch/indexer/output_func.py @@ -1,10 +1,21 @@ import torch -from typing import Dict, Callable, Optional +from typing import Dict, Callable, List, Optional from ..abs import vOp -from vortex_torch_C import topk_output +from vortex_torch_C import topk_output, topk_output_sglang, topk_output_sglang_fused, topk_profile_histogram from .context import Context from ..abs import vTensor, FORMAT +# --- Module-level histogram accumulator for offline calibration --- +_calibration_histograms: List[torch.Tensor] = [] + +def get_calibration_histograms() -> List[torch.Tensor]: + """Return collected histogram tensors (each [eff_bs, 256] int32 on CPU).""" + return _calibration_histograms + +def clear_calibration_histograms() -> None: + """Clear all collected calibration histograms.""" + _calibration_histograms.clear() + class topK(vOp): r""" Piecewise top-k dispatcher for packed sequences with reserved pages. @@ -75,13 +86,19 @@ class topK(vOp): """ # Dispatch by input format; only RAGGED is supported for now. - _impl_map: Dict[FORMAT, Callable] = { - FORMAT.RAGGED: topk_output, + _impl_map: Dict[FORMAT, Dict[str, Callable]] = { + FORMAT.RAGGED: { + "naive": topk_output, + "sglang": topk_output_sglang, + "sglang_fused": topk_output_sglang_fused, + }, } def __init__(self): super().__init__() self.impl: Optional[Callable] = None + self.topk_type: str = "naive" + self.last_histograms: Optional[torch.Tensor] = None # ---------------- profile ---------------- def profile(self, x: vTensor, o: vTensor, ctx: Context) -> None: @@ -152,7 +169,13 @@ def profile(self, x: vTensor, o: vTensor, ctx: Context) -> None: f"{prefix}no implementation for x._format={x_fmt}. " f"Available: {list(self._impl_map.keys())}" ) - self.impl = self._impl_map[x_fmt] + self.topk_type = getattr(ctx, "topk_type", "naive") + impl_variants = self._impl_map[x_fmt] + assert self.topk_type in impl_variants, ( + f"{prefix}no topk implementation for topk_type='{self.topk_type}'. " + f"Available: {list(impl_variants.keys())}" + ) + self.impl = impl_variants[self.topk_type] # ---- optional sanity checks on `o` ---- # We only assert device consistency and leave exact (S_pack, D0, D1) @@ -220,16 +243,85 @@ def execute(self, x: torch.Tensor, o: torch.Tensor, ctx: Context) -> torch.Tenso prefix = self._prefix() assert self.impl is not None, f"{prefix}execute called before profile() (impl is None)" - self.impl( - x, - ctx.dense_kv_indptr, - ctx.sparse_kv_indptr, - ctx.dense_kv_indices, - o, - ctx.batch_size * ctx.num_kv_heads, - ctx.topk_val, - ctx.page_reserved_bos, - ctx.page_reserved_eos, - ctx.max_num_pages_per_request, - ) + if self.topk_type == "sglang": + # topk_output_sglang: unmapped baseline (no remap). + self.impl( + x, + ctx.dense_kv_indptr, + ctx.sparse_kv_indptr, + ctx.dense_kv_indices, + o, + ctx.batch_size * ctx.num_kv_heads, + ctx.topk_val, + ctx.page_reserved_bos, + ctx.page_reserved_eos, + ctx.max_num_pages_per_request, + ) + elif self.topk_type == "sglang_fused": + # topk_output_sglang_fused: single-launch fused remap + topk. + mapping_mode = getattr(ctx, 'topk_mapping_mode', 0) + mapping_power = getattr( + ctx, 'topk_mapping_hparam', + getattr(ctx, 'topk_mapping_power', 0.5), + ) + self.impl( + x, + ctx.dense_kv_indptr, + ctx.sparse_kv_indptr, + ctx.dense_kv_indices, + o, + ctx.batch_size * ctx.num_kv_heads, + ctx.topk_val, + ctx.page_reserved_bos, + ctx.page_reserved_eos, + ctx.max_num_pages_per_request, + int(mapping_mode), + float(mapping_power), + ) + else: + # topk_output (naive): (x, dense_kv_indptr, dense_kv_indices, sparse_kv_indptr, sparse_kv_indices, ...) + self.impl( + x, + ctx.dense_kv_indptr, + ctx.dense_kv_indices, + ctx.sparse_kv_indptr, + o, + ctx.batch_size * ctx.num_kv_heads, + ctx.topk_val, + ctx.page_reserved_bos, + ctx.page_reserved_eos, + ctx.max_num_pages_per_request, + ) + + # Optional histogram profiling (default disabled, no overhead when off). + # Skip entirely during CUDA graph capture — allocations and D2H copies + # are not permitted while a stream is being captured. + if ( + getattr(ctx, 'topk_histogram_enabled', False) + and self.topk_type in ("sglang", "sglang_fused") + and not torch.cuda.is_current_stream_capturing() + ): + eff_bs = ctx.batch_size * ctx.num_kv_heads + self.last_histograms = torch.zeros(eff_bs, 256, dtype=torch.int32, device=x.device) + hist_mode = 0 + hist_power = 0.5 + if self.topk_type == "sglang_fused": + hist_mode = int(getattr(ctx, 'topk_mapping_mode', 0)) + hist_power = float(getattr( + ctx, 'topk_mapping_hparam', + getattr(ctx, 'topk_mapping_power', 0.5), + )) + topk_profile_histogram( + x, + ctx.dense_kv_indptr, + self.last_histograms, + eff_bs, + ctx.page_reserved_bos, + ctx.page_reserved_eos, + hist_mode, + hist_power, + ) + # Accumulate histograms for offline calibration + _calibration_histograms.append(self.last_histograms.cpu().clone()) + return o diff --git a/vortex_torch/indexer/utils_sglang.py b/vortex_torch/indexer/utils_sglang.py index 74b8cfe..343207f 100644 --- a/vortex_torch/indexer/utils_sglang.py +++ b/vortex_torch/indexer/utils_sglang.py @@ -40,7 +40,7 @@ def plan_decode( ctx.max_chunk_size, ctx.min_chunk_size ) - + ctx.set_batch_size(cached_seq_lens.shape[0]) diff --git a/vortex_torch/kernels/__init__.py b/vortex_torch/kernels/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vortex_torch/kernels/fsa/__init__.py b/vortex_torch/kernels/fsa/__init__.py new file mode 100644 index 0000000..25d5b3e --- /dev/null +++ b/vortex_torch/kernels/fsa/__init__.py @@ -0,0 +1,5 @@ +from .fused_score_kernels import _fused_attention_score_and_transform + +__all__ = [ + "_fused_attention_score_and_transform", +] diff --git a/vortex_torch/kernels/fsa/fused_score_kernels.py b/vortex_torch/kernels/fsa/fused_score_kernels.py new file mode 100644 index 0000000..f2a05ed --- /dev/null +++ b/vortex_torch/kernels/fsa/fused_score_kernels.py @@ -0,0 +1,300 @@ +# This file provides a fused implementation of computing attention score for selected attention indices. +# TODO: this implementation may incur illegal memory access issues, will be fixed. +import math + +import torch +import triton +import triton.language as tl + +from ..nsa.utils import is_hopper_gpu + +IS_HOPPER_GPU = is_hopper_gpu() + + +@triton.jit +def fused_score_kernel( + q_ptr, # q_len x h x d + k_ptr, # k_len x h x d + lse_ptr, # h x n + bs_ptr, # h x n x nb + offs_ptr, # BO + kernel_size, + kernel_stride, + num_offs, # BO + num_k_blocks, + # seqlens + cu_seqlens_q, + cu_seqlens_k, + # shape + NUM_KV_HEADS, # which is also num_q_heads + HEAD_DIM, + # sm_scale + sm_scale, + max_blocks, + pad_len, + block_size, + block_stride, + init_blocks, + local_blocks, + # stride + stride_qn, + stride_qh, + stride_qd, + stride_kn, + stride_kh, + stride_kd, + stride_lh, + stride_ln, + stride_bsh, + stride_bsq, + stride_bsnb, + # META parameters + BLOCK_SIZE_Q: tl.constexpr, # q block size + BLOCK_SIZE_K: tl.constexpr, # k block size + BLOCK_SIZE_D: tl.constexpr, +): + qk_scale = sm_scale * 1.44269504 + # get batch id and head id + pid_bkh = tl.program_id(0) + pid_b = pid_bkh // NUM_KV_HEADS + pid_kh = pid_bkh % NUM_KV_HEADS + pid_q = tl.program_id(1) + pid_k = tl.program_id(2) # the blocks id of k + # get q k start and len after rmpad + q_start = tl.load(cu_seqlens_q + pid_b) + q_len = tl.load(cu_seqlens_q + pid_b + 1) - q_start + k_start = tl.load(cu_seqlens_k + pid_b) + k_len = tl.load(cu_seqlens_k + pid_b + 1) - k_start + + k_start += pid_k * BLOCK_SIZE_K * num_k_blocks + if pid_q * BLOCK_SIZE_Q >= q_len or pid_k * BLOCK_SIZE_K >= k_len: + return + + q_ptrs = tl.make_block_ptr( + base=q_ptr + q_start * stride_qn + pid_kh * stride_qh, + shape=(q_len, HEAD_DIM), + strides=(stride_qn, stride_qd), + offsets=(pid_q * BLOCK_SIZE_Q, 0), + block_shape=(BLOCK_SIZE_Q, BLOCK_SIZE_D), + order=(1, 0), + ) + lse_ptrs = tl.make_block_ptr( + base=lse_ptr + q_start * stride_ln + pid_kh * stride_lh, + shape=(q_len, 1), + strides=(stride_ln, stride_lh), + offsets=(pid_q * BLOCK_SIZE_Q, 0), + block_shape=(BLOCK_SIZE_Q, 1), + order=(0, 1), + ) + # load q and lse + q = tl.load(q_ptrs, boundary_check=(0, 1), padding_option="zero") + lse = tl.load(lse_ptrs, boundary_check=(0, 1), padding_option="zero") + + for j in range(num_k_blocks): + k_start_j = k_start + j * BLOCK_SIZE_K + if k_start_j < k_len: + off_d = tl.arange(0, BLOCK_SIZE_D) + off_q = pid_q * BLOCK_SIZE_Q + tl.arange(0, BLOCK_SIZE_Q) + # k offsets + off_k = (k_start_j + tl.arange(0, BLOCK_SIZE_K)) * block_stride - pad_len + k_ptrs = k_ptr + pid_kh * stride_kh + off_k[None, :] * stride_kn + off_d[:, None] * stride_kd + causal_mask = off_q[:, None] >= (off_k * kernel_stride + kernel_size - 1)[None, :] + + # init block score + bs = tl.zeros((BLOCK_SIZE_Q, BLOCK_SIZE_K), dtype=tl.float32) + for i in range(num_offs): + k = tl.load(k_ptrs, mask=causal_mask, other=0) + w = tl.load(offs_ptr + i, mask=i < num_offs, other=0) + # compute qk + qk = tl.dot(q, k) * qk_scale + # compute score and apply weight + bs += w * tl.where(causal_mask, tl.exp2(qk - lse), 0) + + # increment pointers + off_k += 1 + k_ptrs = k_ptr + pid_kh * stride_kh + off_k[None, :] * stride_kn + off_d[:, None] * stride_kd + causal_mask = off_q[:, None] >= (off_k * kernel_stride + kernel_size - 1)[None, :] + + # init mask and local mask + off_bq = off_q // block_size + off_bk = tl.arange(0, BLOCK_SIZE_K) + bs = tl.where( + ( + (off_bq[:, None] >= k_start_j + off_bk[None, :]) + & (off_bq[:, None] < k_start_j + off_bk[None, :] + local_blocks) + ) + | (off_bk[None, :] < init_blocks - k_start_j), + float("inf"), + bs, + ) + + # save output + bs_ptrs = ( + bs_ptr + + pid_kh.to(tl.int64) * stride_bsh + + q_start * stride_bsq + + k_start_j * stride_bsnb + + off_q[:, None] * stride_bsq + + off_bk[None, :] * stride_bsnb + ) + + tl.store( + bs_ptrs, + bs.to(bs_ptr.dtype.element_ty), + mask=(off_q < q_len)[:, None] & (off_bk < max_blocks - k_start_j)[None, :], + ) + + +def _fused_attention_score_and_transform( + q: torch.Tensor, # [total_query_len, num_q_heads, head_dim] + k: torch.Tensor, # [total_key_len, num_k_heads, head_dim] + lse: torch.Tensor, # [num_q_heads, total_query_len] + kernel_size: int, + kernel_stride: int, + block_size: int, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + max_seqlen_q: int, + max_seqlen_k: int, + sm_scale: float, + init_blocks: int = 1, + local_blocks: int = 2, + align_baseline: bool = False, +) -> torch.Tensor: + + q_len, num_q_heads, head_dim = q.shape + k_len, num_k_heads, head_dim = k.shape + max_blocks = math.ceil(max_seqlen_q / block_size) + # init block score + block_scores = torch.zeros( + num_k_heads, + q_len, + max_blocks, + dtype=torch.float32 if align_baseline else torch.bfloat16, + device=q.device, + ) + offs = ( + torch.arange(kernel_size // kernel_stride, device=q.device)[:, None] + + torch.arange(block_size // kernel_stride, device=q.device)[None, :] + ).view(-1) + + offs = torch.histc(offs, bins=offs.max() + 1, min=0, max=offs.max()) + + num_offs = int(offs.shape[0]) + for i in range(cu_seqlens_q.shape[0] - 1): + q_seq = q[cu_seqlens_q[i]: cu_seqlens_q[i + 1]] + k_seq = k[cu_seqlens_k[i]: cu_seqlens_k[i + 1]] + lse_seq = lse[:, cu_seqlens_q[i]: cu_seqlens_q[i + 1]] + block_scores_seq = block_scores[:, cu_seqlens_q[i]: cu_seqlens_q[i + 1]] + + _fused_attention_score_and_transform_per_seq( + q_seq, + k_seq, + lse_seq, + block_scores_seq, + kernel_size, + kernel_stride, + block_size, + offs, + num_offs, + cu_seqlens_q[i: i + 2] - cu_seqlens_q[i], + cu_seqlens_k[i: i + 2] - cu_seqlens_k[i], + cu_seqlens_q[i + 1] - cu_seqlens_q[i], + cu_seqlens_k[i + 1] - cu_seqlens_k[i], + sm_scale, + init_blocks, + local_blocks, + ) + block_scores[:, cu_seqlens_q[i]: cu_seqlens_q[i + 1]] = block_scores_seq + return block_scores + + +@torch.inference_mode() +def _fused_attention_score_and_transform_per_seq( + q: torch.Tensor, # [total_query_len, num_q_heads, head_dim] + k: torch.Tensor, # [total_key_len, num_k_heads, head_dim] + lse: torch.Tensor, # [num_q_heads, total_query_len] + block_score: torch.Tensor, + kernel_size: int, + kernel_stride: int, + block_size: int, + offs, + num_offs, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + max_seqlen_q: int, + max_seqlen_k: int, + sm_scale: float, + init_blocks: int = 1, + local_blocks: int = 2, +) -> torch.Tensor: + # dtype check + assert q.dtype == torch.bfloat16 or q.dtype == torch.float16 + assert q.dtype == k.dtype + assert cu_seqlens_q.dtype == torch.int32 and cu_seqlens_k.dtype == torch.int32 + assert lse.dtype == torch.float32 # lse here is log2(sum(exp(qk*scale))), not log(sum(exp(qk*scale))) + # shape + q_len, num_q_heads, head_dim = q.shape + k_len, num_k_heads, head_dim = k.shape + batch_size = cu_seqlens_q.shape[0] - 1 + assert q_len > k_len + if sm_scale is None: + sm_scale = 1 / math.sqrt(head_dim) + + max_blocks = math.ceil(max_seqlen_q / block_size) + + pad_len = kernel_size // kernel_stride - 1 + max_blocks = math.ceil(max_seqlen_q / block_size) + + BLOCK_SIZE_K = min(128, triton.next_power_of_2(max_blocks)) + # ensure qk is valid on triton + BLOCK_SIZE_K = max(BLOCK_SIZE_K, 16) + BLOCK_SIZE_Q = 128 + + # launch kernel + num_k_blocks = 1 + grid = lambda META: ( + batch_size * num_k_heads, + triton.cdiv(max_seqlen_q, BLOCK_SIZE_Q), + triton.cdiv(max_blocks, BLOCK_SIZE_K * num_k_blocks), + ) + BLOCK_SIZE_D = triton.next_power_of_2(head_dim) + + fused_score_kernel[grid]( + q, + k, + lse, + block_score, + offs, + kernel_size, + kernel_stride, + num_offs, + num_k_blocks, + cu_seqlens_q, + cu_seqlens_k, + num_k_heads, + head_dim, + sm_scale, + max_blocks, + pad_len, + block_size, + block_size // kernel_stride, + init_blocks, + local_blocks, + q.stride(0), + q.stride(1), + q.stride(2), + k.stride(0), + k.stride(1), + k.stride(2), + lse.stride(0), + lse.stride(1), + block_score.stride(0), + block_score.stride(1), + block_score.stride(2), + BLOCK_SIZE_Q=BLOCK_SIZE_Q, + BLOCK_SIZE_K=BLOCK_SIZE_K, + BLOCK_SIZE_D=BLOCK_SIZE_D, + num_warps=8, + num_stages=3, + ) diff --git a/vortex_torch/kernels/nsa/__init__.py b/vortex_torch/kernels/nsa/__init__.py new file mode 100644 index 0000000..9af3029 --- /dev/null +++ b/vortex_torch/kernels/nsa/__init__.py @@ -0,0 +1,24 @@ +# Copyright 2025 Xunhao Lai. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .compressed_attention import compressed_attention +from .weighted_pool import (avgpool_compress, softmaxpool_compress, + weightedpool_compress) + +__all__ = [ + "compressed_attention", + "avgpool_compress", + "weightedpool_compress", + "softmaxpool_compress", +] diff --git a/vortex_torch/kernels/nsa/compressed_attention.py b/vortex_torch/kernels/nsa/compressed_attention.py new file mode 100644 index 0000000..9770a94 --- /dev/null +++ b/vortex_torch/kernels/nsa/compressed_attention.py @@ -0,0 +1,1317 @@ +# Copyright 2025 Xunhao Lai & Jianqiao Lu. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math +import warnings +from typing import Any, Tuple, Union + +import torch +import triton +import triton.language as tl + +from .utils import get_num_warps_stages, is_hopper_gpu + +IS_HOPPER_GPU = is_hopper_gpu() + + +@triton.jit +def forward_kernel( + q_ptr, # Q: n x h x d + k_ptr, # K: n x h x d + v_ptr, # V: n x h x d + o_ptr, # O: n x h x d + lse_ptr, # LSE: h x n + # size and stride at compresstion + kernel_size, + kernel_stride, + # seqlens + cu_seqlens_q, + cu_seqlens_k, + # shape + NUM_KV_HEADS, + NUM_SHARE_Q_HEADS, + HEAD_DIM, + # sm_scale + sm_scale, + # stride + stride_qn, + stride_qh, + stride_qd, + stride_kn, + stride_kh, + stride_kd, + stride_vn, + stride_vh, + stride_vd, + stride_on, + stride_oh, + stride_od, + stride_lh, + stride_ln, + # META parameters + BLOCK_SIZE_Q: tl.constexpr, # q block size + BLOCK_SIZE_K: tl.constexpr, # k block size + BLOCK_SIZE_D: tl.constexpr, +): + qk_scale = sm_scale * 1.44269504 + # get batch id and head id + pid_b = tl.program_id(0) + pid_h = tl.program_id(1) + pid_q = tl.program_id(2) + pid_kh = pid_h // NUM_SHARE_Q_HEADS + # get q k start and len after rmpad + q_start = tl.load(cu_seqlens_q + pid_b) + q_len = tl.load(cu_seqlens_q + pid_b + 1) - q_start + k_start = tl.load(cu_seqlens_k + pid_b) + k_len = tl.load(cu_seqlens_k + pid_b + 1) - k_start + # skip first kernel_size query block, because they do no attend to any keys + q_start_in_seq = pid_q * BLOCK_SIZE_Q + kernel_size - 1 + if q_start_in_seq >= q_len: + return + # init qkv pointer + q_ptrs = tl.make_block_ptr( + base=q_ptr + q_start * stride_qn + pid_h * stride_qh, + shape=(q_len, HEAD_DIM), + strides=(stride_qn, stride_qd), + offsets=(q_start_in_seq, 0), + block_shape=(BLOCK_SIZE_Q, BLOCK_SIZE_D), + order=(1, 0), + ) + k_ptrs = tl.make_block_ptr( + base=k_ptr + k_start * stride_kn + pid_kh * stride_kh, + shape=(HEAD_DIM, k_len), + strides=(stride_kd, stride_kn), + offsets=(0, 0), + block_shape=(BLOCK_SIZE_D, BLOCK_SIZE_K), + order=(0, 1), + ) + v_ptrs = tl.make_block_ptr( + base=v_ptr + k_start * stride_vn + pid_kh * stride_vh, + shape=(k_len, HEAD_DIM), + strides=(stride_vn, stride_vd), + offsets=(0, 0), + block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D), + order=(1, 0), + ) + # load q + q = tl.load(q_ptrs, boundary_check=(0, 1), padding_option="zero") + # init statistics + off_q = tl.arange(0, BLOCK_SIZE_Q) + q_start_in_seq + off_k = tl.arange(0, BLOCK_SIZE_K) * kernel_stride + kernel_size - 1 + m_i = tl.full((BLOCK_SIZE_Q,), float("-inf"), dtype=tl.float32) + lse_i = tl.full((BLOCK_SIZE_Q,), float("-inf"), dtype=tl.float32) + acc_o = tl.full((BLOCK_SIZE_Q, BLOCK_SIZE_D), 0, dtype=tl.float32) + # attention + lo = 0 + hi = min(k_len, (q_start_in_seq + BLOCK_SIZE_Q - kernel_size) // kernel_stride + 1) + for i in range(lo, hi, BLOCK_SIZE_K): + i = tl.multiple_of(i, BLOCK_SIZE_K) + # load k + k = tl.load(k_ptrs, boundary_check=(1, 0), padding_option="zero") + # compute qk + qk = tl.zeros((BLOCK_SIZE_Q, BLOCK_SIZE_K), dtype=tl.float32) + qk += tl.where(off_q[:, None] >= (i * kernel_stride + off_k)[None, :], 0, float("-inf")) + qk += tl.dot(q, k) * qk_scale + # compute m_ij and l_ij + m_ij = tl.maximum(m_i, tl.max(qk, axis=1)) + p = tl.exp2(qk - m_ij[:, None]) + l_ij = tl.sum(p, axis=1) + # scale acc_o + acc_o_scale = tl.exp2(m_i - m_ij) + acc_o = acc_o * acc_o_scale[:, None] + # load v and update acc_o + v = tl.load(v_ptrs, boundary_check=(0, 1), padding_option="zero") + p = p.to(v.dtype) + acc_o += tl.dot(p, v) + # update statistics + m_i = m_ij + lse_i = m_ij + tl.math.log2(tl.exp2(lse_i - m_ij) + l_ij) + # update ptrs + k_ptrs = tl.advance(k_ptrs, (0, BLOCK_SIZE_K)) + v_ptrs = tl.advance(v_ptrs, (BLOCK_SIZE_K, 0)) + # final scale + acc_o = acc_o * tl.exp2(m_i - lse_i)[:, None] + # save output + o_ptrs = tl.make_block_ptr( + base=o_ptr + q_start * stride_on + pid_h * stride_oh, + shape=(q_len, HEAD_DIM), + strides=(stride_on, stride_od), + offsets=(q_start_in_seq, 0), + block_shape=(BLOCK_SIZE_Q, BLOCK_SIZE_D), + order=(1, 0), + ) + tl.store(o_ptrs, acc_o.to(o_ptr.dtype.element_ty), boundary_check=(0, 1)) + # save lse + l_ptrs = lse_ptr + q_start * stride_ln + pid_h * stride_lh + off_q * stride_ln + tl.store(l_ptrs, lse_i, mask=off_q < q_len) + + +@triton.jit +def backward_sum_o_do( + o_ptr, # O: n x h x d + do_ptr, # dO: n x h x d + delta_ptr, # D: h x n + o_len, + HEAD_DIM, + stride_on, + stride_oh, + stride_od, + stride_don, + stride_doh, + stride_dod, + stride_dh, + stride_dn, + BLOCK_SIZE_O: tl.constexpr, + BLOCK_SIZE_D: tl.constexpr, +): + pid_n = tl.program_id(0) + pid_h = tl.program_id(1) + off_n = pid_n * BLOCK_SIZE_O + tl.arange(0, BLOCK_SIZE_O) + off_d = tl.arange(0, BLOCK_SIZE_D) + o = tl.load( + o_ptr + off_n[:, None] * stride_on + pid_h * stride_oh + off_d[None, :] * stride_od, + mask=(off_n[:, None] < o_len) & (off_d[None, :] < HEAD_DIM), + other=0, + ).to(tl.float32) + do = tl.load( + do_ptr + off_n[:, None] * stride_don + pid_h * stride_doh + off_d[None, :] * stride_dod, + mask=(off_n[:, None] < o_len) & (off_d[None, :] < HEAD_DIM), + other=0, + ).to(tl.float32) + delta = tl.sum(o * do, axis=1) + tl.store(delta_ptr + pid_h * stride_dh + off_n * stride_dn, delta, mask=off_n < o_len) + + +@triton.jit +def backward_dkdv( + q_ptr, # Q: n x qh x d + k_ptr, # K: n x kh x d + v_ptr, # V: n x kh x d + lse_ptr, # LSE: qh x n + d_ptr, # Delta: qh x n + do_ptr, + dk_ptr, # DK: sh x n x kh x d + dv_ptr, # DV: sh x n x kh x d + kernel_size, + kernel_stride, + # seqlens + cu_seqlens_q, + cu_seqlens_k, + # shape + NUM_KV_HEADS, + NUM_SHARE_Q_HEADS, + HEAD_DIM, + # sm_scale + sm_scale, + # stride + stride_qn, + stride_qh, + stride_qd, + stride_kn, + stride_kh, + stride_kd, + stride_vn, + stride_vh, + stride_vd, + stride_lh, + stride_ln, + stride_dh, + stride_dn, + stride_don, + stride_doh, + stride_dod, + stride_dks, + stride_dkn, + stride_dkh, + stride_dkd, + stride_dvs, + stride_dvn, + stride_dvh, + stride_dvd, + # META parameters + BLOCK_SIZE_Q: tl.constexpr, # q block size + BLOCK_SIZE_K: tl.constexpr, # k block size + BLOCK_SIZE_D: tl.constexpr, +): + qk_scale = sm_scale * 1.44269504 + # get batch id and head id + pid_b = tl.program_id(0) + pid_h = tl.program_id(1) + pid_kh = pid_h // NUM_SHARE_Q_HEADS + pid_sh = pid_h % NUM_SHARE_Q_HEADS + pid_k = tl.program_id(2) + # get q k start and len after rmpad + q_start = tl.load(cu_seqlens_q + pid_b) + q_len = tl.load(cu_seqlens_q + pid_b + 1) - q_start + k_start = tl.load(cu_seqlens_k + pid_b) + k_len = tl.load(cu_seqlens_k + pid_b + 1) - k_start + if BLOCK_SIZE_K * pid_k >= k_len: + return + # init pointers + k_ptrs = tl.make_block_ptr( + base=k_ptr + k_start * stride_kn + pid_kh * stride_kh, + shape=(k_len, HEAD_DIM), + strides=(stride_kn, stride_kd), + offsets=(pid_k * BLOCK_SIZE_K, 0), + block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D), + order=(1, 0), + ) + dk_ptrs = tl.make_block_ptr( + base=dk_ptr + k_start * stride_dkn + pid_kh * stride_dkh + pid_sh * stride_dks, + shape=(k_len, HEAD_DIM), + strides=(stride_dkn, stride_dkd), + offsets=(pid_k * BLOCK_SIZE_K, 0), + block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D), + order=(1, 0), + ) + v_ptrs = tl.make_block_ptr( + base=v_ptr + k_start * stride_vn + pid_kh * stride_vh, + shape=(k_len, HEAD_DIM), + strides=(stride_vn, stride_vd), + offsets=(pid_k * BLOCK_SIZE_K, 0), + block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D), + order=(1, 0), + ) + dv_ptrs = tl.make_block_ptr( + base=dv_ptr + k_start * stride_dvn + pid_kh * stride_dvh + pid_sh * stride_dvs, + shape=(k_len, HEAD_DIM), + strides=(stride_dvn, stride_dvd), + offsets=(pid_k * BLOCK_SIZE_K, 0), + block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D), + order=(1, 0), + ) + # offsets + off_q = tl.arange(0, BLOCK_SIZE_Q) + off_k = pid_k * BLOCK_SIZE_K * kernel_stride + tl.arange(0, BLOCK_SIZE_K) * kernel_stride + kernel_size - 1 + # load k v and keep in SRAM + k = tl.load(k_ptrs, boundary_check=(0, 1), padding_option="zero") + v = tl.load(v_ptrs, boundary_check=(0, 1), padding_option="zero") + # init dk dv + dk = tl.zeros((BLOCK_SIZE_K, BLOCK_SIZE_D), dtype=tl.float32) + dv = tl.zeros((BLOCK_SIZE_K, BLOCK_SIZE_D), dtype=tl.float32) + q_lo = pid_k * BLOCK_SIZE_K * kernel_stride + kernel_size - 1 + q_ptrs = tl.make_block_ptr( + base=q_ptr + q_start * stride_qn + pid_h * stride_qh, + shape=(HEAD_DIM, q_len), + strides=(stride_qd, stride_qn), + offsets=(0, q_lo), + block_shape=(BLOCK_SIZE_D, BLOCK_SIZE_Q), + order=(0, 1), + ) + do_ptrs = tl.make_block_ptr( + base=do_ptr + q_start * stride_don + pid_h * stride_doh, + shape=(HEAD_DIM, q_len), + strides=(stride_dod, stride_don), + offsets=(0, q_lo), + block_shape=(BLOCK_SIZE_D, BLOCK_SIZE_Q), + order=(0, 1), + ) + d_ptrs = tl.make_block_ptr( + base=d_ptr + q_start * stride_dn + pid_h * stride_dh, + shape=(1, q_len), + strides=(0, stride_dn), + offsets=(0, q_lo), + block_shape=(1, BLOCK_SIZE_Q), + order=(1, 0), + ) + lse_ptrs = tl.make_block_ptr( + base=lse_ptr + q_start * stride_ln + pid_h * stride_lh, + shape=(1, q_len), + strides=(0, stride_ln), + offsets=(0, q_lo), + block_shape=(1, BLOCK_SIZE_Q), + order=(0, 1), + ) + # loop for q blocks + for i in range(q_lo, q_len, BLOCK_SIZE_Q): + # load + q = tl.load(q_ptrs, boundary_check=(0, 1), padding_option="zero") + do = tl.load(do_ptrs, boundary_check=(0, 1), padding_option="zero") + lse = tl.load(lse_ptrs, boundary_check=(0, 1), padding_option="zero") + d = tl.load(d_ptrs, boundary_check=(0, 1), padding_option="zero") + # compute qk + # [BLOCK_SIZE_K, HEAD_DIM] @ [HEAD_DIM, BLOCK_SIE_Q] -> [BLOCK_SIZE_K, BLOCK_SIE_Q] + qk = tl.where(off_k[:, None] <= (off_q + i)[None, :], float(0.0), float("-inf")) + qk += tl.dot(k, q) * qk_scale + # compute p, ds + # [BLOCK_SIZE_K, BLOCK_SIE_Q] - [1, BLOCK_SIZE_Q] -> [BLOCK_SIZE_K, BLOCK_SIE_Q] + p = tl.exp2(qk - lse) + # [BLOCK_SIZE_K, HEAD_DIM] @ [HEAD_DIM, BLOCK_SIE_Q] -> [BLOCK_SIZE_K, BLOCK_SIE_Q] + dp = tl.dot(v, do) + ds = sm_scale * p * (dp - d) + # cast dtype + p = p.to(do.dtype) + ds = ds.to(q.dtype) + # update dk and dv + # [BLOCK_SIZE_K, BLOCK_SIE_Q] @ [BLOCK_SIE_Q, HEAD_DIM] -> [BLOCK_SIZE_K, HEAD_DIM] + dk += tl.dot(ds, tl.trans(q)) + dv += tl.dot(p, tl.trans(do)) + # increment pointers + q_ptrs = tl.advance(q_ptrs, (0, BLOCK_SIZE_Q)) + do_ptrs = tl.advance(do_ptrs, (0, BLOCK_SIZE_Q)) + lse_ptrs = tl.advance(lse_ptrs, (0, BLOCK_SIZE_Q)) + d_ptrs = tl.advance(d_ptrs, (0, BLOCK_SIZE_Q)) + # save dk dv + tl.store(dk_ptrs, dk.to(dk_ptr.dtype.element_ty), boundary_check=(0, 1)) + tl.store(dv_ptrs, dv.to(dv_ptr.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.jit +def backward_dq( + q_ptr, # Q: n x qh x d + k_ptr, # K: n x kh x d + v_ptr, # V: n x kh x d + lse_ptr, # LSE: qh x n + d_ptr, # Delta: qh x n + do_ptr, + dq_ptr, + kernel_size, + kernel_stride, + # seqlens + cu_seqlens_q, + cu_seqlens_k, + # shape + NUM_KV_HEADS, + NUM_SHARE_Q_HEADS, + HEAD_DIM, + # sm_scale + sm_scale, + # stride + stride_qn, + stride_qh, + stride_qd, + stride_kn, + stride_kh, + stride_kd, + stride_vn, + stride_vh, + stride_vd, + stride_lh, + stride_ln, + stride_dh, + stride_dn, + stride_don, + stride_doh, + stride_dod, + stride_dqn, + stride_dqh, + stride_dqd, + # META parameters + BLOCK_SIZE_Q: tl.constexpr, # q block size + BLOCK_SIZE_K: tl.constexpr, # k block size + BLOCK_SIZE_D: tl.constexpr, +): + qk_scale = sm_scale * 1.44269504 + # get batch id and head id + pid_b = tl.program_id(0) + pid_h = tl.program_id(1) + pid_q = tl.program_id(2) + pid_kh = pid_h // NUM_SHARE_Q_HEADS + # get q k start and len after rmpad + q_start = tl.load(cu_seqlens_q + pid_b) + q_len = tl.load(cu_seqlens_q + pid_b + 1) - q_start + k_start = tl.load(cu_seqlens_k + pid_b) + k_len = tl.load(cu_seqlens_k + pid_b + 1) - k_start + # skip first kernel_size query block, because they do no attend to any keys + q_start_in_seq = pid_q * BLOCK_SIZE_Q + kernel_size - 1 + if q_start_in_seq >= q_len: + return + # init pointers + q_ptrs = tl.make_block_ptr( + base=q_ptr + q_start * stride_qn + pid_h * stride_qh, + shape=(q_len, HEAD_DIM), + strides=(stride_qn, stride_qd), + offsets=(q_start_in_seq, 0), + block_shape=(BLOCK_SIZE_Q, BLOCK_SIZE_D), + order=(1, 0), + ) + dq_ptrs = tl.make_block_ptr( + base=dq_ptr + q_start * stride_dqn + pid_h * stride_dqh, + shape=(q_len, HEAD_DIM), + strides=(stride_dqn, stride_dqd), + offsets=(q_start_in_seq, 0), + block_shape=(BLOCK_SIZE_Q, BLOCK_SIZE_D), + order=(1, 0), + ) + k_ptrs = tl.make_block_ptr( + base=k_ptr + k_start * stride_kn + pid_kh * stride_kh, + shape=(k_len, HEAD_DIM), + strides=(stride_kn, stride_kd), + offsets=(0, 0), + block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D), + order=(1, 0), + ) + v_ptrs = tl.make_block_ptr( + base=v_ptr + k_start * stride_vn + pid_kh * stride_vh, + shape=(HEAD_DIM, k_len), + strides=(stride_vd, stride_vn), + offsets=(0, 0), + block_shape=(BLOCK_SIZE_D, BLOCK_SIZE_K), + order=(0, 1), + ) + do_ptrs = tl.make_block_ptr( + base=do_ptr + q_start * stride_don + pid_h * stride_doh, + shape=(q_len, HEAD_DIM), + strides=(stride_don, stride_dod), + offsets=(q_start_in_seq, 0), + block_shape=(BLOCK_SIZE_Q, BLOCK_SIZE_D), + order=(1, 0), + ) + d_ptrs = tl.make_block_ptr( + base=d_ptr + q_start * stride_dn + pid_h * stride_dh, + shape=(q_len, 1), + strides=(stride_dn, stride_dh), + offsets=(q_start_in_seq, 0), + block_shape=(BLOCK_SIZE_Q, 1), + order=(0, 1), + ) + lse_ptrs = tl.make_block_ptr( + base=lse_ptr + q_start * stride_ln + pid_h * stride_lh, + shape=(q_len, 1), + strides=(stride_ln, stride_lh), + offsets=(q_start_in_seq, 0), + block_shape=(BLOCK_SIZE_Q, 1), + order=(0, 1), + ) + # offsets + off_q = tl.arange(0, BLOCK_SIZE_Q) + q_start_in_seq + off_k = tl.arange(0, BLOCK_SIZE_K) * kernel_stride + kernel_size - 1 + # load q, do, lse, delta, and keep in SRAM + q = tl.load(q_ptrs, boundary_check=(1, 0), padding_option="zero") + do = tl.load(do_ptrs, boundary_check=(0, 1), padding_option="zero") + lse = tl.load(lse_ptrs, boundary_check=(0, 1), padding_option="zero") + d = tl.load(d_ptrs, boundary_check=(0, 1), padding_option="zero") + # init dq + dq = tl.zeros((BLOCK_SIZE_Q, BLOCK_SIZE_D), dtype=tl.float32) + lo = 0 + hi = min(k_len, (q_start_in_seq + BLOCK_SIZE_Q - kernel_size) // kernel_stride + 1) + for i in range(lo, hi, BLOCK_SIZE_K): + # load + k = tl.load(k_ptrs, boundary_check=(0, 1), padding_option="zero") + v = tl.load(v_ptrs, boundary_check=(0, 1), padding_option="zero") + # compute qk + qk = tl.zeros((BLOCK_SIZE_Q, BLOCK_SIZE_K), dtype=tl.float32) + qk += tl.where(off_q[:, None] >= (i * kernel_stride + off_k)[None, :], 0, float("-inf")) + qk += tl.dot(q, tl.trans(k)) * qk_scale + # compute p, ds + p = tl.exp2(qk - lse) + dp = tl.dot(do, v) + ds = sm_scale * p * (dp - d) + # cast dtype + ds = ds.to(q.dtype) + # update dq + dq += tl.dot(ds, k) + # increment pointers + k_ptrs = tl.advance(k_ptrs, (BLOCK_SIZE_K, 0)) + v_ptrs = tl.advance(v_ptrs, (0, BLOCK_SIZE_K)) + # save dq + tl.store(dq_ptrs, dq.to(dq_ptr.dtype.element_ty), boundary_check=(0, 1)) + + +def _compressed_attention_fwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + kernel_size: int, + kernel_stride: int, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + max_seqlen_q: torch.Tensor, + max_seqlen_k: torch.Tensor, + sm_scale: float, +): + # dtype check + assert k.dtype == q.dtype and v.dtype == q.dtype + assert cu_seqlens_q.dtype == torch.int32 and cu_seqlens_k.dtype == torch.int32 + # shape + q_len, num_q_heads, head_dim = q.shape + k_len, num_k_heads, head_dim = k.shape + v_len, num_v_heads, head_dim = v.shape + batch_size = cu_seqlens_q.shape[0] - 1 + assert k_len == v_len and q_len > k_len + # gqa + assert num_k_heads == num_v_heads + assert num_q_heads % num_k_heads == 0 + num_share_q_heads = num_q_heads // num_k_heads + # output tensor + o = torch.zeros_like(q) + lse = torch.full( + (num_q_heads, q_len), + fill_value=-torch.inf, + dtype=torch.float32, + device=q.device, + ) + # launch kernel + grid = lambda META: ( + batch_size, + num_q_heads, + triton.cdiv(max_seqlen_q, META["BLOCK_SIZE_Q"]), + ) + BLOCK_SIZE_Q = 128 + BLOCK_SIZE_K = 128 + BLOCK_SIZE_D = triton.next_power_of_2(head_dim) + num_warps, num_stages = get_num_warps_stages(head_dim, BLOCK_SIZE_Q, IS_HOPPER_GPU) + forward_kernel[grid]( + q, + k, + v, + o, + lse, + kernel_size, + kernel_stride, + cu_seqlens_q, + cu_seqlens_k, + num_k_heads, + num_share_q_heads, + head_dim, + sm_scale, + q.stride(0), + q.stride(1), + q.stride(2), + k.stride(0), + k.stride(1), + k.stride(2), + v.stride(0), + v.stride(1), + v.stride(2), + o.stride(0), + o.stride(1), + o.stride(2), + lse.stride(0), + lse.stride(1), + BLOCK_SIZE_Q=BLOCK_SIZE_Q, + BLOCK_SIZE_K=BLOCK_SIZE_K, + BLOCK_SIZE_D=BLOCK_SIZE_D, + num_warps=num_warps, + num_stages=num_stages, + ) + return o, lse + + +def _compressed_attention_bwd( + o: torch.Tensor, + do: torch.Tensor, + lse: torch.Tensor, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + kernel_size: int, + kernel_stride: int, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + max_seqlen_q: torch.Tensor, + max_seqlen_k: torch.Tensor, + sm_scale: float, +): + q_len, num_q_heads, head_dim = q.shape + k_len, num_k_heads, head_dim = k.shape + v_len, num_v_heads, head_dim = v.shape + o_len, num_o_heads, head_dim = o.shape + num_share_q_heads = num_q_heads // num_k_heads + # compute D + delta = torch.zeros([num_o_heads, o_len], device=o.device, dtype=torch.float32) + grid = lambda META: (triton.cdiv(o_len, META["BLOCK_SIZE_O"]), num_o_heads) + BLOCK_SIZE_O = 256 + BLOCK_SIZE_D = triton.next_power_of_2(head_dim) + num_warps, num_stages = get_num_warps_stages(head_dim, BLOCK_SIZE_O, IS_HOPPER_GPU) + backward_sum_o_do[grid]( + o, + do, + delta, + o_len, + head_dim, + o.stride(0), + o.stride(1), + o.stride(2), + do.stride(0), + do.stride(1), + do.stride(2), + delta.stride(0), + delta.stride(1), + BLOCK_SIZE_O=BLOCK_SIZE_O, + BLOCK_SIZE_D=BLOCK_SIZE_D, + num_warps=num_warps, + num_stages=num_stages, + ) + # compute dk dv + dk = torch.zeros(num_share_q_heads, k_len, num_k_heads, head_dim, device=k.device, dtype=k.dtype) + dv = torch.zeros(num_share_q_heads, k_len, num_k_heads, head_dim, device=k.device, dtype=k.dtype) + batch_size = cu_seqlens_q.shape[0] - 1 + grid = lambda META: ( + batch_size, + num_q_heads, + triton.cdiv(max_seqlen_k, META["BLOCK_SIZE_K"]), + ) + BLOCK_SIZE_Q = 64 + BLOCK_SIZE_K = 128 + BLOCK_SIZE_D = triton.next_power_of_2(head_dim) + num_warps, num_stages = get_num_warps_stages(head_dim, BLOCK_SIZE_K, IS_HOPPER_GPU) + backward_dkdv[grid]( + q, + k, + v, + lse, + delta, + do, + dk, + dv, + kernel_size, + kernel_stride, + cu_seqlens_q, + cu_seqlens_k, + num_k_heads, + num_share_q_heads, + head_dim, + sm_scale, + q.stride(0), + q.stride(1), + q.stride(2), + k.stride(0), + k.stride(1), + k.stride(2), + v.stride(0), + v.stride(1), + v.stride(2), + lse.stride(0), + lse.stride(1), + delta.stride(0), + delta.stride(1), + do.stride(0), + do.stride(1), + do.stride(2), + dk.stride(0), + dk.stride(1), + dk.stride(2), + dk.stride(3), + dv.stride(0), + dv.stride(1), + dv.stride(2), + dv.stride(3), + BLOCK_SIZE_Q=BLOCK_SIZE_Q, + BLOCK_SIZE_K=BLOCK_SIZE_K, + BLOCK_SIZE_D=BLOCK_SIZE_D, + num_warps=num_warps, + num_stages=num_stages, + ) + dk = dk.sum(0) + dv = dv.sum(0) + # compute dq + dq = torch.zeros_like(q) + grid = lambda META: ( + batch_size, + num_q_heads, + triton.cdiv(max_seqlen_q, META["BLOCK_SIZE_Q"]), + ) + BLOCK_SIZE_Q = 128 + BLOCK_SIZE_K = 64 + num_warps, num_stages = get_num_warps_stages(head_dim, BLOCK_SIZE_Q, IS_HOPPER_GPU) + backward_dq[grid]( + q, + k, + v, + lse, + delta, + do, + dq, + kernel_size, + kernel_stride, + cu_seqlens_q, + cu_seqlens_k, + num_k_heads, + num_share_q_heads, + head_dim, + sm_scale, + q.stride(0), + q.stride(1), + q.stride(2), + k.stride(0), + k.stride(1), + k.stride(2), + v.stride(0), + v.stride(1), + v.stride(2), + lse.stride(0), + lse.stride(1), + delta.stride(0), + delta.stride(1), + do.stride(0), + do.stride(1), + do.stride(2), + dq.stride(0), + dq.stride(1), + dq.stride(2), + BLOCK_SIZE_Q=BLOCK_SIZE_Q, + BLOCK_SIZE_K=BLOCK_SIZE_K, + BLOCK_SIZE_D=BLOCK_SIZE_D, + num_warps=num_warps, + num_stages=num_stages, + ) + return dq, dk, dv + + +class CompressedAttention(torch.autograd.Function): + @staticmethod + def forward( + ctx, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + kernel_size: int, + kernel_stride: int, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + max_seqlen_q: torch.Tensor, + max_seqlen_k: torch.Tensor, + sm_scale=None, + ): + # dtype check + assert q.dtype == torch.bfloat16 or q.dtype == torch.float16 + assert q.dtype == k.dtype and k.dtype == v.dtype + assert cu_seqlens_q.dtype == torch.int32 and cu_seqlens_k.dtype == torch.int32 + # softmax scale + if sm_scale is None: + sm_scale = 1 / math.sqrt(q.shape[-1]) + + o, lse = _compressed_attention_fwd( + q, + k, + v, + kernel_size, + kernel_stride, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + sm_scale, + ) + ctx.save_for_backward(q, k, v, o, lse, cu_seqlens_q, cu_seqlens_k) + ctx.sm_scale = sm_scale + ctx.max_seqlen_q = max_seqlen_q + ctx.max_seqlen_k = max_seqlen_k + ctx.kernel_size = kernel_size + ctx.kernel_stride = kernel_stride + return o, lse + + @staticmethod + def backward(ctx, do: torch.Tensor, *args) -> Any: + q, k, v, o, lse, cu_seqlens_q, cu_seqlens_k = ctx.saved_tensors + max_seqlen_q = ctx.max_seqlen_q + max_seqlen_k = ctx.max_seqlen_k + sm_scale = ctx.sm_scale + kernel_size = ctx.kernel_size + kernel_stride = ctx.kernel_stride + + dq, dk, dv = _compressed_attention_bwd( + o, + do, + lse, + q, + k, + v, + kernel_size, + kernel_stride, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + sm_scale, + ) + return dq, dk, dv, None, None, None, None, None, None, None + + +@triton.jit +def score_kernel( + q_ptr, + k_ptr, + lse_ptr, + s_ptr, + kernel_size, + kernel_stride, + # seqlens + cu_seqlens_q, + cu_seqlens_k, + # shape + NUM_KV_HEADS, + NUM_SHARE_Q_HEADS, + HEAD_DIM, + # sm_scale + sm_scale, + # stride + stride_qn, + stride_qh, + stride_qd, + stride_kn, + stride_kh, + stride_kd, + stride_lh, + stride_ln, + stride_sh, + stride_sq, + stride_sk, + # META parameters + BLOCK_SIZE_Q: tl.constexpr, # q block size + BLOCK_SIZE_K: tl.constexpr, # k block size + BLOCK_SIZE_D: tl.constexpr, +): + qk_scale = sm_scale * 1.44269504 + # get batch id and head id + pid_bkh = tl.program_id(0) + pid_b = pid_bkh // NUM_KV_HEADS + pid_kh = pid_bkh % NUM_KV_HEADS + pid_q = tl.program_id(1) + pid_k = tl.program_id(2) + # get q k start and len after rmpad + q_start = tl.load(cu_seqlens_q + pid_b) + q_len = tl.load(cu_seqlens_q + pid_b + 1) - q_start + k_start = tl.load(cu_seqlens_k + pid_b) + k_len = tl.load(cu_seqlens_k + pid_b + 1) - k_start + if pid_q * BLOCK_SIZE_Q >= q_len or pid_k * BLOCK_SIZE_K >= k_len: + return + # init k pointer and load k + k_ptrs = tl.make_block_ptr( + base=k_ptr + k_start * stride_kn + pid_kh * stride_kh, + shape=(HEAD_DIM, k_len), + strides=(stride_kd, stride_kn), + offsets=(0, pid_k * BLOCK_SIZE_K), + block_shape=(BLOCK_SIZE_D, BLOCK_SIZE_K), + order=(0, 1), + ) + k = tl.load(k_ptrs, boundary_check=(0, 1), padding_option="zero") + # offsets + off_q = tl.arange(0, BLOCK_SIZE_Q) + pid_q * BLOCK_SIZE_Q + off_k = tl.arange(0, BLOCK_SIZE_K) + pid_k * BLOCK_SIZE_K + causal_mask = off_q[:, None] >= (off_k * kernel_stride + kernel_size - 1)[None, :] + # init score + s = tl.zeros((BLOCK_SIZE_Q, BLOCK_SIZE_K), dtype=tl.float32) + + q_ptrs = tl.make_block_ptr( + base=q_ptr + q_start * stride_qn + pid_kh * stride_qh, + shape=(q_len, HEAD_DIM), + strides=(stride_qn, stride_qd), + offsets=(pid_q * BLOCK_SIZE_Q, 0), + block_shape=(BLOCK_SIZE_Q, BLOCK_SIZE_D), + order=(1, 0), + ) + lse_ptrs = tl.make_block_ptr( + base=lse_ptr + q_start * stride_ln + pid_kh * stride_lh, + shape=(q_len, 1), + strides=(stride_ln, stride_lh), + offsets=(pid_q * BLOCK_SIZE_Q, 0), + block_shape=(BLOCK_SIZE_Q, 1), + order=(0, 1), + ) + # load q and lse + q = tl.load(q_ptrs, boundary_check=(0, 1), padding_option="zero") + lse = tl.load(lse_ptrs, boundary_check=(0, 1), padding_option="zero") + # compute qk + qk = tl.zeros((BLOCK_SIZE_Q, BLOCK_SIZE_K), dtype=tl.float32) + qk += tl.dot(q, k) * qk_scale + # compute score + s += tl.where(causal_mask, tl.exp2(qk - lse), 0) + # save output + s_ptrs = tl.make_block_ptr( + base=s_ptr + pid_kh * stride_sh + q_start * stride_sq, + shape=(q_len, k_len), + strides=(stride_sq, stride_sk), + offsets=(pid_q * BLOCK_SIZE_Q, pid_k * BLOCK_SIZE_K), + block_shape=(BLOCK_SIZE_Q, BLOCK_SIZE_K), + order=(1, 0), + ) + tl.store(s_ptrs, s.to(s_ptr.dtype.element_ty), boundary_check=(0, 1)) + + +def _get_attention_score( + q: torch.Tensor, # [total_query_len, num_q_heads, head_dim] + k: torch.Tensor, # [total_key_len, num_k_heads, head_dim] + lse: torch.Tensor, # [num_q_heads, total_query_len] + kernel_size: int, + kernel_stride: int, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + max_seqlen_q: int, + max_seqlen_k: int, + sm_scale: float, +) -> torch.Tensor: + # dtype check + assert q.dtype == torch.bfloat16 or q.dtype == torch.float16 + assert q.dtype == k.dtype + assert cu_seqlens_q.dtype == torch.int32 and cu_seqlens_k.dtype == torch.int32 + assert lse.dtype == torch.float32 # lse here is log2(sum(exp(qk*scale))), not log(sum(exp(qk*scale))) + # shape + q_len, num_q_heads, head_dim = q.shape + k_len, num_k_heads, head_dim = k.shape + batch_size = cu_seqlens_q.shape[0] - 1 + assert q_len > k_len + if sm_scale is None: + sm_scale = 1 / math.sqrt(head_dim) + # gqa + assert num_q_heads % num_k_heads == 0 + num_share_q_heads = num_q_heads // num_k_heads + # init score + score = torch.zeros(num_k_heads, q_len, max_seqlen_k, dtype=torch.float32, device=q.device) + + # launch kernel + grid = lambda META: ( + batch_size * num_k_heads, + triton.cdiv(max_seqlen_q, META["BLOCK_SIZE_Q"]), + triton.cdiv(max_seqlen_k, META["BLOCK_SIZE_K"]), + ) + BLOCK_SIZE_Q = 128 + BLOCK_SIZE_K = 128 + BLOCK_SIZE_D = triton.next_power_of_2(head_dim) + + score_kernel[grid]( + q, + k, + lse, + score, + kernel_size, + kernel_stride, + cu_seqlens_q, + cu_seqlens_k, + num_k_heads, + num_share_q_heads, + head_dim, + sm_scale, + q.stride(0), + q.stride(1), + q.stride(2), + k.stride(0), + k.stride(1), + k.stride(2), + lse.stride(0), + lse.stride(1), + score.stride(0), + score.stride(1), + score.stride(2), + BLOCK_SIZE_Q=BLOCK_SIZE_Q, + BLOCK_SIZE_K=BLOCK_SIZE_K, + BLOCK_SIZE_D=BLOCK_SIZE_D, + num_warps=8, + num_stages=3, + ) + return score + + +@triton.jit +def _transform_score_kernel( + s_ptr, # score, shape: [num_heads, q_len, k_len] + bs_ptr, # block wise score: [num_heads, q_len, num_k_block] + offs, + cu_seqlens_q, + # shape + num_heads, + num_offs, + max_k_len, + max_blocks, + pad_len, + # kernel & block size + block_size, + block_stride, # block_size // kernel_stride + init_blocks, + local_blocks, + # stride + stride_sh, + stride_sq, + stride_sk, + stride_bsh, + stride_bsq, + stride_bsk, + TOTAL_QUERY_LEN: tl.constexpr, + BLOCK_SIZE_Q: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + BLOCK_SIZE_O: tl.constexpr, +): + pid_bh = tl.program_id(0) + pid_b = pid_bh // num_heads + pid_h = pid_bh % num_heads + pid_q = tl.program_id(1) + pid_k = tl.program_id(2) + q_start = tl.load(cu_seqlens_q + pid_b) + q_len = tl.load(cu_seqlens_q + pid_b + 1) - q_start + k_start = pid_k * BLOCK_SIZE_K + if pid_q * BLOCK_SIZE_Q >= q_len: + return + # load weight + off_o = tl.arange(0, BLOCK_SIZE_O) + w = tl.load(offs + off_o, mask=off_o < num_offs, other=0) + # load score + off_q = pid_q * BLOCK_SIZE_Q + tl.arange(0, BLOCK_SIZE_Q) + off_k = (k_start + tl.arange(0, BLOCK_SIZE_K)) * block_stride - pad_len + off_k = off_k[None, :] + off_o[:, None] + s_ptrs = ( + s_ptr + + q_start * stride_sq + + pid_h * stride_sh + + off_q[:, None, None] * stride_sq + + off_k[None, :, :] * stride_sk + ) + # weighted sum, [BQ, BO, BK] * [1, BO, 1] -> [BQ, BO, BK] -> [BQ, BK] + s = tl.load( + s_ptrs, + mask=(off_q < q_len)[:, None, None] & (off_k >= 0) & (off_k < max_k_len), + other=0, + ) + s = s * w[None, :, None] + s = tl.sum(s, axis=1) + # init mask and local mask + off_bq = off_q // block_size + off_bk = k_start + tl.arange(0, BLOCK_SIZE_K) + s = tl.where( + ((off_bq[:, None] >= off_bk[None, :]) & (off_bq[:, None] < off_bk[None, :] + local_blocks)) + | (off_bk[None, :] < init_blocks - k_start), + float("inf"), + s, + ) + # store block wise score + bs_ptrs = ( + bs_ptr + q_start * stride_bsq + pid_h * stride_bsh + off_q[:, None] * stride_bsq + off_bk[None, :] * stride_bsk + ) + tl.store( + bs_ptrs, + s, + mask=(off_q < q_len)[:, None] & (off_bk < max_blocks)[None, :], + ) + + +def transform_score( + score: torch.Tensor, + kernel_size: int, + kernel_stride: int, + block_size: int, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + max_seqlen_q: int, + max_seqlen_k: int, + init_blocks: int = 1, + local_blocks: int = 2, +) -> torch.Tensor: + num_k_heads, total_query_len, max_key_len = score.shape + batch_size = cu_seqlens_q.shape[0] - 1 + pad_len = kernel_size // kernel_stride - 1 + max_blocks = math.ceil(max_seqlen_q / block_size) + block_score = torch.zeros( + num_k_heads, + total_query_len, + max_blocks, + dtype=torch.float32, + device=score.device, + ) + offs = ( + torch.arange(kernel_size // kernel_stride, device=score.device)[:, None] + + torch.arange(block_size // kernel_stride, device=score.device)[None, :] + ).view(-1) + + offs = torch.histc(offs, bins=offs.max() + 1, min=0, max=offs.max()) + + num_offs = int(offs.shape[0]) + + BLOCK_SIZE_Q = 16 + BLOCK_SIZE_K = min(128, triton.next_power_of_2(max_blocks)) + BLOCK_SIZE_O = triton.next_power_of_2(num_offs) + + def grid(meta): + grid = ( + num_k_heads * batch_size, + triton.cdiv(total_query_len, BLOCK_SIZE_Q), + triton.cdiv(max_blocks, BLOCK_SIZE_K), + ) + return grid + + _transform_score_kernel[grid]( + score, + block_score, + offs, + cu_seqlens_q, + num_k_heads, + offs.shape[0], + max_key_len, + max_blocks, + pad_len, + block_size, + block_size // kernel_stride, + init_blocks, + local_blocks, + score.stride(0), + score.stride(1), + score.stride(2), + block_score.stride(0), + block_score.stride(1), + block_score.stride(2), + TOTAL_QUERY_LEN=total_query_len, + BLOCK_SIZE_Q=BLOCK_SIZE_Q, + BLOCK_SIZE_K=BLOCK_SIZE_K, + BLOCK_SIZE_O=BLOCK_SIZE_O, + num_warps=4, + num_stages=3, + ) + return block_score + + +def compressed_attention( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + kernel_size: int, + kernel_stride: int, + block_size: int, + topk: int, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + max_seqlen_q: int, + max_seqlen_k: int, + sm_scale: float = None, + init_blocks: int = 1, + local_blocks: int = 2, + parallel_topk_compute: Union[str, bool] = False, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Attention between query and compressed key and value. Compute attention output and topk block idx used in topk_sparse_attention. + + Args: + q (torch.Tensor): shape [total_q_len, num_q_heads, head_dim] + k (torch.Tensor): shape [total_kv_len, num_kv_heads, head_dim] + v (torch.Tensor): shape [total_kv_len, num_kv_heads, head_dim] + kernel_size (int): kernel size in compress_key_value + kernel_stride (int): stride of compress_key_value + block_size (int): key value block size for topk sparse attention. + topk (int): number of blocks for each query. + cu_seqlens_q (torch.Tensor): shape [batch_size + 1], similar to cu_seqlens_q in flash_attn_func_varlen. + cu_seqlens_k (torch.Tensor): shape [batch_size + 1], similar to cu_seqlens_k in flash_attn_func_varlen. + max_seqlen_q (int): max q len of the batch. + max_seqlen_k (int): max k len of the batch. + sm_scale (float, optional): softmax scale. Defaults to None, means 1/sqrt(head_dim). + init_blocks (int, optional): Number of init blocks for each query. Defaults to 1. + local_blocks (int, optional): Number of local blocks for each query. Defaults to 2. + parallel_topk_compute (str, optional): Only set it to False when the sequence length is too long. This can avoid a current bug. + We'll fix this issue later. Defaults to auto, it will be set to False when the sequence length is greater than 32k and True otherwise. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: attention output and topk_idx used in topk_sparse_attention + """ + + if max_seqlen_q is None: + max_seqlen_q = (cu_seqlens_q[1:] - cu_seqlens_q[:-1]).max().item() + if max_seqlen_k is None: + max_seqlen_k = (cu_seqlens_k[1:] - cu_seqlens_k[:-1]).max().item() + + attn_output, lse = CompressedAttention.apply( + q, + k, + v, + kernel_size, + kernel_stride, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + sm_scale, + ) + + # do not select topk index + if topk <= 0: + warnings.warn("topk <= 0, returned topk_idx will be None") + return attn_output, None + + assert topk >= init_blocks + local_blocks + with torch.no_grad(): + num_k_heads, num_q_heads = k.shape[1], q.shape[1] + num_shared_q_heads = num_q_heads // num_k_heads + batch_size = cu_seqlens_q.shape[0] - 1 + q_idx = torch.cat( + [torch.arange(cu_seqlens_q[i + 1] - cu_seqlens_q[i], device=q.device) for i in range(batch_size)], + dim=0, + ) + q_idx = q_idx // block_size + + # whether to use parallel version + if parallel_topk_compute == "auto": + parallel_topk_compute = cu_seqlens_q[-1] <= 32768 + # parallel version + if parallel_topk_compute: + # recompute score + score = _get_attention_score( + q, + k, + lse, + kernel_size, + kernel_stride, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + sm_scale, + ) + # transform score to block-wise score + score = transform_score( + score, + kernel_size, + kernel_stride, + block_size, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + init_blocks, + local_blocks, + ) + # get topk + topk = min(topk, score.shape[-1]) + topk_idx = score.topk(topk, dim=-1).indices.sort(-1).values + topk_idx[topk_idx > q_idx[None, :, None]] = -1 + topk_idx = topk_idx.to(torch.int32) + # non parallel version, avoid some current bugs when sequence length is too long + # FIXME: need to fix later + else: + topk_idx_list = [] + head_tile = 1 + assert num_k_heads % head_tile == 0, f"Num kv heads: {num_k_heads}, head_tile: {head_tile}" + for h in range(num_k_heads // head_tile): + # recompute score + score = _get_attention_score( + q[:, h * num_shared_q_heads * head_tile: (h + 1) * num_shared_q_heads * head_tile], + k[:, h * head_tile: (h + 1) * head_tile], + lse[h * num_shared_q_heads * head_tile: (h + 1) * num_shared_q_heads * head_tile], + kernel_size, + kernel_stride, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + sm_scale, + ) + # transform score to block-wise score + score = transform_score( + score, + kernel_size, + kernel_stride, + block_size, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + init_blocks, + local_blocks, + ) + # get topk + topk = min(topk, score.shape[-1]) + if score.dtype == torch.float32: + score = score.to(torch.bfloat16) + topk_idx = score.topk(topk, dim=-1, sorted=False).indices + topk_idx = topk_idx.sort(-1).values + + topk_idx[topk_idx > q_idx[None, :, None]] = -1 + topk_idx = topk_idx.to(torch.int32) + topk_idx_list.append(topk_idx) + topk_idx = torch.cat(topk_idx_list, dim=0) + + return attn_output, topk_idx diff --git a/vortex_torch/kernels/nsa/flash_attention.py b/vortex_torch/kernels/nsa/flash_attention.py new file mode 100644 index 0000000..c556a4c --- /dev/null +++ b/vortex_torch/kernels/nsa/flash_attention.py @@ -0,0 +1,886 @@ +# Copyright 2025 Xunhao Lai & Jianqiao Lu. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math +from typing import Any, Optional + +import torch +import triton +import triton.language as tl + +from .utils import get_num_warps_stages, is_hopper_gpu + +IS_HOPPER_GPU = is_hopper_gpu() + + +@triton.jit +def forward_kernel( + q_ptr, # Q: n x h x d + k_ptr, # K: n x h x d + v_ptr, # V: n x h x d + o_ptr, # O: n x h x d + lse_ptr, # LSE: h x n + # seqlens + cu_seqlens_q, + cu_seqlens_k, + # shape + NUM_KV_HEADS, + NUM_SHARE_Q_HEADS, + HEAD_DIM, + # sm_scale + sm_scale, + # causal + causal, + # gqa + gqa_interleave, + # stride + stride_qn, + stride_qh, + stride_qd, + stride_kn, + stride_kh, + stride_kd, + stride_vn, + stride_vh, + stride_vd, + stride_on, + stride_oh, + stride_od, + stride_lh, + stride_ln, + # META parameters + BLOCK_SIZE_Q: tl.constexpr, # q block size + BLOCK_SIZE_K: tl.constexpr, # k block size + BLOCK_SIZE_D: tl.constexpr, +): + qk_scale = sm_scale * 1.44269504 + # get batch id and head id + pid_b = tl.program_id(0) + pid_h = tl.program_id(1) + pid_q = tl.program_id(2) + if gqa_interleave: + pid_kh = pid_h % NUM_KV_HEADS + else: + pid_kh = pid_h // NUM_SHARE_Q_HEADS + # get q k start and len after rmpad + q_start = tl.load(cu_seqlens_q + pid_b) + q_len = tl.load(cu_seqlens_q + pid_b + 1) - q_start + k_start = tl.load(cu_seqlens_k + pid_b) + k_len = tl.load(cu_seqlens_k + pid_b + 1) - k_start + if BLOCK_SIZE_Q * pid_q >= q_len: + return + # init qkv pointer + q_ptrs = tl.make_block_ptr( + base=q_ptr + q_start * stride_qn + pid_h * stride_qh, + shape=(q_len, HEAD_DIM), + strides=(stride_qn, stride_qd), + offsets=(pid_q * BLOCK_SIZE_Q, 0), + block_shape=(BLOCK_SIZE_Q, BLOCK_SIZE_D), + order=(1, 0), + ) + k_ptrs = tl.make_block_ptr( + base=k_ptr + k_start * stride_kn + pid_kh * stride_kh, + shape=(HEAD_DIM, k_len), + strides=(stride_kd, stride_kn), + offsets=(0, 0), + block_shape=(BLOCK_SIZE_D, BLOCK_SIZE_K), + order=(0, 1), + ) + v_ptrs = tl.make_block_ptr( + base=v_ptr + k_start * stride_vn + pid_kh * stride_vh, + shape=(k_len, HEAD_DIM), + strides=(stride_vn, stride_vd), + offsets=(0, 0), + block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D), + order=(1, 0), + ) + # load q + q = tl.load(q_ptrs, boundary_check=(0, 1), padding_option="zero") + # init statistics + off_q = tl.arange(0, BLOCK_SIZE_Q) + pid_q * BLOCK_SIZE_Q + off_k = tl.arange(0, BLOCK_SIZE_K) + m_i = tl.full((BLOCK_SIZE_Q,), float("-inf"), dtype=tl.float32) + lse_i = tl.full((BLOCK_SIZE_Q,), float("-inf"), dtype=tl.float32) + acc_o = tl.full((BLOCK_SIZE_Q, BLOCK_SIZE_D), 0, dtype=tl.float32) + # full attention or causal attention + lo = 0 + if causal: + hi = min(k_len, (pid_q + 1) * BLOCK_SIZE_Q) + else: + hi = k_len + for i in range(lo, hi, BLOCK_SIZE_K): + i = tl.multiple_of(i, BLOCK_SIZE_K) + # load k + k = tl.load(k_ptrs, boundary_check=(1, 0), padding_option="zero") + # compute qk + qk = tl.zeros((BLOCK_SIZE_Q, BLOCK_SIZE_K), dtype=tl.float32) + if causal: + qk += tl.where(off_q[:, None] >= (i + off_k)[None, :], 0, float("-inf")) + else: + qk += tl.where((off_k < k_len - i)[None, :], 0, float("-inf")) + qk += tl.dot(q, k) * qk_scale + # compute m_ij and l_ij + m_ij = tl.maximum(m_i, tl.max(qk, axis=1)) + p = tl.math.exp2(qk - m_ij[:, None]) + l_ij = tl.sum(p, axis=1) + # scale acc_o + acc_o_scale = tl.math.exp2(m_i - m_ij) + acc_o = acc_o * acc_o_scale[:, None] + # load v and update acc_o + v = tl.load(v_ptrs, boundary_check=(0, 1), padding_option="zero") + p = p.to(v.dtype) + acc_o += tl.dot(p, v) + # update statistics + m_i = m_ij + lse_i = m_ij + tl.math.log2(tl.math.exp2(lse_i - m_ij) + l_ij) + # update ptrs + k_ptrs = tl.advance(k_ptrs, (0, BLOCK_SIZE_K)) + v_ptrs = tl.advance(v_ptrs, (BLOCK_SIZE_K, 0)) + # final scale + acc_o = acc_o * tl.math.exp2(m_i - lse_i)[:, None] + # save output + o_ptrs = tl.make_block_ptr( + base=o_ptr + q_start * stride_on + pid_h * stride_oh, + shape=(q_len, HEAD_DIM), + strides=(stride_on, stride_od), + offsets=(pid_q * BLOCK_SIZE_Q, 0), + block_shape=(BLOCK_SIZE_Q, BLOCK_SIZE_D), + order=(1, 0), + ) + tl.store(o_ptrs, acc_o.to(o_ptr.dtype.element_ty), boundary_check=(0, 1)) + # save lse + l_ptrs = lse_ptr + q_start * stride_ln + pid_h * stride_lh + off_q * stride_ln + tl.store(l_ptrs, lse_i, mask=off_q < q_len) + + +@triton.jit +def backward_sum_o_do( + o_ptr, # O: n x h x d + do_ptr, # dO: n x h x d + delta_ptr, # D: h x n + o_len, + HEAD_DIM, + stride_on, + stride_oh, + stride_od, + stride_don, + stride_doh, + stride_dod, + stride_dh, + stride_dn, + BLOCK_SIZE_O: tl.constexpr, + BLOCK_SIZE_D: tl.constexpr, +): + pid_n = tl.program_id(0) + pid_h = tl.program_id(1) + off_n = pid_n * BLOCK_SIZE_O + tl.arange(0, BLOCK_SIZE_O) + off_d = tl.arange(0, BLOCK_SIZE_D) + o = tl.load( + o_ptr + off_n[:, None] * stride_on + pid_h * stride_oh + off_d[None, :] * stride_od, + mask=(off_n[:, None] < o_len) & (off_d[None, :] < HEAD_DIM), + other=0, + ).to(tl.float32) + do = tl.load( + do_ptr + off_n[:, None] * stride_don + pid_h * stride_doh + off_d[None, :] * stride_dod, + mask=(off_n[:, None] < o_len) & (off_d[None, :] < HEAD_DIM), + other=0, + ).to(tl.float32) + delta = tl.sum(o * do, axis=1) + tl.store(delta_ptr + pid_h * stride_dh + off_n * stride_dn, delta, mask=off_n < o_len) + + +@triton.jit +def backward_dkdv( + q_ptr, # Q: n x qh x d + k_ptr, # K: n x kh x d + v_ptr, # V: n x kh x d + lse_ptr, # LSE: qh x n + d_ptr, # Delta: qh x n + do_ptr, + dk_ptr, # DK: sh x n x kh x d + dv_ptr, # DV: sh x n x kh x d + # seqlens + cu_seqlens_q, + cu_seqlens_k, + # shape + NUM_KV_HEADS, + NUM_SHARE_Q_HEADS, + HEAD_DIM, + # sm_scale + sm_scale, + # causal + causal, + # gqa + gqa_interleave, + # stride + stride_qn, + stride_qh, + stride_qd, + stride_kn, + stride_kh, + stride_kd, + stride_vn, + stride_vh, + stride_vd, + stride_lh, + stride_ln, + stride_dh, + stride_dn, + stride_don, + stride_doh, + stride_dod, + stride_dks, + stride_dkn, + stride_dkh, + stride_dkd, + stride_dvs, + stride_dvn, + stride_dvh, + stride_dvd, + # META parameters + BLOCK_SIZE_Q: tl.constexpr, # q block size + BLOCK_SIZE_K: tl.constexpr, # k block size + BLOCK_SIZE_D: tl.constexpr, +): + qk_scale = sm_scale * 1.44269504 + # get batch id and head id + pid_b = tl.program_id(0) + pid_h = tl.program_id(1) + if gqa_interleave: + pid_kh = pid_h % NUM_SHARE_Q_HEADS + pid_sh = pid_h // NUM_SHARE_Q_HEADS + else: + pid_kh = pid_h // NUM_SHARE_Q_HEADS + pid_sh = pid_h % NUM_SHARE_Q_HEADS + pid_k = tl.program_id(2) + # get q k start and len after rmpad + q_start = tl.load(cu_seqlens_q + pid_b) + q_len = tl.load(cu_seqlens_q + pid_b + 1) - q_start + k_start = tl.load(cu_seqlens_k + pid_b) + k_len = tl.load(cu_seqlens_k + pid_b + 1) - k_start + if BLOCK_SIZE_K * pid_k >= k_len: + return + # init pointers + k_ptrs = tl.make_block_ptr( + base=k_ptr + k_start * stride_kn + pid_kh * stride_kh, + shape=(k_len, HEAD_DIM), + strides=(stride_kn, stride_kd), + offsets=(pid_k * BLOCK_SIZE_K, 0), + block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D), + order=(1, 0), + ) + dk_ptrs = tl.make_block_ptr( + base=dk_ptr + k_start * stride_dkn + pid_kh * stride_dkh + pid_sh * stride_dks, + shape=(k_len, HEAD_DIM), + strides=(stride_dkn, stride_dkd), + offsets=(pid_k * BLOCK_SIZE_K, 0), + block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D), + order=(1, 0), + ) + v_ptrs = tl.make_block_ptr( + base=v_ptr + k_start * stride_vn + pid_kh * stride_vh, + shape=(k_len, HEAD_DIM), + strides=(stride_vn, stride_vd), + offsets=(pid_k * BLOCK_SIZE_K, 0), + block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D), + order=(1, 0), + ) + dv_ptrs = tl.make_block_ptr( + base=dv_ptr + k_start * stride_dvn + pid_kh * stride_dvh + pid_sh * stride_dvs, + shape=(k_len, HEAD_DIM), + strides=(stride_dvn, stride_dvd), + offsets=(pid_k * BLOCK_SIZE_K, 0), + block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D), + order=(1, 0), + ) + # offsets + off_q = tl.arange(0, BLOCK_SIZE_Q) + off_k = tl.arange(0, BLOCK_SIZE_K) + pid_k * BLOCK_SIZE_K + # load k v and keep in SRAM + k = tl.load(k_ptrs, boundary_check=(0, 1), padding_option="zero") + v = tl.load(v_ptrs, boundary_check=(0, 1), padding_option="zero") + # init dk dv + dk = tl.zeros((BLOCK_SIZE_K, BLOCK_SIZE_D), dtype=tl.float32) + dv = tl.zeros((BLOCK_SIZE_K, BLOCK_SIZE_D), dtype=tl.float32) + # causal + if causal: + q_lo = pid_k * BLOCK_SIZE_K + else: + q_lo = 0 + q_ptrs = tl.make_block_ptr( + base=q_ptr + q_start * stride_qn + pid_h * stride_qh, + shape=(q_len, HEAD_DIM), + strides=(stride_qn, stride_qd), + offsets=(q_lo, 0), + block_shape=(BLOCK_SIZE_Q, BLOCK_SIZE_D), + order=(1, 0), + ) + do_ptrs = tl.make_block_ptr( + base=do_ptr + q_start * stride_don + pid_h * stride_doh, + shape=(q_len, HEAD_DIM), + strides=(stride_don, stride_dod), + offsets=(q_lo, 0), + block_shape=(BLOCK_SIZE_Q, BLOCK_SIZE_D), + order=(1, 0), + ) + d_ptrs = tl.make_block_ptr( + base=d_ptr + q_start * stride_dn + pid_h * stride_dh, + shape=(q_len, 1), + strides=(stride_dn, stride_dh), + offsets=(q_lo, 0), + block_shape=(BLOCK_SIZE_Q, 1), + order=(0, 1), + ) + lse_ptrs = tl.make_block_ptr( + base=lse_ptr + q_start * stride_ln + pid_h * stride_lh, + shape=(q_len, 1), + strides=(stride_ln, stride_lh), + offsets=(q_lo, 0), + block_shape=(BLOCK_SIZE_Q, 1), + order=(0, 1), + ) + # loop for q blocks + for i in range(q_lo, q_len, BLOCK_SIZE_Q): + # load + q = tl.load(q_ptrs, boundary_check=(0, 1), padding_option="zero") + do = tl.load(do_ptrs, boundary_check=(0, 1), padding_option="zero") + lse = tl.load(lse_ptrs, boundary_check=(0, 1), padding_option="zero") + d = tl.load(d_ptrs, boundary_check=(0, 1), padding_option="zero") + # compute qk + if causal: + qk = tl.where((off_q + i)[:, None] >= off_k[None, :], float(0.0), float("-inf")) + else: + qk = tl.zeros((BLOCK_SIZE_Q, BLOCK_SIZE_K), dtype=tl.float32) + qk += tl.dot(q, k.T) * qk_scale + # compute p, ds + p = tl.math.exp2(qk - lse) + dp = tl.dot(do, v.T) + ds = sm_scale * p * (dp - d) + # cast dtype + p = p.to(do.dtype) + ds = ds.to(q.dtype) + # update dk and dv + dk += tl.dot(ds.T, q) + dv += tl.dot(p.T, do) + # increment pointers + q_ptrs = tl.advance(q_ptrs, (BLOCK_SIZE_Q, 0)) + do_ptrs = tl.advance(do_ptrs, (BLOCK_SIZE_Q, 0)) + lse_ptrs = tl.advance(lse_ptrs, (BLOCK_SIZE_Q, 0)) + d_ptrs = tl.advance(d_ptrs, (BLOCK_SIZE_Q, 0)) + # save dk dv + tl.store(dk_ptrs, dk.to(dk_ptr.dtype.element_ty), boundary_check=(0, 1)) + tl.store(dv_ptrs, dv.to(dv_ptr.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.jit +def backward_dq( + q_ptr, # Q: n x qh x d + k_ptr, # K: n x kh x d + v_ptr, # V: n x kh x d + lse_ptr, # LSE: qh x n + d_ptr, # Delta: qh x n + do_ptr, + dq_ptr, + # seqlens + cu_seqlens_q, + cu_seqlens_k, + # shape + NUM_KV_HEADS, + NUM_SHARE_Q_HEADS, + HEAD_DIM, + # sm_scale + sm_scale, + # causal + causal, + # gqa + gqa_interleave, + # stride + stride_qn, + stride_qh, + stride_qd, + stride_kn, + stride_kh, + stride_kd, + stride_vn, + stride_vh, + stride_vd, + stride_lh, + stride_ln, + stride_dh, + stride_dn, + stride_don, + stride_doh, + stride_dod, + stride_dqn, + stride_dqh, + stride_dqd, + # META parameters + BLOCK_SIZE_Q: tl.constexpr, # q block size + BLOCK_SIZE_K: tl.constexpr, # k block size + BLOCK_SIZE_D: tl.constexpr, +): + qk_scale = sm_scale * 1.44269504 + # get batch id and head id + pid_b = tl.program_id(0) + pid_h = tl.program_id(1) + pid_q = tl.program_id(2) + if gqa_interleave: + pid_kh = pid_h % NUM_KV_HEADS + else: + pid_kh = pid_h // NUM_SHARE_Q_HEADS + # get q k start and len after rmpad + q_start = tl.load(cu_seqlens_q + pid_b) + q_len = tl.load(cu_seqlens_q + pid_b + 1) - q_start + k_start = tl.load(cu_seqlens_k + pid_b) + k_len = tl.load(cu_seqlens_k + pid_b + 1) - k_start + if BLOCK_SIZE_Q * pid_q >= q_len: + return + # init pointers + q_ptrs = tl.make_block_ptr( + base=q_ptr + q_start * stride_qn + pid_h * stride_qh, + shape=(q_len, HEAD_DIM), + strides=(stride_qn, stride_qd), + offsets=(pid_q * BLOCK_SIZE_Q, 0), + block_shape=(BLOCK_SIZE_Q, BLOCK_SIZE_D), + order=(1, 0), + ) + dq_ptrs = tl.make_block_ptr( + base=dq_ptr + q_start * stride_dqn + pid_h * stride_dqh, + shape=(q_len, HEAD_DIM), + strides=(stride_dqn, stride_dqd), + offsets=(pid_q * BLOCK_SIZE_Q, 0), + block_shape=(BLOCK_SIZE_Q, BLOCK_SIZE_D), + order=(1, 0), + ) + k_ptrs = tl.make_block_ptr( + base=k_ptr + k_start * stride_kn + pid_kh * stride_kh, + shape=(k_len, HEAD_DIM), + strides=(stride_kn, stride_kd), + offsets=(0, 0), + block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D), + order=(1, 0), + ) + v_ptrs = tl.make_block_ptr( + base=v_ptr + k_start * stride_vn + pid_kh * stride_vh, + shape=(k_len, HEAD_DIM), + strides=(stride_vn, stride_vd), + offsets=(0, 0), + block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D), + order=(1, 0), + ) + do_ptrs = tl.make_block_ptr( + base=do_ptr + q_start * stride_don + pid_h * stride_doh, + shape=(q_len, HEAD_DIM), + strides=(stride_don, stride_dod), + offsets=(pid_q * BLOCK_SIZE_Q, 0), + block_shape=(BLOCK_SIZE_Q, BLOCK_SIZE_D), + order=(1, 0), + ) + d_ptrs = tl.make_block_ptr( + base=d_ptr + q_start * stride_dn + pid_h * stride_dh, + shape=(q_len, 1), + strides=(stride_dn, stride_dh), + offsets=(pid_q * BLOCK_SIZE_Q, 0), + block_shape=(BLOCK_SIZE_Q, 1), + order=(0, 1), + ) + lse_ptrs = tl.make_block_ptr( + base=lse_ptr + q_start * stride_ln + pid_h * stride_lh, + shape=(q_len, 1), + strides=(stride_ln, stride_lh), + offsets=(pid_q * BLOCK_SIZE_Q, 0), + block_shape=(BLOCK_SIZE_Q, 1), + order=(0, 1), + ) + # offsets + off_q = tl.arange(0, BLOCK_SIZE_Q) + pid_q * BLOCK_SIZE_Q + off_k = tl.arange(0, BLOCK_SIZE_K) + # load q, do, lse, delta, and keep in SRAM + q = tl.load(q_ptrs, boundary_check=(1, 0), padding_option="zero") + do = tl.load(do_ptrs, boundary_check=(0, 1), padding_option="zero") + lse = tl.load(lse_ptrs, boundary_check=(0, 1), padding_option="zero") + d = tl.load(d_ptrs, boundary_check=(0, 1), padding_option="zero") + # init dq + dq = tl.zeros((BLOCK_SIZE_Q, BLOCK_SIZE_D), dtype=tl.float32) + # causal + if causal: + k_hi = (pid_q + 1) * BLOCK_SIZE_Q + else: + k_hi = k_len + for j in range(0, k_hi, BLOCK_SIZE_K): + # load + k = tl.load(k_ptrs, boundary_check=(0, 1), padding_option="zero") + v = tl.load(v_ptrs, boundary_check=(0, 1), padding_option="zero") + # compute qk + if causal: + qk = tl.where(off_q[:, None] >= (off_k + j)[None, :], float(0.0), float("-inf")) + else: + qk = tl.zeros((BLOCK_SIZE_Q, BLOCK_SIZE_K), dtype=tl.float32) + qk += tl.dot(q, k.T) * qk_scale + # compute p, ds + p = tl.math.exp2(qk - lse) + dp = tl.dot(do, v.T) + ds = sm_scale * p * (dp - d) + # cast dtype + ds = ds.to(q.dtype) + # update dq + dq += tl.dot(ds, k) + # increment pointers + k_ptrs = tl.advance(k_ptrs, (BLOCK_SIZE_K, 0)) + v_ptrs = tl.advance(v_ptrs, (BLOCK_SIZE_K, 0)) + # save dq + tl.store(dq_ptrs, dq.to(dq_ptr.dtype.element_ty), boundary_check=(0, 1)) + + +def _flash_attention_fwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + max_seqlen_q: torch.Tensor, + max_seqlen_k: torch.Tensor, + causal: bool, + sm_scale: float, + gqa_interleave: bool = False, +): + # dtype check + assert q.dtype == torch.bfloat16 or q.dtype == torch.float16 + assert k.dtype == q.dtype and v.dtype == q.dtype + assert cu_seqlens_q.dtype == torch.int32 and cu_seqlens_k.dtype == torch.int32 + # shape + q_len, num_q_heads, head_dim = q.shape + k_len, num_k_heads, head_dim = k.shape + v_len, num_v_heads, head_dim = v.shape + batch_size = cu_seqlens_q.shape[0] - 1 + # assert q_len == k_len and k_len == v_len + # gqa + assert num_k_heads == num_v_heads + assert num_q_heads % num_k_heads == 0 + num_share_q_heads = num_q_heads // num_k_heads + # output tensor + o = torch.empty_like(q) + lse = torch.empty(num_q_heads, q_len, dtype=torch.float32, device=q.device) + # launch kernel + grid = lambda META: ( + batch_size, + num_q_heads, + triton.cdiv(max_seqlen_q, META["BLOCK_SIZE_Q"]), + ) + BLOCK_SIZE_Q = 128 + BLOCK_SIZE_K = 64 + BLOCK_SIZE_D = triton.next_power_of_2(head_dim) + num_warps, num_stages = get_num_warps_stages(head_dim, BLOCK_SIZE_Q, IS_HOPPER_GPU) + forward_kernel[grid]( + q, + k, + v, + o, + lse, + cu_seqlens_q, + cu_seqlens_k, + num_k_heads, + num_share_q_heads, + head_dim, + sm_scale, + causal, + gqa_interleave, + q.stride(0), + q.stride(1), + q.stride(2), + k.stride(0), + k.stride(1), + k.stride(2), + v.stride(0), + v.stride(1), + v.stride(2), + o.stride(0), + o.stride(1), + o.stride(2), + lse.stride(0), + lse.stride(1), + BLOCK_SIZE_Q=BLOCK_SIZE_Q, + BLOCK_SIZE_K=BLOCK_SIZE_K, + BLOCK_SIZE_D=BLOCK_SIZE_D, + num_warps=num_warps, + num_stages=num_stages, + ) + return o, lse + + +def _flash_attention_bwd( + o: torch.Tensor, + do: torch.Tensor, + lse: torch.Tensor, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + max_seqlen_q: torch.Tensor, + max_seqlen_k: torch.Tensor, + causal: bool, + sm_scale: float, + gqa_interleave: bool = False, +): + q_len, num_q_heads, head_dim = q.shape + k_len, num_k_heads, head_dim = k.shape + v_len, num_v_heads, head_dim = v.shape + o_len, num_o_heads, head_dim = o.shape + num_share_q_heads = num_q_heads // num_k_heads + # compute D + delta = torch.empty([num_o_heads, o_len], device=o.device, dtype=torch.float32) + grid = lambda META: (triton.cdiv(o_len, META["BLOCK_SIZE_O"]), num_o_heads) + BLOCK_SIZE_O = 256 + BLOCK_SIZE_D = triton.next_power_of_2(head_dim) + num_warps, num_stages = get_num_warps_stages(head_dim, BLOCK_SIZE_O, IS_HOPPER_GPU) + backward_sum_o_do[grid]( + o, + do, + delta, + o_len, + head_dim, + o.stride(0), + o.stride(1), + o.stride(2), + do.stride(0), + do.stride(1), + do.stride(2), + delta.stride(0), + delta.stride(1), + BLOCK_SIZE_O=BLOCK_SIZE_O, + BLOCK_SIZE_D=BLOCK_SIZE_D, + num_warps=num_warps, + num_stages=num_stages, + ) + # compute dk dv + dk = torch.empty(num_share_q_heads, k_len, num_k_heads, head_dim, device=k.device, dtype=k.dtype) + dv = torch.empty(num_share_q_heads, k_len, num_k_heads, head_dim, device=k.device, dtype=k.dtype) + batch_size = cu_seqlens_q.shape[0] - 1 + grid = lambda META: ( + batch_size, + num_q_heads, + triton.cdiv(max_seqlen_k, META["BLOCK_SIZE_K"]), + ) + BLOCK_SIZE_Q = 64 + BLOCK_SIZE_K = 64 + BLOCK_SIZE_D = triton.next_power_of_2(head_dim) + num_warps, num_stages = get_num_warps_stages(head_dim, BLOCK_SIZE_K, IS_HOPPER_GPU) + backward_dkdv[grid]( + q, + k, + v, + lse, + delta, + do, + dk, + dv, + cu_seqlens_q, + cu_seqlens_k, + num_k_heads, + num_share_q_heads, + head_dim, + sm_scale, + causal, + gqa_interleave, + q.stride(0), + q.stride(1), + q.stride(2), + k.stride(0), + k.stride(1), + k.stride(2), + v.stride(0), + v.stride(1), + v.stride(2), + lse.stride(0), + lse.stride(1), + delta.stride(0), + delta.stride(1), + do.stride(0), + do.stride(1), + do.stride(2), + dk.stride(0), + dk.stride(1), + dk.stride(2), + dk.stride(3), + dv.stride(0), + dv.stride(1), + dv.stride(2), + dv.stride(3), + BLOCK_SIZE_Q=BLOCK_SIZE_Q, + BLOCK_SIZE_K=BLOCK_SIZE_K, + BLOCK_SIZE_D=BLOCK_SIZE_D, + num_warps=num_warps, + num_stages=num_stages, + ) + dk = dk.sum(0) + dv = dv.sum(0) + # compute dq + dq = torch.empty_like(q) + grid = lambda META: ( + batch_size, + num_q_heads, + triton.cdiv(max_seqlen_q, META["BLOCK_SIZE_Q"]), + ) + BLOCK_SIZE_Q = 128 + BLOCK_SIZE_K = 64 + num_warps, num_stages = get_num_warps_stages(head_dim, BLOCK_SIZE_Q, IS_HOPPER_GPU) + backward_dq[grid]( + q, + k, + v, + lse, + delta, + do, + dq, + cu_seqlens_q, + cu_seqlens_k, + num_k_heads, + num_share_q_heads, + head_dim, + sm_scale, + causal, + gqa_interleave, + q.stride(0), + q.stride(1), + q.stride(2), + k.stride(0), + k.stride(1), + k.stride(2), + v.stride(0), + v.stride(1), + v.stride(2), + lse.stride(0), + lse.stride(1), + delta.stride(0), + delta.stride(1), + do.stride(0), + do.stride(1), + do.stride(2), + dq.stride(0), + dq.stride(1), + dq.stride(2), + BLOCK_SIZE_Q=BLOCK_SIZE_Q, + BLOCK_SIZE_K=BLOCK_SIZE_K, + BLOCK_SIZE_D=BLOCK_SIZE_D, + num_warps=num_warps, + num_stages=num_stages, + ) + return dq, dk, dv + + +class FlashAttention(torch.autograd.Function): + @staticmethod + def forward( + ctx, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + max_seqlen_q: torch.Tensor, + max_seqlen_k: torch.Tensor, + causal=True, + sm_scale=None, + gqa_interleave=False, + ): + # softmax scale + if sm_scale is None: + sm_scale = 1 / math.sqrt(q.shape[-1]) + o, lse = _flash_attention_fwd( + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + causal, + sm_scale, + gqa_interleave, + ) + ctx.save_for_backward(q, k, v, o, lse, cu_seqlens_q, cu_seqlens_k) + ctx.sm_scale = sm_scale + ctx.max_seqlen_q = max_seqlen_q + ctx.max_seqlen_k = max_seqlen_k + ctx.causal = causal + ctx.gqa_interleave = gqa_interleave + return o + + @staticmethod + def backward(ctx, do: torch.Tensor, *args) -> Any: + q, k, v, o, lse, cu_seqlens_q, cu_seqlens_k = ctx.saved_tensors + max_seqlen_q = ctx.max_seqlen_q + max_seqlen_k = ctx.max_seqlen_k + sm_scale = ctx.sm_scale + causal = ctx.causal + gqa_interleave = ctx.gqa_interleave + dq, dk, dv = _flash_attention_bwd( + o, + do, + lse, + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + causal, + sm_scale, + gqa_interleave, + ) + return dq, dk, dv, None, None, None, None, None, None, None + + +def flash_attention_varlen( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + max_seqlen_q: torch.Tensor, + max_seqlen_k: torch.Tensor, + causal: bool = False, + sm_scale: Optional[float] = None, + gqa_interleave: bool = False, +) -> torch.Tensor: + """Flash attention with variable length based on triton. + + Args: + q (torch.Tensor): shape [total_q_len, num_q_heads, head_dim] + k (torch.Tensor): shape [total_kv_len, num_q_heads, head_dim] + v (torch.Tensor): shape [total_kv_len, num_q_heads, head_dim] + cu_seqlens_q (torch.Tensor): shape [batch_size + 1], similar to cu_seqlens_q in flash_attn_func_varlen. + cu_seqlens_k (torch.Tensor): shape [batch_size + 1], similar to cu_seqlens_k in flash_attn_func_varlen. + max_seqlen_q (torch.Tensor): max q len of the batch. + max_seqlen_k (torch.Tensor): max k len of the batch. + causal (bool, optional): Causal mask. Defaults to False. + sm_scale (float, optional): softmax scale. Defaults to None, means 1/sqrt(head_dim). + gqa_interleave (bool, optional): GQA pattern. Defaults to False, use Llama style GQA. + + Returns: + torch.Tensor: attention output with shape [total_q_len, num_q_heads, head_dim] + """ + return FlashAttention.apply( + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + causal, + sm_scale, + gqa_interleave, + ) diff --git a/vortex_torch/kernels/nsa/utils.py b/vortex_torch/kernels/nsa/utils.py new file mode 100644 index 0000000..1f158a1 --- /dev/null +++ b/vortex_torch/kernels/nsa/utils.py @@ -0,0 +1,50 @@ +import torch + + +def is_hopper_gpu(): + if torch.cuda.is_available(): + device_capability = torch.cuda.get_device_capability(0) + major, minor = device_capability + return major == 9 + return False + + +def get_num_warps_stages(head_dim, block_size, is_hopper_gpu): + """ + Returns recommended num_warps and num_stages for a Sparse Attention kernel in Triton. + + Args: + head_dim (int): Size of the head dimension. + block_size (int): Size of the block in the attention matrix. + is_hopper_gpu (bool): True if Hopper GPU, False if Ampere GPU. + + Returns: + tuple: (num_warps, num_stages) recommended values. + """ + # Determine if head_dim and block_size exceed 64 + head_large = head_dim > 64 + block_large = block_size > 64 + + if is_hopper_gpu: + # Hopper GPU recommendations + if head_large and block_large: + num_warps = 8 + num_stages = 3 + elif head_large or block_large: + num_warps = 4 + num_stages = 3 + else: + num_warps = 2 + num_stages = 2 + else: + # Ampere GPU recommendations + if head_large and block_large: + num_warps = 8 + num_stages = 3 + elif head_large or block_large: + num_warps = 8 + num_stages = 3 + else: + num_warps = 2 + num_stages = 2 + return num_warps, num_stages diff --git a/vortex_torch/kernels/nsa/weighted_pool.py b/vortex_torch/kernels/nsa/weighted_pool.py new file mode 100644 index 0000000..abfe9d3 --- /dev/null +++ b/vortex_torch/kernels/nsa/weighted_pool.py @@ -0,0 +1,341 @@ +# Copyright 2025 Xunhao Lai. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional + +import torch +import triton +import triton.language as tl +from einops import einsum + + +@triton.jit +def sliding_pool_fwd_kernel( + x_ptr, + y_ptr, + w_ptr, + cu_seqlens, + y_cu_seqlens, + head_dim, + kernel_size, + kernel_stride, + stride_xn, + stride_xh, + stride_xd, + stride_yn, + stride_yh, + stride_yd, + stride_wh, + stride_wk, + BLOCK_SIZE_K: tl.constexpr, + BLOCK_SIZE_D: tl.constexpr, +): + pid_b = tl.program_id(0) + pid_h = tl.program_id(1) + pid_k = tl.program_id(2) + # get start and len after rmpad + x_start = tl.load(cu_seqlens + pid_b) + x_len = tl.load(cu_seqlens + pid_b + 1) - x_start + y_start = tl.load(y_cu_seqlens + pid_b) + y_len = tl.load(y_cu_seqlens + pid_b + 1) - y_start + if pid_k >= y_len: + return + if w_ptr is not None: + # load w + w_ptrs = tl.make_block_ptr( + base=w_ptr + pid_h * stride_wh, + shape=(kernel_size, 1), + strides=(stride_wk, 0), + offsets=(0, 0), + block_shape=(BLOCK_SIZE_K, 1), + order=(0, 1), + ) + w = tl.load(w_ptrs, boundary_check=(0, 1), padding_option="zero") + # load x + x_ptrs = tl.make_block_ptr( + base=x_ptr + x_start * stride_xn + pid_h * stride_xh, + shape=(x_len, head_dim), + strides=(stride_xn, stride_xd), + offsets=(pid_k * kernel_stride, 0), + block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D), + order=(1, 0), + ) + x = tl.load(x_ptrs, boundary_check=(0, 1), padding_option="zero") + # compute y + if w_ptr is not None: + y = tl.sum(x * w, axis=0) + else: + y = tl.sum(x, axis=0) / kernel_size + off_d = tl.arange(0, BLOCK_SIZE_D) + tl.store( + y_ptr + (y_start + pid_k) * stride_yn + pid_h * stride_yh + off_d * stride_yd, + y.to(y_ptr.dtype.element_ty), + mask=off_d < head_dim, + ) + + +@triton.jit +def sliding_pool_dxdw_kernel( + x_ptr, + dx_ptr, + dy_ptr, + w_ptr, + dw_ptr, + cu_seqlens, + y_cu_seqlens, + head_dim, + kernel_size, + kernel_stride, + stride_xn, + stride_xh, + stride_xd, + stride_dxn, + stride_dxh, + stride_dxd, + stride_dyn, + stride_dyh, + stride_dyd, + stride_wh, + stride_wk, + stride_dwh, + stride_dwn, + stride_dwk, + BLOCK_SIZE_K: tl.constexpr, + BLOCK_SIZE_D: tl.constexpr, +): + pid_b = tl.program_id(0) + pid_h = tl.program_id(1) + pid_k = tl.program_id(2) + # get start and len after rmpad + x_start = tl.load(cu_seqlens + pid_b) + x_len = tl.load(cu_seqlens + pid_b + 1) - x_start + y_start = tl.load(y_cu_seqlens + pid_b) + y_len = tl.load(y_cu_seqlens + pid_b + 1) - y_start + if pid_k >= y_len: + return + # offsets + off_d = tl.arange(0, BLOCK_SIZE_D) + off_k = tl.arange(0, BLOCK_SIZE_K) + if w_ptr is not None: + # load w + w_ptrs = w_ptr + pid_h * stride_wh + off_k * stride_wk + w = tl.load(w_ptrs, mask=off_k < kernel_size, other=0) + # load x + x_ptrs = tl.make_block_ptr( + base=x_ptr + x_start * stride_xn + pid_h * stride_xh, + shape=(head_dim, x_len), + strides=(stride_xd, stride_xn), + offsets=(0, pid_k * kernel_stride), + block_shape=(BLOCK_SIZE_D, BLOCK_SIZE_K), + order=(0, 1), + ) + x = tl.load(x_ptrs, boundary_check=(0, 1), padding_option="zero") + # load dy + dy_ptrs = dy_ptr + pid_h * stride_dyh + (y_start + pid_k) * stride_dyn + off_d * stride_dyd + dy = tl.load(dy_ptrs, mask=off_d < head_dim, other=0) + if w_ptr is not None: + # compute dx, [1, D] x [K, 1] -> [K, D] + dx = dy[None, :] * w[:, None] + # compute dw, [D, 1] x [D, K] -> [D, K] -> [K] + dw = tl.sum(dy[:, None] * x, axis=0) + # store dw + dw_ptrs = dw_ptr + pid_h * stride_dwh + (y_start + pid_k) * stride_dwn + off_k * stride_dwk + tl.store(dw_ptrs, dw.to(dw_ptr.dtype.element_ty), mask=off_k < kernel_size) + else: + dx = dy[None, :] / kernel_size + # store dx + dx_ptrs = ( + dx_ptr + + pid_h * stride_dxh + + (x_start + pid_k * kernel_stride + off_k[:, None]) * stride_dxn + + off_d[None, :] * stride_dxd + ) + tl.atomic_add( + dx_ptrs, + dx.to(dx_ptr.dtype.element_ty), + mask=(off_k < x_len - pid_k * kernel_stride)[:, None] & (off_d < head_dim)[None, :], + ) + + +class SlidingWindowWeightedPool(torch.autograd.Function): + @staticmethod + def forward( + ctx, + x: torch.Tensor, # [total_len, num_heads, head_dim] + w: torch.Tensor, # [num_heads, kernel_size] + cu_seqlens: torch.Tensor, + kernel_size: int, + kernel_stride: int, + ): + # dtype check + assert x.dtype == torch.float16 or x.dtype == torch.bfloat16 + if w is not None: + assert x.dtype == w.dtype + assert cu_seqlens.dtype == torch.int32 + # shape check + total_len, num_heads, head_dim = x.shape + batch_size = cu_seqlens.shape[0] - 1 + if w is not None: + assert w.shape[0] == num_heads + assert w.shape[1] == kernel_size + assert kernel_size % kernel_stride == 0 + assert kernel_size in {16, 32, 64, 128} + # compute seqlens after compression + seqlens = cu_seqlens[1:] - cu_seqlens[:-1] + y_seqlens = torch.floor((seqlens - kernel_size) / kernel_stride).to(torch.int32) + 1 + # corner case, if sequence_length < kernel_size, no compression for this sequence + y_seqlens[seqlens < kernel_size] = 0 + y_cu_seqlens = torch.cat( + [ + torch.zeros(1, dtype=torch.int32, device="cuda"), + torch.cumsum(y_seqlens, dim=0), + ], + dim=0, + ).to(torch.int32) + # output buffer + y = torch.zeros(y_cu_seqlens[-1], num_heads, head_dim, dtype=x.dtype, device=x.device) + # launch kernel + BLOCK_SIZE_D = triton.next_power_of_2(head_dim) + BLOCK_SIZE_K = triton.next_power_of_2(kernel_size) + grid = (batch_size, num_heads, y_seqlens.max().item()) + sliding_pool_fwd_kernel[grid]( + x, + y, + w, + cu_seqlens, + y_cu_seqlens, + head_dim, + kernel_size, + kernel_stride, + x.stride(0), + x.stride(1), + x.stride(2), + y.stride(0), + y.stride(1), + y.stride(2), + w.stride(0) if w is not None else None, + w.stride(1) if w is not None else None, + BLOCK_SIZE_K=BLOCK_SIZE_K, + BLOCK_SIZE_D=BLOCK_SIZE_D, + ) + ctx.save_for_backward(x, w, seqlens, cu_seqlens, y_seqlens, y_cu_seqlens) + ctx.kernel_size = kernel_size + ctx.kernel_stride = kernel_stride + ctx.head_dim = head_dim + return y, y_cu_seqlens + + @staticmethod + def backward(ctx, dy, _): + x, w, seqlens, cu_seqlens, y_seqlens, y_cu_seqlens = ctx.saved_tensors + kernel_size = ctx.kernel_size + kernel_stride = ctx.kernel_stride + head_dim = ctx.head_dim + batch_size = cu_seqlens.shape[0] - 1 + num_heads = x.shape[1] + # compute dx + dx = torch.zeros_like(x, dtype=torch.float32) + if w is not None: + dw = torch.zeros( + num_heads, + y_cu_seqlens[-1], + kernel_size, + dtype=torch.float32, + device=w.device, + ) + BLOCK_SIZE_D = triton.next_power_of_2(head_dim) + BLOCK_SIZE_K = triton.next_power_of_2(kernel_size) + grid = (batch_size, num_heads, y_seqlens.max().item()) + sliding_pool_dxdw_kernel[grid]( + x, + dx, + dy, + w, + dw if w is not None else None, + cu_seqlens, + y_cu_seqlens, + head_dim, + kernel_size, + kernel_stride, + x.stride(0), + x.stride(1), + x.stride(2), + dx.stride(0), + dx.stride(1), + dx.stride(2), + dy.stride(0), + dy.stride(1), + dy.stride(2), + w.stride(0) if w is not None else None, + w.stride(1) if w is not None else None, + dw.stride(0) if w is not None else None, + dw.stride(1) if w is not None else None, + dw.stride(2) if w is not None else None, + BLOCK_SIZE_K=BLOCK_SIZE_K, + BLOCK_SIZE_D=BLOCK_SIZE_D, + ) + dx = dx.to(x.dtype) + if w is None: + dw = None + else: + dw = dw.sum(1).to(w.dtype) + return dx, dw, None, None, None + + +def weightedpool_compress( + x: torch.Tensor, # [total_len, num_heads, head_dim] + w: torch.Tensor, # [num_heads, kernel_size] + cu_seqlens: torch.Tensor, + kernel_size: int, + kernel_stride: int, + pe: Optional[torch.Tensor] = None, +): + y, y_cu_seqlens = SlidingWindowWeightedPool.apply(x, w, cu_seqlens, kernel_size, kernel_stride) + if pe is not None: + assert pe.dtype == x.dtype and pe.device == x.device + bias = einsum(pe, w, "h k d, h k -> h d") + y = y + bias.unsqueeze(0) + return y, y_cu_seqlens + + +def avgpool_compress( + x: torch.Tensor, # [total_len, num_heads, head_dim] + w: torch.Tensor, # don't need weight + cu_seqlens: torch.Tensor, + kernel_size: int, + kernel_stride: int, + pe: Optional[torch.Tensor] = None, +): + assert w is None, "don't need additional weight for avgpool" + y, y_cu_seqlens = SlidingWindowWeightedPool.apply(x, w, cu_seqlens, kernel_size, kernel_stride) + if pe is not None: + assert pe.dtype == x.dtype and pe.device == x.device + bias = torch.mean(pe, dim=1) + y = y + bias.unsqueeze(0) + return y, y_cu_seqlens + + +def softmaxpool_compress( + x: torch.Tensor, + w: torch.Tensor, + cu_seqlens: torch.Tensor, + kernel_size: int, + kernel_stride: int, + pe: Optional[torch.Tensor] = None, +): + y, y_cu_seqlens = SlidingWindowWeightedPool.apply(x, w.softmax(-1), cu_seqlens, kernel_size, kernel_stride) + if pe is not None: + assert pe.dtype == x.dtype and pe.device == x.device + bias = torch.mean(pe, dim=1) + y = y + bias.unsqueeze(0) + return y, y_cu_seqlens