diff --git a/.buildkite/generate_index.py b/.buildkite/generate_index.py
deleted file mode 100644
index bbed80ebe847..000000000000
--- a/.buildkite/generate_index.py
+++ /dev/null
@@ -1,46 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-
-import argparse
-import os
-
-template = """
-
-
- Links for vLLM
- {x86_wheel}
- {arm_wheel}
-
-
-"""
-
-parser = argparse.ArgumentParser()
-parser.add_argument("--wheel", help="The wheel path.", required=True)
-args = parser.parse_args()
-
-filename = os.path.basename(args.wheel)
-
-with open("index.html", "w") as f:
- print(f"Generated index.html for {args.wheel}")
- # sync the abi tag with .buildkite/scripts/upload-wheels.sh
- if "x86_64" in filename:
- x86_wheel = filename
- arm_wheel = filename.replace("x86_64", "aarch64").replace(
- "manylinux1", "manylinux2014"
- )
- elif "aarch64" in filename:
- x86_wheel = filename.replace("aarch64", "x86_64").replace(
- "manylinux2014", "manylinux1"
- )
- arm_wheel = filename
- else:
- raise ValueError(f"Unsupported wheel: {filename}")
- # cloudfront requires escaping the '+' character
- f.write(
- template.format(
- x86_wheel=x86_wheel,
- x86_wheel_html_escaped=x86_wheel.replace("+", "%2B"),
- arm_wheel=arm_wheel,
- arm_wheel_html_escaped=arm_wheel.replace("+", "%2B"),
- )
- )
diff --git a/.buildkite/release-pipeline.yaml b/.buildkite/release-pipeline.yaml
index 38c400ba1faf..fbfc923998f8 100644
--- a/.buildkite/release-pipeline.yaml
+++ b/.buildkite/release-pipeline.yaml
@@ -8,7 +8,7 @@ steps:
commands:
# #NOTE: torch_cuda_arch_list is derived from upstream PyTorch build files here:
# https://github.com/pytorch/pytorch/blob/main/.ci/aarch64_linux/aarch64_ci_build.sh#L7
- - "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=12.9.1 --build-arg VLLM_MAIN_CUDA_VERSION=12.9 --build-arg torch_cuda_arch_list='8.7 8.9 9.0 10.0+PTX 12.0' --tag vllm-ci:build-image --target build --progress plain -f docker/Dockerfile ."
+ - "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=12.9.1 --build-arg torch_cuda_arch_list='8.7 8.9 9.0 10.0+PTX 12.0' --tag vllm-ci:build-image --target build --progress plain -f docker/Dockerfile ."
- "mkdir artifacts"
- "docker run --rm -v $(pwd)/artifacts:/artifacts_host vllm-ci:build-image bash -c 'cp -r dist /artifacts_host && chmod -R a+rw /artifacts_host'"
- "bash .buildkite/scripts/upload-wheels.sh"
@@ -30,19 +30,6 @@ steps:
DOCKER_BUILDKIT: "1"
# x86 + CUDA builds
- - label: "Build wheel - CUDA 12.8"
- depends_on: ~
- id: build-wheel-cuda-12-8
- agents:
- queue: cpu_queue_postmerge
- commands:
- - "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=12.8.1 --tag vllm-ci:build-image --target build --progress plain -f docker/Dockerfile ."
- - "mkdir artifacts"
- - "docker run --rm -v $(pwd)/artifacts:/artifacts_host vllm-ci:build-image bash -c 'cp -r dist /artifacts_host && chmod -R a+rw /artifacts_host'"
- - "bash .buildkite/scripts/upload-wheels.sh"
- env:
- DOCKER_BUILDKIT: "1"
-
- label: "Build wheel - CUDA 12.9"
depends_on: ~
id: build-wheel-cuda-12-9
@@ -109,7 +96,6 @@ steps:
- label: "Annotate release workflow"
depends_on:
- create-multi-arch-manifest
- - build-wheel-cuda-12-8
id: annotate-release-workflow
agents:
queue: cpu_queue_postmerge
diff --git a/.buildkite/scripts/annotate-release.sh b/.buildkite/scripts/annotate-release.sh
index 56bb5cedaa0a..df805e085080 100755
--- a/.buildkite/scripts/annotate-release.sh
+++ b/.buildkite/scripts/annotate-release.sh
@@ -23,8 +23,8 @@ To download the wheel (by version):
aws s3 cp s3://vllm-wheels/${RELEASE_VERSION}/vllm-${RELEASE_VERSION}-cp38-abi3-manylinux1_x86_64.whl .
aws s3 cp s3://vllm-wheels/${RELEASE_VERSION}/vllm-${RELEASE_VERSION}-cp38-abi3-manylinux2014_aarch64.whl .
-aws s3 cp s3://vllm-wheels/${RELEASE_VERSION}+cu126/vllm-${RELEASE_VERSION}+cu126-cp38-abi3-manylinux1_x86_64.whl .
aws s3 cp s3://vllm-wheels/${RELEASE_VERSION}+cu129/vllm-${RELEASE_VERSION}+cu129-cp38-abi3-manylinux1_x86_64.whl .
+aws s3 cp s3://vllm-wheels/${RELEASE_VERSION}+cu130/vllm-${RELEASE_VERSION}+cu130-cp38-abi3-manylinux1_x86_64.whl .
\`\`\`
To download and upload the image:
@@ -45,9 +45,10 @@ docker tag vllm/vllm-openai:aarch64 vllm/vllm-openai:v${RELEASE_VERSION}-aarch64
docker push vllm/vllm-openai:latest-aarch64
docker push vllm/vllm-openai:v${RELEASE_VERSION}-aarch64
-docker manifest create vllm/vllm-openai:latest vllm/vllm-openai:latest-x86_64 vllm/vllm-openai:latest-aarch64 --amend
-docker manifest create vllm/vllm-openai:v${RELEASE_VERSION} vllm/vllm-openai:v${RELEASE_VERSION}-x86_64 vllm/vllm-openai:v${RELEASE_VERSION}-aarch64 --amend
+docker manifest rm vllm/vllm-openai:latest
+docker manifest create vllm/vllm-openai:latest vllm/vllm-openai:latest-x86_64 vllm/vllm-openai:latest-aarch64
+docker manifest create vllm/vllm-openai:v${RELEASE_VERSION} vllm/vllm-openai:v${RELEASE_VERSION}-x86_64 vllm/vllm-openai:v${RELEASE_VERSION}-aarch64
docker manifest push vllm/vllm-openai:latest
docker manifest push vllm/vllm-openai:v${RELEASE_VERSION}
\`\`\`
-EOF
\ No newline at end of file
+EOF
diff --git a/.buildkite/scripts/generate-nightly-index.py b/.buildkite/scripts/generate-nightly-index.py
new file mode 100644
index 000000000000..a61f08107647
--- /dev/null
+++ b/.buildkite/scripts/generate-nightly-index.py
@@ -0,0 +1,368 @@
+#!/usr/bin/env python3
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+# do not complain about line length (for docstring)
+# ruff: noqa: E501
+
+import argparse
+import json
+import re
+import sys
+from dataclasses import asdict, dataclass
+from pathlib import Path
+from typing import Any
+from urllib.parse import quote
+
+if not sys.version_info >= (3, 10):
+ raise RuntimeError("This script requires Python 3.10 or higher.")
+
+INDEX_HTML_TEMPLATE = """
+
+
+
+{items}
+
+
+"""
+
+
+@dataclass
+class WheelFileInfo:
+ package_name: str
+ version: str
+ build_tag: str | None
+ python_tag: str
+ abi_tag: str
+ platform_tag: str
+ variant: str | None
+ filename: str
+
+
+def parse_from_filename(file: str) -> WheelFileInfo:
+ """
+ Parse wheel file name to extract metadata.
+
+ The format of wheel names:
+ {package_name}-{version}(-{build_tag})?-{python_tag}-{abi_tag}-{platform_tag}.whl
+ All versions could contain a variant like '+cu129' or '.cpu' or `.rocm` (or not).
+ Example:
+ vllm-0.11.0-cp38-abi3-manylinux1_x86_64.whl
+ vllm-0.10.2rc2+cu129-cp38-abi3-manylinux2014_aarch64.whl
+ vllm-0.11.1rc8.dev14+gaa384b3c0-cp38-abi3-manylinux2014_aarch64.whl
+ vllm-0.11.1rc8.dev14+gaa384b3c0.cu130-cp38-abi3-manylinux1_x86_64.whl
+ """
+ wheel_file_re = re.compile(
+ r"^(?P.+)-(?P[^-]+?)(-(?P[^-]+))?-(?P[^-]+)-(?P[^-]+)-(?P[^-]+)\.whl$"
+ )
+ match = wheel_file_re.match(file)
+ if not match:
+ raise ValueError(f"Invalid wheel file name: {file}")
+
+ package_name = match.group("package_name")
+ version = match.group("version")
+ build_tag = match.group("build_tag")
+ python_tag = match.group("python_tag")
+ abi_tag = match.group("abi_tag")
+ platform_tag = match.group("platform_tag")
+
+ # extract variant from version
+ variant = None
+ if "dev" in version:
+ ver_after_dev = version.split("dev")[-1]
+ if "." in ver_after_dev:
+ variant = ver_after_dev.split(".")[-1]
+ version = version.removesuffix("." + variant)
+ else:
+ if "+" in version:
+ version, variant = version.split("+")
+
+ return WheelFileInfo(
+ package_name=package_name,
+ version=version,
+ build_tag=build_tag,
+ python_tag=python_tag,
+ abi_tag=abi_tag,
+ platform_tag=platform_tag,
+ variant=variant,
+ filename=file,
+ )
+
+
+def generate_project_list(subdir_names: list[str]) -> str:
+ """
+ Generate project list HTML content linking to each project & variant sub-directory.
+ """
+ href_tags = []
+ for name in sorted(subdir_names):
+ name = name.strip("/").strip(".")
+ href_tags.append(f' {name}/
')
+ return INDEX_HTML_TEMPLATE.format(items="\n".join(href_tags))
+
+
+def generate_package_index_and_metadata(
+ wheel_files: list[WheelFileInfo], wheel_base_dir: Path, index_base_dir: Path
+) -> tuple[str, str]:
+ """
+ Generate package index HTML content for a specific package, linking to actual wheel files.
+ """
+ href_tags = []
+ metadata = []
+ for file in sorted(wheel_files, key=lambda x: x.filename):
+ relative_path = (
+ wheel_base_dir.relative_to(index_base_dir, walk_up=True) / file.filename
+ )
+ href_tags.append(
+ f' {file.filename}
'
+ )
+ file_meta = asdict(file)
+ file_meta["path"] = relative_path.as_posix()
+ metadata.append(file_meta)
+ index_str = INDEX_HTML_TEMPLATE.format(items="\n".join(href_tags))
+ metadata_str = json.dumps(metadata, indent=2)
+ return index_str, metadata_str
+
+
+def generate_index_and_metadata(
+ whl_files: list[str],
+ wheel_base_dir: Path,
+ index_base_dir: Path,
+ default_variant: str | None = None,
+ alias_to_default: str | None = None,
+):
+ """
+ Generate index for all wheel files.
+
+ Args:
+ whl_files (list[str]): List of wheel files (must be directly under `wheel_base_dir`).
+ wheel_base_dir (Path): Base directory for wheel files.
+ index_base_dir (Path): Base directory to store index files.
+ default_variant (str | None): The default variant name, if any.
+ alias_to_default (str | None): Alias variant name for the default variant, if any.
+
+ First, parse all wheel files to extract metadata.
+ We need to collect all wheel files for each variant, and generate an index for it (in a sub-directory).
+ The index for the default variant (if any) is generated in the root index directory.
+
+ If `default_variant` is provided, all wheels must have variant suffixes, and the default variant index
+ is purely a copy of the corresponding variant index, with only the links adjusted.
+ Otherwise, all wheels without variant suffixes are treated as the default variant.
+
+ If `alias_to_default` is provided, an additional alias sub-directory is created, it has the same content
+ as the default variant index, but the links are adjusted accordingly.
+
+ Index directory structure:
+ index_base_dir/ (hosted at wheels.vllm.ai/{nightly,$commit,$version}/)
+ index.html # project list, linking to "vllm/" and other packages, and all variant sub-directories
+ vllm/
+ index.html # package index, pointing to actual files in wheel_base_dir (relative path)
+ metadata.json # machine-readable metadata for all wheels in this package
+ cpu/ # cpu variant sub-directory
+ index.html
+ vllm/
+ index.html
+ metadata.json
+ cu129/ # cu129 is actually the alias to default variant
+ index.html
+ vllm/
+ index.html
+ metadata.json
+ cu130/ # cu130 variant sub-directory
+ index.html
+ vllm/
+ index.html
+ metadata.json
+ ...
+
+ metadata.json stores a dump of all wheel files' metadata in a machine-readable format:
+ [
+ {
+ "package_name": "vllm",
+ "version": "0.10.2rc2",
+ "build_tag": null,
+ "python_tag": "cp38",
+ "abi_tag": "abi3",
+ "platform_tag": "manylinux2014_aarch64",
+ "variant": "cu129",
+ "filename": "vllm-0.10.2rc2+cu129-cp38-abi3-manylinux2014_aarch64.whl",
+ "path": "../vllm-0.10.2rc2+cu129-cp38-abi3-manylinux2014_aarch64.whl" # to be concatenated with the directory URL
+ },
+ ...
+ ]
+ """
+
+ parsed_files = [parse_from_filename(f) for f in whl_files]
+
+ if not parsed_files:
+ print("No wheel files found, skipping index generation.")
+ return
+
+ # Group by variant
+ variant_to_files: dict[str, list[WheelFileInfo]] = {}
+ for file in parsed_files:
+ variant = file.variant or "default"
+ if variant not in variant_to_files:
+ variant_to_files[variant] = []
+ variant_to_files[variant].append(file)
+
+ print(f"Found variants: {list(variant_to_files.keys())}")
+
+ # sanity check for default variant
+ if default_variant:
+ if "default" in variant_to_files:
+ raise ValueError(
+ "All wheel files must have variant suffixes when `default_variant` is specified."
+ )
+ if default_variant not in variant_to_files:
+ raise ValueError(
+ f"Default variant '{default_variant}' not found among wheel files."
+ )
+
+ if alias_to_default:
+ if "default" not in variant_to_files:
+ # e.g. only some wheels are uploaded to S3 currently
+ print(
+ "[WARN] Alias to default variant specified, but no default variant found."
+ )
+ elif alias_to_default in variant_to_files:
+ raise ValueError(
+ f"Alias variant name '{alias_to_default}' already exists among wheel files."
+ )
+ else:
+ variant_to_files[alias_to_default] = variant_to_files["default"].copy()
+ print(f"Alias variant '{alias_to_default}' created for default variant.")
+
+ # Generate index for each variant
+ subdir_names = set()
+ for variant, files in variant_to_files.items():
+ if variant == "default":
+ variant_dir = index_base_dir
+ else:
+ variant_dir = index_base_dir / variant
+ subdir_names.add(variant)
+
+ variant_dir.mkdir(parents=True, exist_ok=True)
+
+ # gather all package names in this variant
+ packages = set(f.package_name for f in files)
+ if variant == "default":
+ # these packages should also appear in the "project list"
+ # generate after all variants are processed
+ subdir_names = subdir_names.union(packages)
+ else:
+ # generate project list for this variant directly
+ project_list_str = generate_project_list(sorted(packages))
+ with open(variant_dir / "index.html", "w") as f:
+ f.write(project_list_str)
+
+ for package in packages:
+ # filter files belonging to this package only
+ package_files = [f for f in files if f.package_name == package]
+ package_dir = variant_dir / package
+ package_dir.mkdir(parents=True, exist_ok=True)
+ index_str, metadata_str = generate_package_index_and_metadata(
+ package_files, wheel_base_dir, package_dir
+ )
+ with open(package_dir / "index.html", "w") as f:
+ f.write(index_str)
+ with open(package_dir / "metadata.json", "w") as f:
+ f.write(metadata_str)
+
+ # Generate top-level project list index
+ project_list_str = generate_project_list(sorted(subdir_names))
+ with open(index_base_dir / "index.html", "w") as f:
+ f.write(project_list_str)
+
+
+if __name__ == "__main__":
+ """
+ Arguments:
+ --version : version string for the current build (e.g., commit hash)
+ --current-objects : path to JSON file containing current S3 objects listing in this version directory
+ --output-dir : directory to store generated index files
+ --alias-to-default : (optional) alias variant name for the default variant
+ """
+
+ parser = argparse.ArgumentParser(
+ description="Process nightly build wheel files to generate indices."
+ )
+ parser.add_argument(
+ "--version",
+ type=str,
+ required=True,
+ help="Version string for the current build (e.g., commit hash)",
+ )
+ parser.add_argument(
+ "--current-objects",
+ type=str,
+ required=True,
+ help="Path to JSON file containing current S3 objects listing in this version directory",
+ )
+ parser.add_argument(
+ "--output-dir",
+ type=str,
+ required=True,
+ help="Directory to store generated index files",
+ )
+ parser.add_argument(
+ "--alias-to-default",
+ type=str,
+ default=None,
+ help="Alias variant name for the default variant",
+ )
+
+ args = parser.parse_args()
+
+ version = args.version
+ if "/" in version or "\\" in version:
+ raise ValueError("Version string must not contain slashes.")
+ current_objects_path = Path(args.current_objects)
+ output_dir = Path(args.output_dir)
+ if not output_dir.exists():
+ output_dir.mkdir(parents=True, exist_ok=True)
+
+ # Read current objects JSON
+ with open(current_objects_path) as f:
+ current_objects: dict[str, list[dict[str, Any]]] = json.load(f)
+
+ # current_objects looks like from list_objects_v2 S3 API:
+ """
+ "Contents": [
+ {
+ "Key": "e2f56c309d2a28899c68975a7e104502d56deb8f/vllm-0.11.2.dev363+ge2f56c309-cp38-abi3-manylinux1_x86_64.whl",
+ "LastModified": "2025-11-28T14:00:32+00:00",
+ "ETag": "\"37a38339c7cdb61ca737021b968075df-52\"",
+ "ChecksumAlgorithm": [
+ "CRC64NVME"
+ ],
+ "ChecksumType": "FULL_OBJECT",
+ "Size": 435649349,
+ "StorageClass": "STANDARD"
+ },
+ ...
+ ]
+ """
+
+ # Extract wheel file keys
+ wheel_files = []
+ for item in current_objects.get("Contents", []):
+ key: str = item["Key"]
+ if key.endswith(".whl"):
+ wheel_files.append(key.split("/")[-1]) # only the filename is used
+
+ print(f"Found {len(wheel_files)} wheel files for version {version}: {wheel_files}")
+
+ # Generate index and metadata, assuming wheels and indices are stored as:
+ # s3://vllm-wheels/{version}/
+ # s3://vllm-wheels//
+ wheel_base_dir = Path(output_dir).parent / version
+ index_base_dir = Path(output_dir)
+
+ generate_index_and_metadata(
+ whl_files=wheel_files,
+ wheel_base_dir=wheel_base_dir,
+ index_base_dir=index_base_dir,
+ default_variant=None,
+ alias_to_default=args.alias_to_default,
+ )
+ print(f"Successfully generated index and metadata in {output_dir}")
diff --git a/.buildkite/scripts/hardware_ci/run-amd-test.sh b/.buildkite/scripts/hardware_ci/run-amd-test.sh
index 0e5b21ddf25b..864eb470bb0a 100755
--- a/.buildkite/scripts/hardware_ci/run-amd-test.sh
+++ b/.buildkite/scripts/hardware_ci/run-amd-test.sh
@@ -59,7 +59,7 @@ while true; do
fi
done
-echo "--- Pulling container"
+echo "--- Pulling container"
image_name="rocm/vllm-ci:${BUILDKITE_COMMIT}"
container_name="rocm_${BUILDKITE_COMMIT}_$(tr -dc A-Za-z0-9 < /dev/urandom | head -c 10; echo)"
docker pull "${image_name}"
@@ -177,13 +177,13 @@ if [[ -z "$render_gid" ]]; then
exit 1
fi
-# check if the command contains shard flag, we will run all shards in parallel because the host have 8 GPUs.
+# check if the command contains shard flag, we will run all shards in parallel because the host have 8 GPUs.
if [[ $commands == *"--shard-id="* ]]; then
- # assign job count as the number of shards used
- commands=${commands//"--num-shards= "/"--num-shards=${PARALLEL_JOB_COUNT} "}
+ # assign job count as the number of shards used
+ commands=$(echo "$commands" | sed -E "s/--num-shards[[:blank:]]*=[[:blank:]]*[0-9]*/--num-shards=${PARALLEL_JOB_COUNT} /g" | sed 's/ \\ / /g')
for GPU in $(seq 0 $(($PARALLEL_JOB_COUNT-1))); do
# assign shard-id for each shard
- commands_gpu=${commands//"--shard-id= "/"--shard-id=${GPU} "}
+ commands_gpu=$(echo "$commands" | sed -E "s/--shard-id[[:blank:]]*=[[:blank:]]*[0-9]*/--shard-id=${GPU} /g" | sed 's/ \\ / /g')
echo "Shard ${GPU} commands:$commands_gpu"
echo "Render devices: $BUILDKITE_AGENT_META_DATA_RENDER_DEVICES"
docker run \
diff --git a/.buildkite/scripts/hardware_ci/run-cpu-test-arm.sh b/.buildkite/scripts/hardware_ci/run-cpu-test-arm.sh
new file mode 100755
index 000000000000..b5f6b2494792
--- /dev/null
+++ b/.buildkite/scripts/hardware_ci/run-cpu-test-arm.sh
@@ -0,0 +1,62 @@
+#!/bin/bash
+
+# This script build the CPU docker image and run the offline inference inside the container.
+# It serves a sanity check for compilation and basic model usage.
+set -ex
+
+# allow to bind to different cores
+CORE_RANGE=${CORE_RANGE:-0-16}
+OMP_CORE_RANGE=${OMP_CORE_RANGE:-0-16}
+
+export CMAKE_BUILD_PARALLEL_LEVEL=16
+
+# Setup cleanup
+remove_docker_container() {
+ set -e;
+ docker rm -f cpu-test || true;
+}
+trap remove_docker_container EXIT
+remove_docker_container
+
+# Try building the docker image
+docker build --tag cpu-test --target vllm-test -f docker/Dockerfile.cpu .
+
+# Run the image
+docker run -itd --cpuset-cpus="$CORE_RANGE" --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/huggingface -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=16 --env VLLM_CPU_CI_ENV=1 -e E2E_OMP_THREADS="$OMP_CORE_RANGE" --shm-size=4g --name cpu-test cpu-test
+
+function cpu_tests() {
+ set -e
+
+ docker exec cpu-test bash -c "
+ set -e
+ pip list"
+
+ # offline inference
+ docker exec cpu-test bash -c "
+ set -e
+ python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m"
+
+ # Run kernel tests
+ docker exec cpu-test bash -c "
+ set -e
+ pytest -x -v -s tests/kernels/test_onednn.py
+ pytest -x -v -s tests/kernels/attention/test_cpu_attn.py"
+
+ # basic online serving
+ docker exec cpu-test bash -c '
+ set -e
+ VLLM_CPU_OMP_THREADS_BIND=$E2E_OMP_THREADS vllm serve Qwen/Qwen3-0.6B --max-model-len 2048 &
+ server_pid=$!
+ timeout 600 bash -c "until curl localhost:8000/v1/models; do sleep 1; done" || exit 1
+ vllm bench serve \
+ --backend vllm \
+ --dataset-name random \
+ --model Qwen/Qwen3-0.6B \
+ --num-prompts 20 \
+ --endpoint /v1/completions
+ kill -s SIGTERM $server_pid &'
+}
+
+# All of CPU tests are expected to be finished less than 40 mins.
+export -f cpu_tests
+timeout 2h bash -c cpu_tests
diff --git a/.buildkite/scripts/hardware_ci/run-cpu-test-ppc64le.sh b/.buildkite/scripts/hardware_ci/run-cpu-test-ppc64le.sh
index 39ea18017308..3728f73fa2a3 100755
--- a/.buildkite/scripts/hardware_ci/run-cpu-test-ppc64le.sh
+++ b/.buildkite/scripts/hardware_ci/run-cpu-test-ppc64le.sh
@@ -25,20 +25,22 @@ function cpu_tests() {
# offline inference
podman exec -it "$container_id" bash -c "
+ export TORCH_COMPILE_DISABLE=1
set -xve
python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m" >> $HOME/test_basic.log
# Run basic model test
podman exec -it "$container_id" bash -c "
+ export TORCH_COMPILE_DISABLE=1
set -evx
pip install pytest pytest-asyncio einops peft Pillow soundfile transformers_stream_generator matplotlib
- pip install sentence-transformers datamodel_code_generator
+ pip install sentence-transformers datamodel_code_generator tblib
# Note: disable Bart until supports V1
# pytest -v -s tests/models/language/generation/test_bart.py -m cpu_model
- pytest -v -s tests/models/language/generation/test_common.py::test_models[False-5-32-openai-community/gpt2]
- pytest -v -s tests/models/language/generation/test_common.py::test_models[False-5-32-facebook/opt-125m]
- pytest -v -s tests/models/language/generation/test_common.py::test_models[False-5-32-google/gemma-1.1-2b-it]
+ pytest -v -s tests/models/language/generation/test_common.py::test_models[False-False-5-32-openai-community/gpt2]
+ pytest -v -s tests/models/language/generation/test_common.py::test_models[False-False-5-32-facebook/opt-125m]
+ pytest -v -s tests/models/language/generation/test_common.py::test_models[False-False-5-32-google/gemma-1.1-2b-it]
pytest -v -s tests/models/language/pooling/test_classification.py::test_models[float-jason9693/Qwen2.5-1.5B-apeach]
# TODO: Below test case tests/models/language/pooling/test_embedding.py::test_models[True-ssmits/Qwen2-7B-Instruct-embed-base] fails on ppc64le. Disabling it for time being.
# pytest -v -s tests/models/language/pooling/test_embedding.py -m cpu_model" >> $HOME/test_rest.log
diff --git a/.buildkite/scripts/hardware_ci/run-cpu-test.sh b/.buildkite/scripts/hardware_ci/run-cpu-test.sh
index 7479c43977d7..438fe522c870 100644
--- a/.buildkite/scripts/hardware_ci/run-cpu-test.sh
+++ b/.buildkite/scripts/hardware_ci/run-cpu-test.sh
@@ -21,8 +21,8 @@ trap remove_docker_container EXIT
remove_docker_container
# Try building the docker image
-numactl -C "$CORE_RANGE" -N "$NUMA_NODE" docker build --tag cpu-test-"$NUMA_NODE" --target vllm-test -f docker/Dockerfile.cpu .
-numactl -C "$CORE_RANGE" -N "$NUMA_NODE" docker build --build-arg VLLM_CPU_DISABLE_AVX512="true" --tag cpu-test-"$NUMA_NODE"-avx2 --target vllm-test -f docker/Dockerfile.cpu .
+numactl -C "$CORE_RANGE" -N "$NUMA_NODE" docker build --progress plain --tag cpu-test-"$NUMA_NODE" --target vllm-test -f docker/Dockerfile.cpu .
+numactl -C "$CORE_RANGE" -N "$NUMA_NODE" docker build --progress plain --build-arg VLLM_CPU_DISABLE_AVX512="true" --tag cpu-test-"$NUMA_NODE"-avx2 --target vllm-test -f docker/Dockerfile.cpu .
# Run the image, setting --shm-size=4g for tensor parallel.
docker run -itd --cpuset-cpus="$CORE_RANGE" --cpuset-mems="$NUMA_NODE" --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/huggingface --privileged=true -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=16 --env VLLM_CPU_CI_ENV=1 -e E2E_OMP_THREADS="$OMP_CORE_RANGE" --shm-size=4g --name cpu-test-"$NUMA_NODE" cpu-test-"$NUMA_NODE"
@@ -73,12 +73,11 @@ function cpu_tests() {
pytest -x -s -v \
tests/quantization/test_compressed_tensors.py::test_compressed_tensors_w8a8_logprobs"
- # Note: disable it until supports V1
- # Run AWQ test
- # docker exec cpu-test-"$NUMA_NODE" bash -c "
- # set -e
- # pytest -x -s -v \
- # tests/quantization/test_ipex_quant.py"
+ # Run AWQ/GPTQ test
+ docker exec cpu-test-"$NUMA_NODE" bash -c "
+ set -e
+ pytest -x -s -v \
+ tests/quantization/test_cpu_wna16.py"
# Run multi-lora tests
docker exec cpu-test-"$NUMA_NODE" bash -c "
diff --git a/.buildkite/scripts/hardware_ci/run-xpu-test.sh b/.buildkite/scripts/hardware_ci/run-xpu-test.sh
index 27ed67c4517e..4d163399cfc6 100644
--- a/.buildkite/scripts/hardware_ci/run-xpu-test.sh
+++ b/.buildkite/scripts/hardware_ci/run-xpu-test.sh
@@ -35,7 +35,7 @@ docker run \
echo $ZE_AFFINITY_MASK
pip install tblib==3.1.0
python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager
- python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 -O3 -O.cudagraph_mode=NONE
+ python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 -O3 -cc.cudagraph_mode=NONE
python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager -tp 2 --distributed-executor-backend ray
python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager -tp 2 --distributed-executor-backend mp
VLLM_ATTENTION_BACKEND=TRITON_ATTN python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager
@@ -46,6 +46,6 @@ docker run \
pytest -v -s v1/worker --ignore=v1/worker/test_gpu_model_runner.py
pytest -v -s v1/structured_output
pytest -v -s v1/spec_decode --ignore=v1/spec_decode/test_max_len.py --ignore=v1/spec_decode/test_tree_attention.py --ignore=v1/spec_decode/test_speculators_eagle3.py
- pytest -v -s v1/kv_connector/unit --ignore=v1/kv_connector/unit/test_multi_connector.py --ignore=v1/kv_connector/unit/test_nixl_connector.py --ignore=v1/kv_connector/unit/test_shared_storage_connector.py
+ pytest -v -s v1/kv_connector/unit --ignore=v1/kv_connector/unit/test_multi_connector.py --ignore=v1/kv_connector/unit/test_nixl_connector.py --ignore=v1/kv_connector/unit/test_shared_storage_connector.py --ignore=v1/kv_connector/unit/test_lmcache_integration.py
pytest -v -s v1/test_serial_utils.py
'
diff --git a/.buildkite/scripts/scheduled_integration_test/deepseek_v2_lite_ep_eplb.sh b/.buildkite/scripts/scheduled_integration_test/deepseek_v2_lite_ep_eplb.sh
index 5302f524a0ae..8106f50f18f6 100644
--- a/.buildkite/scripts/scheduled_integration_test/deepseek_v2_lite_ep_eplb.sh
+++ b/.buildkite/scripts/scheduled_integration_test/deepseek_v2_lite_ep_eplb.sh
@@ -17,7 +17,17 @@ wait_for_server() {
}
MODEL="deepseek-ai/DeepSeek-V2-lite"
-BACKENDS=("deepep_high_throughput" "deepep_low_latency")
+
+# Set BACKENDS based on platform
+if command -v rocm-smi &> /dev/null || [[ -d /opt/rocm ]] || [[ -n "${ROCM_PATH:-}" ]]; then
+ # ROCm platform
+ BACKENDS=("allgather_reducescatter")
+ # Disable MOE padding for ROCm since it is causing eplb to fail
+ export VLLM_ROCM_MOE_PADDING=0
+else
+ # Non-ROCm platform (CUDA/other)
+ BACKENDS=("deepep_high_throughput" "deepep_low_latency")
+fi
cleanup() {
if [[ -n "${SERVER_PID:-}" ]] && kill -0 "${SERVER_PID}" 2>/dev/null; then
diff --git a/.buildkite/scripts/scheduled_integration_test/qwen30b_a3b_fp8_block_ep.sh b/.buildkite/scripts/scheduled_integration_test/qwen30b_a3b_fp8_block_ep_eplb.sh
similarity index 64%
rename from .buildkite/scripts/scheduled_integration_test/qwen30b_a3b_fp8_block_ep.sh
rename to .buildkite/scripts/scheduled_integration_test/qwen30b_a3b_fp8_block_ep_eplb.sh
index a5135299297e..6a1bef275d04 100644
--- a/.buildkite/scripts/scheduled_integration_test/qwen30b_a3b_fp8_block_ep.sh
+++ b/.buildkite/scripts/scheduled_integration_test/qwen30b_a3b_fp8_block_ep_eplb.sh
@@ -1,10 +1,12 @@
#!/usr/bin/env bash
set -euxo pipefail
-# args: [THRESHOLD] [NUM_QUESTIONS] [START_PORT]
+# args: [THRESHOLD] [NUM_QUESTIONS] [START_PORT] [DATA_PARALLEL_SIZE] [TENSOR_PARALLEL_SIZE]
THRESHOLD=${1:-0.8}
NUM_Q=${2:-1319}
PORT=${3:-8020}
+DATA_PARALLEL_SIZE=${4:-2}
+TENSOR_PARALLEL_SIZE=${5:-2}
OUT_DIR=${OUT_DIR:-/tmp/vllm-scheduled}
mkdir -p "${OUT_DIR}"
@@ -17,7 +19,16 @@ wait_for_server() {
}
MODEL="QWen/Qwen3-30B-A3B-FP8"
-BACKENDS=("deepep_high_throughput" "deepep_low_latency")
+# Set BACKENDS based on platform
+if command -v rocm-smi &> /dev/null || [[ -d /opt/rocm ]] || [[ -n "${ROCM_PATH:-}" ]]; then
+ # ROCm platform
+ BACKENDS=("allgather_reducescatter")
+ # Disable MOE padding for ROCm since it is causing eplb to fail
+ export VLLM_ROCM_MOE_PADDING=0
+else
+ # Non-ROCm platform (CUDA/other)
+ BACKENDS=("deepep_high_throughput" "deepep_low_latency")
+fi
cleanup() {
if [[ -n "${SERVER_PID:-}" ]] && kill -0 "${SERVER_PID}" 2>/dev/null; then
@@ -36,8 +47,10 @@ for BACK in "${BACKENDS[@]}"; do
VLLM_ALL2ALL_BACKEND=$BACK \
vllm serve "$MODEL" \
--enforce-eager \
- --tensor-parallel-size 2 \
- --data-parallel-size 2 \
+ --enable-eplb \
+ --eplb-config '{"window_size":10, "step_interval":100, "num_redundant_experts":0, "log_balancedness":true}' \
+ --tensor-parallel-size ${TENSOR_PARALLEL_SIZE} \
+ --data-parallel-size ${DATA_PARALLEL_SIZE} \
--enable-expert-parallel \
--trust-remote-code \
--max-model-len 2048 \
diff --git a/.buildkite/scripts/upload-wheels.sh b/.buildkite/scripts/upload-wheels.sh
index 945c5e48c009..05accb9cf16d 100644
--- a/.buildkite/scripts/upload-wheels.sh
+++ b/.buildkite/scripts/upload-wheels.sh
@@ -2,6 +2,28 @@
set -ex
+# ======== part 0: setup ========
+
+BUCKET="vllm-wheels"
+INDICES_OUTPUT_DIR="indices"
+DEFAULT_VARIANT_ALIAS="cu129" # align with vLLM_MAIN_CUDA_VERSION in vllm/envs.py
+PYTHON=${PYTHON_PROG:=python3} # try to read from env var, otherwise use python3
+SUBPATH=$BUILDKITE_COMMIT
+S3_COMMIT_PREFIX="s3://$BUCKET/$SUBPATH/"
+
+# detect if python3.10+ is available
+has_new_python=$($PYTHON -c "print(1 if __import__('sys').version_info >= (3,10) else 0)")
+if [[ "$has_new_python" -eq 0 ]]; then
+ # use new python from docker
+ docker pull python:3-slim
+ PYTHON="docker run --rm -v $(pwd):/app -w /app python:3-slim python3"
+fi
+
+echo "Using python interpreter: $PYTHON"
+echo "Python version: $($PYTHON --version)"
+
+# ========= part 1: collect, rename & upload the wheel ==========
+
# Assume wheels are in artifacts/dist/*.whl
wheel_files=(artifacts/dist/*.whl)
@@ -10,74 +32,69 @@ if [[ ${#wheel_files[@]} -ne 1 ]]; then
echo "Error: Expected exactly one wheel file in artifacts/dist/, but found ${#wheel_files[@]}"
exit 1
fi
-
-# Get the single wheel file
wheel="${wheel_files[0]}"
-# Detect architecture and rename 'linux' to appropriate manylinux version
-arch=$(uname -m)
-if [[ $arch == "x86_64" ]]; then
- manylinux_version="manylinux1"
-elif [[ $arch == "aarch64" ]]; then
- manylinux_version="manylinux2014"
-else
- echo "Warning: Unknown architecture $arch, using manylinux1 as default"
- manylinux_version="manylinux1"
-fi
+# current build image uses ubuntu 20.04, which corresponds to manylinux_2_31
+# refer to https://github.com/mayeut/pep600_compliance?tab=readme-ov-file#acceptable-distros-to-build-wheels
+manylinux_version="manylinux_2_31"
# Rename 'linux' to the appropriate manylinux version in the wheel filename
+if [[ "$wheel" != *"linux"* ]]; then
+ echo "Error: Wheel filename does not contain 'linux': $wheel"
+ exit 1
+fi
new_wheel="${wheel/linux/$manylinux_version}"
mv -- "$wheel" "$new_wheel"
wheel="$new_wheel"
+echo "Renamed wheel to: $wheel"
# Extract the version from the wheel
version=$(unzip -p "$wheel" '**/METADATA' | grep '^Version: ' | cut -d' ' -f2)
-echo "Version: $version"
-
-normal_wheel="$wheel" # Save the original wheel filename
-
-# If the version contains "dev", rename it to v1.0.0.dev for consistency
-if [[ $version == *dev* ]]; then
- suffix="${version##*.}"
- if [[ $suffix == cu* ]]; then
- new_version="1.0.0.dev+${suffix}"
- else
- new_version="1.0.0.dev"
- fi
- new_wheel="${wheel/$version/$new_version}"
- # use cp to keep both files in the artifacts directory
- cp -- "$wheel" "$new_wheel"
- wheel="$new_wheel"
- version="$new_version"
-fi
+echo "Version in wheel: $version"
+pure_version="${version%%+*}"
+echo "Pure version (without variant): $pure_version"
-# Upload the wheel to S3
-python3 .buildkite/generate_index.py --wheel "$normal_wheel"
+# copy wheel to its own bucket
+aws s3 cp "$wheel" "$S3_COMMIT_PREFIX"
-# generate index for this commit
-aws s3 cp "$wheel" "s3://vllm-wheels/$BUILDKITE_COMMIT/"
-aws s3 cp "$normal_wheel" "s3://vllm-wheels/$BUILDKITE_COMMIT/"
+# ========= part 2: generate and upload indices ==========
+# generate indices for all existing wheels in the commit directory
+# this script might be run multiple times if there are multiple variants being built
+# so we need to guarantee there is little chance for "TOCTOU" issues
+# i.e., one process is generating indices while another is uploading a new wheel
+# so we need to ensure no time-consuming operations happen below
-if [[ $normal_wheel == *"cu129"* ]]; then
- # only upload index.html for cu129 wheels (default wheels) as it
- # is available on both x86 and arm64
- aws s3 cp index.html "s3://vllm-wheels/$BUILDKITE_COMMIT/vllm/index.html"
- aws s3 cp "s3://vllm-wheels/nightly/index.html" "s3://vllm-wheels/$BUILDKITE_COMMIT/index.html"
+# list all wheels in the commit directory
+echo "Existing wheels on S3:"
+aws s3 ls "$S3_COMMIT_PREFIX"
+obj_json="objects.json"
+aws s3api list-objects-v2 --bucket "$BUCKET" --prefix "$SUBPATH/" --delimiter / --output json > "$obj_json"
+mkdir -p "$INDICES_OUTPUT_DIR"
+
+# call script to generate indicies for all existing wheels
+# this indices have relative paths that could work as long as it is next to the wheel directory in s3
+# i.e., the wheels are always in s3://vllm-wheels//
+# and indices can be placed in //, or /nightly/, or //
+if [[ ! -z "$DEFAULT_VARIANT_ALIAS" ]]; then
+ alias_arg="--alias-to-default $DEFAULT_VARIANT_ALIAS"
else
- echo "Skipping index files for non-cu129 wheels"
+ alias_arg=""
fi
-# generate index for nightly
-aws s3 cp "$wheel" "s3://vllm-wheels/nightly/"
-aws s3 cp "$normal_wheel" "s3://vllm-wheels/nightly/"
+$PYTHON .buildkite/scripts/generate-nightly-index.py --version "$SUBPATH" --current-objects "$obj_json" --output-dir "$INDICES_OUTPUT_DIR" $alias_arg
-if [[ $normal_wheel == *"cu129"* ]]; then
- # only upload index.html for cu129 wheels (default wheels) as it
- # is available on both x86 and arm64
- aws s3 cp index.html "s3://vllm-wheels/nightly/vllm/index.html"
-else
- echo "Skipping index files for non-cu129 wheels"
+# copy indices to // unconditionally
+echo "Uploading indices to $S3_COMMIT_PREFIX"
+aws s3 cp --recursive "$INDICES_OUTPUT_DIR/" "$S3_COMMIT_PREFIX"
+
+# copy to /nightly/ only if it is on the main branch and not a PR
+if [[ "$BUILDKITE_BRANCH" == "main" && "$BUILDKITE_PULL_REQUEST" == "false" ]]; then
+ echo "Uploading indices to overwrite /nightly/"
+ aws s3 cp --recursive "$INDICES_OUTPUT_DIR/" "s3://$BUCKET/nightly/"
fi
-aws s3 cp "$wheel" "s3://vllm-wheels/$version/"
-aws s3 cp index.html "s3://vllm-wheels/$version/vllm/index.html"
+# copy to // only if it does not have "dev" in the version
+if [[ "$version" != *"dev"* ]]; then
+ echo "Uploading indices to overwrite /$pure_version/"
+ aws s3 cp --recursive "$INDICES_OUTPUT_DIR/" "s3://$BUCKET/$pure_version/"
+fi
diff --git a/.buildkite/test-amd.yaml b/.buildkite/test-amd.yaml
index 5fd048c2ad0c..687b6b08507c 100644
--- a/.buildkite/test-amd.yaml
+++ b/.buildkite/test-amd.yaml
@@ -61,8 +61,8 @@ steps:
- pytest -v -s -m 'not cpu_test' multimodal
- pytest -v -s utils_
-- label: Async Engine, Inputs, Utils, Worker Test (CPU) # 4 mins
- timeout_in_minutes: 10
+- label: Async Engine, Inputs, Utils, Worker, Config Test (CPU) # 15min
+ timeout_in_minutes: 20
mirror_hardwares: [amdexperimental, amdproduction]
agent_pool: mi325_1
# grade: Blocking
@@ -72,14 +72,18 @@ steps:
- tests/test_outputs.py
- tests/multimodal
- tests/standalone_tests/lazy_imports.py
+ - tests/tokenizers_
- tests/transformers_utils
+ - tests/config
no_gpu: true
commands:
- python3 standalone_tests/lazy_imports.py
- pytest -v -s test_inputs.py
- pytest -v -s test_outputs.py
- pytest -v -s -m 'cpu_test' multimodal
+ - pytest -v -s tokenizers_
- pytest -v -s transformers_utils
+ - pytest -v -s config
- label: Python-only Installation Test # 10min
timeout_in_minutes: 20
@@ -187,7 +191,7 @@ steps:
- tests/distributed/test_utils
- tests/distributed/test_pynccl
- tests/distributed/test_events
- - tests/compile/test_basic_correctness
+ - tests/compile/fullgraph/test_basic_correctness.py
- examples/offline_inference/rlhf.py
- examples/offline_inference/rlhf_colocate.py
- tests/examples/offline_inference/data_parallel.py
@@ -215,7 +219,7 @@ steps:
- TP_SIZE=1 DP_SIZE=4 pytest -v -s v1/distributed/test_hybrid_lb_dp.py
- pytest -v -s v1/engine/test_engine_core_client.py::test_kv_cache_events_dp
- pytest -v -s distributed/test_utils.py
- - pytest -v -s compile/test_basic_correctness.py
+ - pytest -v -s compile/fullgraph/test_basic_correctness.py
- pytest -v -s distributed/test_pynccl.py
- pytest -v -s distributed/test_events.py
- pytest -v -s distributed/test_symm_mem_allreduce.py
@@ -226,6 +230,27 @@ steps:
- VLLM_ALLOW_INSECURE_SERIALIZATION=1 RAY_DEDUP_LOGS=0 python3 rlhf_colocate.py
- popd
+- label: Distributed Tests (8 GPUs) # 4min
+ timeout_in_minutes: 10
+ mirror_hardwares: [amdexperimental]
+ agent_pool: mi325_8
+ # grade: Blocking
+ gpu: h100
+ num_gpus: 8
+ working_dir: "/vllm-workspace/tests"
+ source_file_dependencies:
+ - examples/offline_inference/torchrun_dp_example.py
+ - vllm/config/parallel.py
+ - vllm/distributed/
+ - vllm/v1/engine/llm_engine.py
+ - vllm/v1/executor/uniproc_executor.py
+ - vllm/v1/worker/gpu_worker.py
+ commands:
+ # https://github.com/NVIDIA/nccl/issues/1838
+ #- export NCCL_CUMEM_HOST_ENABLE=0
+ # test with torchrun tp=2 and dp=4 with ep
+ - torchrun --nproc-per-node=8 ../examples/offline_inference/torchrun_dp_example.py --tp-size=2 --pp-size=1 --dp-size=4 --enable-ep
+
- label: EPLB Algorithm Test # 5min
mirror_hardwares: [amdexperimental, amdproduction]
agent_pool: mi325_1
@@ -238,11 +263,11 @@ steps:
commands:
- pytest -v -s distributed/test_eplb_algo.py
-- label: EPLB Execution Test # 5min
+- label: EPLB Execution Test # 10min
mirror_hardwares: [amdexperimental, amdproduction]
agent_pool: mi325_4
# grade: Blocking
- timeout_in_minutes: 15
+ timeout_in_minutes: 20
working_dir: "/vllm-workspace/tests"
num_gpus: 4
source_file_dependencies:
@@ -250,6 +275,7 @@ steps:
- tests/distributed/test_eplb_execute.py
commands:
- pytest -v -s distributed/test_eplb_execute.py
+ - pytest -v -s distributed/test_eplb_spec_decode.py
- label: Metrics, Tracing Test # 12min
timeout_in_minutes: 20
@@ -273,7 +299,7 @@ steps:
- label: Regression Test # 7min
timeout_in_minutes: 20
- mirror_hardwares: [amdexperimental, amdproduction]
+ mirror_hardwares: [amdexperimental, amdproduction, amdtentative]
agent_pool: mi325_1
grade: Blocking
source_file_dependencies:
@@ -284,23 +310,20 @@ steps:
- pytest -v -s test_regression.py
working_dir: "/vllm-workspace/tests" # optional
-- label: Engine Test # 25min
- timeout_in_minutes: 40
+- label: Engine Test # 9min
+ timeout_in_minutes: 15
mirror_hardwares: [amdexperimental, amdproduction]
agent_pool: mi325_1
- #grade: Blocking
+ # grade: Blocking
source_file_dependencies:
- vllm/
- tests/engine
- - tests/tokenization
- tests/test_sequence
- tests/test_config
- tests/test_logger
- tests/test_vllm_port
commands:
- pytest -v -s engine test_sequence.py test_config.py test_logger.py test_vllm_port.py
- # OOM in the CI unless we run this separately
- - pytest -v -s tokenization
- label: V1 Test e2e + engine # 30min
timeout_in_minutes: 45
@@ -337,6 +360,7 @@ steps:
- tests/v1
commands:
# split the test to avoid interference
+ - uv pip install --system -r /vllm-workspace/requirements/kv_connectors.txt
- pytest -v -s -m 'not cpu_test' v1/core
- pytest -v -s v1/executor
- pytest -v -s v1/kv_offload
@@ -344,7 +368,7 @@ steps:
- pytest -v -s v1/logits_processors
- pytest -v -s v1/worker
- pytest -v -s v1/spec_decode
- - pytest -v -s -m 'not cpu_test' v1/kv_connector/unit --ignore=v1/kv_connector/unit/test_lmcache_integration.py
+ - pytest -v -s -m 'not cpu_test' v1/kv_connector/unit
- pytest -v -s -m 'not cpu_test' v1/metrics
- pytest -v -s v1/test_oracle.py
- pytest -v -s v1/test_request.py
@@ -353,6 +377,29 @@ steps:
- pip install -U git+https://github.com/robertgshaw2-redhat/lm-evaluation-harness.git@streaming-api
- pytest -v -s entrypoints/openai/correctness/test_lmeval.py::test_lm_eval_accuracy_v1_engine
+# TODO: Add the "V1 Test attetion (MI300)" test group
+
+- label: V1 Test attention (H100) # 10min
+ mirror_hardwares: [amdexperimental]
+ agent_pool: mi325_1
+ # grade: Blocking
+ timeout_in_minutes: 30
+ gpu: h100
+ source_file_dependencies:
+ - vllm/v1/attention
+ - tests/v1/attention
+ commands:
+ - pytest -v -s v1/attention
+
+- label: V1 Test attention (B200) # 10min
+ timeout_in_minutes: 30
+ gpu: b200
+ source_file_dependencies:
+ - vllm/v1/attention
+ - tests/v1/attention
+ commands:
+ - VLLM_DISABLE_FLASHINFER_PREFILL=1 pytest -v -s v1/attention # TODO: FI prefill is bugged and causes incorrectness, fix this
+
- label: V1 Test others (CPU) # 5 mins
mirror_hardwares: [amdexperimental, amdproduction]
agent_pool: mi325_1
@@ -456,17 +503,12 @@ steps:
- vllm/
- tests/compile
commands:
- - pytest -v -s compile/test_pass_manager.py
- - pytest -v -s compile/test_fusion.py
- - pytest -v -s compile/test_fusion_attn.py
- - pytest -v -s compile/test_functionalization.py
- - pytest -v -s compile/test_silu_mul_quant_fusion.py
- # - pytest -v -s compile/test_sequence_parallelism.py
- # - pytest -v -s compile/test_async_tp.py
- - pytest -v -s compile/test_fusion_all_reduce.py
- - pytest -v -s compile/test_decorator.py
- - pytest -v -s compile/test_noop_elimination.py
- - pytest -v -s compile/test_aot_compile.py
+ # Run unit tests defined directly under compile/,
+ # not including subdirectories, which are usually heavier
+ # tests covered elsewhere.
+ # Use `find` to launch multiple instances of pytest so that
+ # they do not suffer from https://github.com/vllm-project/vllm/issues/28965
+ - "find compile/ -maxdepth 1 -name 'test_*.py' -exec pytest -s -v {} \\\\;"
- label: PyTorch Fullgraph Smoke Test # 15min
timeout_in_minutes: 30
@@ -478,11 +520,14 @@ steps:
- vllm/
- tests/compile
commands:
- - pytest -v -s compile/test_basic_correctness.py
- - pytest -v -s compile/piecewise/
+ # Run smoke tests under fullgraph directory, except test_full_graph.py
+ # as it is a heavy test that is covered in other steps.
+ # Use `find` to launch multiple instances of pytest so that
+ # they do not suffer from https://github.com/vllm-project/vllm/issues/28965
+ - "find compile/fullgraph/ -name 'test_*.py' -not -name 'test_full_graph.py' -exec pytest -s -v {} \\\\;"
-- label: PyTorch Fullgraph Test # 22min
- timeout_in_minutes: 35
+- label: PyTorch Fullgraph Test # 27min
+ timeout_in_minutes: 40
mirror_hardwares: [amdexperimental, amdproduction]
agent_pool: mi325_1
# grade: Blocking
@@ -491,8 +536,23 @@ steps:
- vllm/
- tests/compile
commands:
- - pytest -v -s compile/test_full_graph.py
- - pytest -v -s compile/test_fusions_e2e.py
+ - pytest -v -s compile/fullgraph/test_full_graph.py -k 'not test_fp8_kv_scale_compile'
+ # Limit to no custom ops to reduce running time
+ # Wrap with quotes to escape yaml and avoid starting -k string with a -
+ - "pytest -v -s compile/distributed/test_fusions_e2e.py -k 'TRITON and not +quant_fp8 and not Llama-4'"
+
+- label: Cudagraph test
+ timeout_in_minutes: 20
+ mirror_hardwares: [amdexperimental, amdproduction]
+ agent_pool: mi325_1
+ source_file_dependencies:
+ - tests/v1/cudagraph
+ - vllm/v1/cudagraph_dispatcher.py
+ - vllm/config/compilation.py
+ - vllm/compilation
+ commands:
+ - pytest -v -s v1/cudagraph/test_cudagraph_dispatch.py
+ - pytest -v -s v1/cudagraph/test_cudagraph_mode.py
- label: Kernels Core Operation Test # 48min
timeout_in_minutes: 75
@@ -544,6 +604,8 @@ steps:
- tests/kernels/moe
- vllm/model_executor/layers/fused_moe/
- vllm/distributed/device_communicators/
+ - vllm/envs.py
+ - vllm/config
commands:
- pytest -v -s kernels/moe --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT
parallelism: 2
@@ -562,10 +624,13 @@ steps:
- label: Model Executor Test # 23min
timeout_in_minutes: 35
+ torch_nightly: true
mirror_hardwares: [amdexperimental, amdproduction]
agent_pool: mi325_1
# grade: Blocking
source_file_dependencies:
+ - vllm/engine/arg_utils.py
+ - vllm/config/model.py
- vllm/model_executor
- tests/model_executor
- tests/entrypoints/openai/test_tensorizer_entrypoint.py
@@ -639,7 +704,7 @@ steps:
- vllm/model_executor/models/whisper.py
commands: # LMEval
# Transcription WER check is skipped because encoder-decoder models are not supported on ROCm, see https://github.com/vllm-project/vllm/issues/27442
- - pytest -s entrypoints/openai/correctness/ --ignore entrypoints/openai/correctness/test_transcription_api_correctness.py
+ - pytest -s entrypoints/openai/correctness/
- label: OpenAI-Compatible Tool Use # 23 min
timeout_in_minutes: 35
@@ -688,6 +753,7 @@ steps:
torch_nightly: true
source_file_dependencies:
- vllm/model_executor/models/
+ - vllm/transformers_utils/
- tests/models/test_initialization.py
commands:
# Only when vLLM model source is modified - test initialization of a large
@@ -861,9 +927,10 @@ steps:
- cd .. && VLLM_WORKER_MULTIPROC_METHOD=spawn pytest -v -s tests/models/multimodal/generation/test_whisper.py -m core_model # Otherwise, mp_method="spawn" doesn't work
- label: Multi-Modal Accuracy Eval (Small Models) # 10min
+ timeout_in_minutes: 70
mirror_hardwares: [amdexperimental, amdproduction]
agent_pool: mi325_1
- timeout_in_minutes: 15
+ # grade: Blocking
working_dir: "/vllm-workspace/.buildkite/lm-eval-harness"
source_file_dependencies:
- vllm/multimodal/
@@ -934,16 +1001,17 @@ steps:
- label: Transformers Nightly Models Test
mirror_hardwares: [amdexperimental]
agent_pool: mi325_1
+ # grade: Blocking
working_dir: "/vllm-workspace/"
optional: true
commands:
- pip install --upgrade git+https://github.com/huggingface/transformers
- - pytest -v -s tests/models/test_initialization.py
+ - pytest -v -s tests/models/test_initialization.py -k 'not (Gemma3 or ModernBert or Qwen2_5_VL or Qwen2_5vl or Qwen2VL or TransformersMultiModalEmbeddingModel or TransformersMultiModalForSequenceClassification or Ultravox or Phi4Multimodal or LlavaNextVideo or MiniCPMO or Lfm2Moe or PaliGemma or RobertaForSequenceClassification or Ovis2_5 or Fuyu or DeepseekOCR or KimiVL)'
- pytest -v -s tests/models/test_transformers.py
- - pytest -v -s tests/models/multimodal/processing/
- - pytest -v -s tests/models/multimodal/test_mapping.py
+ # - pytest -v -s tests/models/multimodal/processing/
+ - pytest -v -s tests/models/multimodal/test_mapping.py -k 'not (Gemma3 or Qwen2VL or Qwen2_5_VL)'
- python3 examples/offline_inference/basic/chat.py
- - python3 examples/offline_inference/vision_language.py --model-type qwen2_5_vl
+ # - python3 examples/offline_inference/vision_language.py --model-type qwen2_5_vl
# Whisper needs spawn method to avoid deadlock
- VLLM_WORKER_MULTIPROC_METHOD=spawn python3 examples/offline_inference/audio_language.py --model-type whisper
@@ -961,11 +1029,16 @@ steps:
- vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py
- vllm/model_executor/layers/quantization/utils/flashinfer_utils.py
- vllm/v1/attention/backends/flashinfer.py
+ - vllm/v1/attention/backends/mla/cutlass_mla.py
+ - vllm/v1/attention/backends/mla/flashinfer_mla.py
+ - vllm/platforms/cuda.py
+ - vllm/attention/selector.py
commands:
- nvidia-smi
- python3 examples/offline_inference/basic/chat.py
# Attention
# num_heads2 broken by https://github.com/flashinfer-ai/flashinfer/issues/1353
+ - pytest -v -s tests/kernels/attention/test_attention_selector.py
- pytest -v -s tests/kernels/attention/test_flashinfer.py -k 'not num_heads2'
- pytest -v -s tests/kernels/attention/test_flashinfer_trtllm_attention.py
- pytest -v -s tests/kernels/attention/test_cutlass_mla_decode.py
@@ -983,7 +1056,7 @@ steps:
- pytest -v -s tests/kernels/moe/test_ocp_mx_moe.py
- pytest -v -s tests/kernels/moe/test_flashinfer.py
-- label: Blackwell Fusion Tests # 30 min
+- label: Blackwell Fusion and Compile Tests # 30 min
timeout_in_minutes: 40
working_dir: "/vllm-workspace/"
gpu: b200
@@ -1001,13 +1074,40 @@ steps:
- pytest -v -s tests/compile/test_fusion_attn.py
- pytest -v -s tests/compile/test_silu_mul_quant_fusion.py
# this runner has 2 GPUs available even though num_gpus=2 is not set
- - pytest -v -s tests/compile/test_fusion_all_reduce.py
+ - pytest -v -s tests/compile/distributed/test_fusion_all_reduce.py
+ # Limit to Inductor partition, no custom ops, and allreduce & attn fusion to reduce running time
+ # Wrap with quotes to escape yaml
+ - "pytest -v -s tests/compile/distributed/test_fusions_e2e.py::test_tp2_attn_quant_allreduce_rmsnorm -k 'True and not +quant_fp8 and not +rms_norm'"
+ # test_fp8_kv_scale_compile requires FlashAttention (not supported on default L4/L40)
+ - pytest -v -s tests/compile/distributed/test_full_graph.py::test_fp8_kv_scale_compile
+
+- label: Blackwell Fusion E2E Tests # 30 min
+ timeout_in_minutes: 40
+ working_dir: "/vllm-workspace/"
+ gpu: b200
+ optional: true
+ num_gpus: 2
+ source_file_dependencies:
+ - csrc/quantization/fp4/
+ - vllm/model_executor/layers/quantization/utils/flashinfer_utils.py
+ - vllm/v1/attention/backends/flashinfer.py
+ - vllm/compilation/
+ # can affect pattern matching
+ - vllm/model_executor/layers/layernorm.py
+ - vllm/model_executor/layers/activation.py
+ - vllm/model_executor/layers/quantization/input_quant_fp8.py
+ - tests/compile/distributed/test_fusions_e2e.py
+ - tests/compile/fullgraph/test_full_graph.py
+ commands:
+ - nvidia-smi
+ # Run all e2e fusion tests
- pytest -v -s tests/compile/test_fusions_e2e.py
-- label: Blackwell GPT-OSS Eval
+- label: ROCm GPT-OSS Eval
timeout_in_minutes: 60
working_dir: "/vllm-workspace/"
- gpu: b200
+ agent_pool: mi325_1
+ mirror_hardwares: [amdexperimental, amdproduction]
optional: true # run on nightlies
source_file_dependencies:
- tests/evals/gpt_oss
@@ -1016,7 +1116,7 @@ steps:
- vllm/v1/attention/backends/flashinfer.py
commands:
- uv pip install --system 'gpt-oss[eval]==0.0.5'
- - pytest -s -v tests/evals/gpt_oss/test_gpqa_correctness.py --model openai/gpt-oss-20b --metric 0.58
+ - VLLM_ROCM_USE_AITER_MHA=0 VLLM_ROCM_USE_AITER=1 VLLM_USE_AITER_UNIFIED_ATTENTION=1 pytest -s -v tests/evals/gpt_oss/test_gpqa_correctness.py --model openai/gpt-oss-20b --metric 0.58
- label: Blackwell Quantized MoE Test
timeout_in_minutes: 60
@@ -1106,7 +1206,7 @@ steps:
- vllm/worker/worker_base.py
- vllm/v1/engine/
- vllm/v1/worker/
- - tests/compile/test_basic_correctness.py
+ - tests/compile/fullgraph/test_basic_correctness.py
- tests/compile/test_wrapper.py
- tests/distributed/
- tests/entrypoints/llm/test_collective_rpc.py
@@ -1119,7 +1219,7 @@ steps:
- TP_SIZE=1 DP_SIZE=2 pytest -v -s v1/distributed/test_external_lb_dp.py
- DP_SIZE=2 pytest -v -s v1/entrypoints/openai/test_multi_api_servers.py
- pytest -v -s entrypoints/llm/test_collective_rpc.py
- - pytest -v -s ./compile/test_basic_correctness.py
+ - pytest -v -s ./compile/fullgraph/test_basic_correctness.py
- pytest -v -s ./compile/test_wrapper.py
- VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py | grep 'Same node test passed'
- VLLM_TEST_SAME_HOST=1 VLLM_TEST_WITH_DEFAULT_DEVICE_SET=1 torchrun --nproc-per-node=4 distributed/test_same_node.py | grep 'Same node test passed'
@@ -1219,7 +1319,10 @@ steps:
- pytest -v -s -x lora/test_llama_tp.py
- pytest -v -s -x lora/test_llm_with_multi_loras.py
- pytest -v -s -x lora/test_olmoe_tp.py
- - pytest -v -s -x lora/test_gptoss_tp.py
+
+ # Disabled for now because MXFP4 backend on non-cuda platform
+ # doesn't support LoRA yet
+ #- pytest -v -s -x lora/test_gptoss_tp.py
- label: Weight Loading Multiple GPU Test # 33min
@@ -1234,7 +1337,7 @@ steps:
- vllm/
- tests/weight_loading
commands:
- - bash weight_loading/run_model_weight_loading_test.sh -c weight_loading/models.txt
+ - bash weight_loading/run_model_weight_loading_test.sh -c weight_loading/models-amd.txt
- label: Weight Loading Multiple GPU Test - Large Models # optional
mirror_hardwares: [amdexperimental]
@@ -1242,17 +1345,17 @@ steps:
# grade: Blocking
working_dir: "/vllm-workspace/tests"
num_gpus: 2
- gpu: a100
optional: true
source_file_dependencies:
- vllm/
- tests/weight_loading
commands:
- - bash weight_loading/run_model_weight_loading_test.sh -c weight_loading/models-large.txt
+ - bash weight_loading/run_model_weight_loading_test.sh -c weight_loading/models-large-amd.txt
- label: NixlConnector PD accuracy tests (Distributed) # 30min
mirror_hardwares: [amdexperimental]
agent_pool: mi325_4
+ # grade: Blocking
timeout_in_minutes: 30
working_dir: "/vllm-workspace/tests"
num_gpus: 4
@@ -1267,6 +1370,9 @@ steps:
##### A100 test #####
- label: Distributed Tests (A100) # optional
+ mirror_hardwares: [amdexperimental]
+ agent_pool: mi325_4
+ # grade: Blocking
gpu: a100
optional: true
num_gpus: 4
@@ -1281,6 +1387,9 @@ steps:
- pytest -v -s -x lora/test_mixtral.py
- label: LM Eval Large Models # optional
+ mirror_hardwares: [amdexperimental, amdproduction]
+ agent_pool: mi325_4
+ # grade: Blocking
gpu: a100
optional: true
num_gpus: 4
@@ -1292,19 +1401,41 @@ steps:
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
- pytest -s -v test_lm_eval_correctness.py --config-list-file=configs/models-large.txt --tp-size=4
+##### H100 test #####
+- label: LM Eval Large Models (H100) # optional
+ mirror_hardwares: [amdexperimental, amdproduction]
+ agent_pool: mi325_4
+ # grade: Blocking
+ gpu: h100
+ optional: true
+ num_gpus: 4
+ working_dir: "/vllm-workspace/.buildkite/lm-eval-harness"
+ source_file_dependencies:
+ - csrc/
+ - vllm/model_executor/layers/quantization
+ commands:
+ - export VLLM_USE_DEEP_GEMM=0 # We found Triton is faster than DeepGEMM for H100
+ - pytest -s -v test_lm_eval_correctness.py --config-list-file=configs/models-large-hopper.txt --tp-size=4
+
##### H200 test #####
- label: Distributed Tests (H200) # optional
+ mirror_hardwares: [amdexperimental]
+ agent_pool: mi325_2
+ # grade: Blocking
gpu: h200
optional: true
working_dir: "/vllm-workspace/"
num_gpus: 2
commands:
- - pytest -v -s tests/compile/test_async_tp.py
- - pytest -v -s tests/compile/test_sequence_parallelism.py
- - pytest -v -s tests/compile/test_fusion_all_reduce.py
- - pytest -v -s tests/compile/test_fusions_e2e.py::test_tp2_attn_quant_allreduce_rmsnorm
+ - pytest -v -s tests/compile/distributed/test_async_tp.py
+ - pytest -v -s tests/compile/distributed/test_sequence_parallelism.py
+ - pytest -v -s tests/compile/distributed/test_fusion_all_reduce.py
+ #- pytest -v -s tests/compile/distributed/test_fusions_e2e.py::test_tp2_attn_quant_allreduce_rmsnorm
+ - "pytest -v -s tests/compile/distributed/test_fusions_e2e.py -k 'not Llama-4'"
+ - pytest -v -s tests/compile/distributed/test_sequence_parallel.py
- pytest -v -s tests/distributed/test_context_parallel.py
- CUDA_VISIBLE_DEVICES=1,2 VLLM_ALL2ALL_BACKEND=deepep_high_throughput VLLM_USE_DEEP_GEMM=1 VLLM_LOGGING_LEVEL=DEBUG python3 examples/offline_inference/data_parallel.py --model Qwen/Qwen1.5-MoE-A2.7B --tp-size=1 --dp-size=2 --max-model-len 2048
+ - pytest -v -s tests/v1/distributed/test_dbo.py
##### B200 test #####
- label: Distributed Tests (B200) # optional
@@ -1315,6 +1446,7 @@ steps:
commands:
- pytest -v -s tests/distributed/test_context_parallel.py
- pytest -v -s tests/distributed/test_nccl_symm_mem_allreduce.py
+ - pytest -v -s tests/v1/distributed/test_dbo.py
##### RL Integration Tests #####
- label: Prime-RL Integration Test # 15min
@@ -1330,3 +1462,27 @@ steps:
- .buildkite/scripts/run-prime-rl-test.sh
commands:
- bash .buildkite/scripts/run-prime-rl-test.sh
+
+- label: DeepSeek V2-Lite Accuracy
+ mirror_hardwares: [amdexperimental]
+ agent_pool: mi325_4
+ # grade: Blocking
+ timeout_in_minutes: 60
+ gpu: h100
+ optional: true
+ num_gpus: 4
+ working_dir: "/vllm-workspace"
+ commands:
+ - bash .buildkite/scripts/scheduled_integration_test/deepseek_v2_lite_ep_eplb.sh 0.25 200 8010
+
+- label: Qwen3-30B-A3B-FP8-block Accuracy
+ mirror_hardwares: [amdexperimental]
+ agent_pool: mi325_4
+ # grade: Blocking
+ timeout_in_minutes: 60
+ gpu: h100
+ optional: true
+ num_gpus: 4
+ working_dir: "/vllm-workspace"
+ commands:
+ - bash .buildkite/scripts/scheduled_integration_test/qwen30b_a3b_fp8_block_ep_eplb.sh 0.8 200 8020
diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml
index be1b79ddc432..9f2107fb1e5a 100644
--- a/.buildkite/test-pipeline.yaml
+++ b/.buildkite/test-pipeline.yaml
@@ -57,14 +57,15 @@ steps:
- pytest -v -s -m 'not cpu_test' multimodal
- pytest -v -s utils_
-- label: Async Engine, Inputs, Utils, Worker, Config Test (CPU) # 4 mins
- timeout_in_minutes: 10
+- label: Async Engine, Inputs, Utils, Worker, Config Test (CPU) # 15min
+ timeout_in_minutes: 20
source_file_dependencies:
- vllm/
- tests/test_inputs.py
- tests/test_outputs.py
- tests/multimodal
- tests/standalone_tests/lazy_imports.py
+ - tests/tokenizers_
- tests/transformers_utils
- tests/config
no_gpu: true
@@ -73,6 +74,7 @@ steps:
- pytest -v -s test_inputs.py
- pytest -v -s test_outputs.py
- pytest -v -s -m 'cpu_test' multimodal
+ - pytest -v -s tokenizers_
- pytest -v -s transformers_utils
- pytest -v -s config
@@ -167,7 +169,7 @@ steps:
- tests/distributed/test_utils
- tests/distributed/test_pynccl
- tests/distributed/test_events
- - tests/compile/test_basic_correctness
+ - tests/compile/fullgraph/test_basic_correctness.py
- examples/offline_inference/rlhf.py
- examples/offline_inference/rlhf_colocate.py
- tests/examples/offline_inference/data_parallel.py
@@ -192,12 +194,13 @@ steps:
# test with internal dp
- python3 ../examples/offline_inference/data_parallel.py --enforce-eager
- TP_SIZE=2 DP_SIZE=2 pytest -v -s v1/distributed/test_async_llm_dp.py
+ - TP_SIZE=2 DP_SIZE=2 pytest -v -s v1/distributed/test_eagle_dp.py
- TP_SIZE=2 DP_SIZE=2 pytest -v -s v1/distributed/test_external_lb_dp.py
- TP_SIZE=1 DP_SIZE=4 pytest -v -s v1/distributed/test_internal_lb_dp.py
- TP_SIZE=1 DP_SIZE=4 pytest -v -s v1/distributed/test_hybrid_lb_dp.py
- pytest -v -s v1/engine/test_engine_core_client.py::test_kv_cache_events_dp
- pytest -v -s distributed/test_utils.py
- - pytest -v -s compile/test_basic_correctness.py
+ - pytest -v -s compile/fullgraph/test_basic_correctness.py
- pytest -v -s distributed/test_pynccl.py
- pytest -v -s distributed/test_events.py
- pytest -v -s distributed/test_symm_mem_allreduce.py
@@ -275,21 +278,18 @@ steps:
- pytest -v -s test_regression.py
working_dir: "/vllm-workspace/tests" # optional
-- label: Engine Test # 25min
- timeout_in_minutes: 40
+- label: Engine Test # 9min
+ timeout_in_minutes: 15
mirror_hardwares: [amdexperimental]
source_file_dependencies:
- vllm/
- tests/engine
- - tests/tokenization
- tests/test_sequence
- tests/test_config
- tests/test_logger
- tests/test_vllm_port
commands:
- pytest -v -s engine test_sequence.py test_config.py test_logger.py test_vllm_port.py
- # OOM in the CI unless we run this separately
- - pytest -v -s tokenization
- label: V1 Test e2e + engine # 30min
timeout_in_minutes: 45
@@ -346,6 +346,18 @@ steps:
commands:
- pytest -v -s v1/attention
+- label: Batch Invariance Tests (H100) # 10min
+ timeout_in_minutes: 25
+ gpu: h100
+ source_file_dependencies:
+ - vllm/
+ - tests/v1/determinism/
+ commands:
+ - export VLLM_WORKER_MULTIPROC_METHOD=spawn
+ - pip install pytest-timeout pytest-forked
+ - pytest -v -s v1/determinism/test_batch_invariance.py
+ - pytest -v -s v1/determinism/test_rms_norm_batch_invariant.py
+
- label: V1 Test attention (B200) # 10min
timeout_in_minutes: 30
gpu: b200
@@ -445,17 +457,12 @@ steps:
- vllm/
- tests/compile
commands:
- - pytest -v -s compile/test_config.py
- - pytest -v -s compile/test_pass_manager.py
- - pytest -v -s compile/test_fusion.py
- - pytest -v -s compile/test_fusion_attn.py
- - pytest -v -s compile/test_functionalization.py
- - pytest -v -s compile/test_silu_mul_quant_fusion.py
- - pytest -v -s compile/test_fusion_all_reduce.py
- - pytest -v -s compile/test_decorator.py
- - pytest -v -s compile/test_noop_elimination.py
- - pytest -v -s compile/test_aot_compile.py
- - pytest -v -s compile/test_qk_norm_rope_fusion.py
+ # Run unit tests defined directly under compile/,
+ # not including subdirectories, which are usually heavier
+ # tests covered elsewhere.
+ # Use `find` to launch multiple instances of pytest so that
+ # they do not suffer from https://github.com/vllm-project/vllm/issues/28965
+ - "find compile/ -maxdepth 1 -name 'test_*.py' -exec pytest -s -v {} \\\\;"
- label: PyTorch Fullgraph Smoke Test # 15min
timeout_in_minutes: 30
@@ -465,9 +472,11 @@ steps:
- vllm/
- tests/compile
commands:
- - pytest -v -s compile/test_basic_correctness.py
- - pytest -v -s compile/test_multimodal_compile.py
- - pytest -v -s compile/piecewise/
+ # Run smoke tests under fullgraph directory, except test_full_graph.py
+ # as it is a heavy test that is covered in other steps.
+ # Use `find` to launch multiple instances of pytest so that
+ # they do not suffer from https://github.com/vllm-project/vllm/issues/28965
+ - "find compile/fullgraph/ -name 'test_*.py' -not -name 'test_full_graph.py' -exec pytest -s -v {} \\\\;"
- label: PyTorch Fullgraph Test # 27min
timeout_in_minutes: 40
@@ -477,10 +486,11 @@ steps:
- vllm/
- tests/compile
commands:
- - pytest -v -s compile/test_full_graph.py -k 'not test_fp8_kv_scale_compile'
+ # fp8 kv scales not supported on sm89, tested on Blackwell instead
+ - pytest -v -s compile/fullgraph/test_full_graph.py -k 'not test_fp8_kv_scale_compile'
# Limit to no custom ops to reduce running time
# Wrap with quotes to escape yaml and avoid starting -k string with a -
- - "pytest -v -s compile/test_fusions_e2e.py -k 'TRITON and -quant_fp8'"
+ - "pytest -v -s compile/distributed/test_fusions_e2e.py -k 'TRITON and not +quant_fp8 and not Llama-4'"
- label: Cudagraph test
timeout_in_minutes: 20
@@ -552,6 +562,25 @@ steps:
commands:
- pytest -v -s kernels/mamba
+- label: Kernels DeepGEMM Test (H100)
+ timeout_in_minutes: 45
+ gpu: h100
+ num_gpus: 1
+ source_file_dependencies:
+ - tools/install_deepgemm.sh
+ - vllm/utils/deep_gemm.py
+ - vllm/model_executor/layers/fused_moe
+ - vllm/model_executor/layers/quantization
+ - tests/kernels/quantization/test_block_fp8.py
+ - tests/kernels/moe/test_deepgemm.py
+ - tests/kernels/moe/test_batched_deepgemm.py
+ - tests/kernels/attention/test_deepgemm_attention.py
+ commands:
+ - pytest -v -s kernels/quantization/test_block_fp8.py -k deep_gemm
+ - pytest -v -s kernels/moe/test_deepgemm.py
+ - pytest -v -s kernels/moe/test_batched_deepgemm.py
+ - pytest -v -s kernels/attention/test_deepgemm_attention.py
+
- label: Model Executor Test # 23min
timeout_in_minutes: 35
torch_nightly: true
@@ -602,6 +631,7 @@ steps:
# we can only upgrade after this is resolved
# TODO(jerryzh168): resolve the above comment
- uv pip install --system torchao==0.13.0 --index-url https://download.pytorch.org/whl/cu129
+ - uv pip install --system conch-triton-kernels
- VLLM_TEST_FORCE_LOAD_FORMAT=auto pytest -v -s quantization/ --ignore quantization/test_blackwell_moe.py
- label: LM Eval Small Models # 53min
@@ -662,6 +692,7 @@ steps:
torch_nightly: true
source_file_dependencies:
- vllm/model_executor/models/
+ - vllm/transformers_utils/
- tests/models/test_initialization.py
commands:
# Only when vLLM model source is modified - test initialization of a large
@@ -788,14 +819,24 @@ steps:
commands:
- pytest -v -s models/language/pooling_mteb_test
-- label: Multi-Modal Processor Test # 44min
+- label: Multi-Modal Processor Test (CPU)
+ timeout_in_minutes: 60
+ source_file_dependencies:
+ - vllm/
+ - tests/models/multimodal
+ no_gpu: true
+ commands:
+ - pip install git+https://github.com/TIGER-AI-Lab/Mantis.git
+ - pytest -v -s models/multimodal/processing --ignore models/multimodal/processing/test_tensor_schema.py
+
+- label: Multi-Modal Processor Test
timeout_in_minutes: 60
source_file_dependencies:
- vllm/
- tests/models/multimodal
commands:
- pip install git+https://github.com/TIGER-AI-Lab/Mantis.git
- - pytest -v -s models/multimodal/processing
+ - pytest -v -s models/multimodal/processing/test_tensor_schema.py
- label: Multi-Modal Models Test (Standard) # 60min
timeout_in_minutes: 80
@@ -872,14 +913,15 @@ steps:
- label: Transformers Nightly Models Test
working_dir: "/vllm-workspace/"
optional: true
+ soft_fail: true
commands:
- pip install --upgrade git+https://github.com/huggingface/transformers
- - pytest -v -s tests/models/test_initialization.py -k 'not (Gemma3 or ModernBert or Qwen2_5_VL or Qwen2_5vl or Qwen2VL or TransformersMultiModalEmbeddingModel or TransformersMultiModalForSequenceClassification or Ultravox or Phi4Multimodal or LlavaNextVideo or MiniCPMO or Lfm2Moe or PaliGemma or RobertaForSequenceClassification or Ovis2_5 or Fuyu or DeepseekOCR or KimiVL)'
+ - pytest -v -s tests/models/test_initialization.py
- pytest -v -s tests/models/test_transformers.py
- # - pytest -v -s tests/models/multimodal/processing/
- - pytest -v -s tests/models/multimodal/test_mapping.py -k 'not (Gemma3 or Qwen2VL or Qwen2_5_VL)'
+ - pytest -v -s tests/models/multimodal/processing/
+ - pytest -v -s tests/models/multimodal/test_mapping.py
- python3 examples/offline_inference/basic/chat.py
- # - python3 examples/offline_inference/vision_language.py --model-type qwen2_5_vl
+ - python3 examples/offline_inference/vision_language.py --model-type qwen2_5_vl
# Whisper needs spawn method to avoid deadlock
- VLLM_WORKER_MULTIPROC_METHOD=spawn python3 examples/offline_inference/audio_language.py --model-type whisper
@@ -923,8 +965,9 @@ steps:
- pytest -v -s tests/kernels/moe/test_nvfp4_moe.py
- pytest -v -s tests/kernels/moe/test_ocp_mx_moe.py
- pytest -v -s tests/kernels/moe/test_flashinfer.py
+ - pytest -v -s tests/kernels/moe/test_cutedsl_moe.py
-- label: Blackwell Fusion Tests # 30 min
+- label: Blackwell Fusion and Compile Tests # 30 min
timeout_in_minutes: 40
working_dir: "/vllm-workspace/"
gpu: b200
@@ -932,20 +975,29 @@ steps:
- csrc/quantization/fp4/
- vllm/model_executor/layers/quantization/utils/flashinfer_utils.py
- vllm/v1/attention/backends/flashinfer.py
+ - vllm/v1/worker/
+ - vllm/v1/cudagraph_dispatcher.py
- vllm/compilation/
# can affect pattern matching
- vllm/model_executor/layers/layernorm.py
- vllm/model_executor/layers/activation.py
- vllm/model_executor/layers/quantization/input_quant_fp8.py
+ - tests/compile/test_fusion_attn.py
+ - tests/compile/test_silu_mul_quant_fusion.py
+ - tests/compile/distributed/test_fusion_all_reduce.py
+ - tests/compile/distributed/test_fusions_e2e.py
+ - tests/compile/fullgraph/test_full_graph.py
commands:
- nvidia-smi
- pytest -v -s tests/compile/test_fusion_attn.py
- pytest -v -s tests/compile/test_silu_mul_quant_fusion.py
# this runner has 2 GPUs available even though num_gpus=2 is not set
- - pytest -v -s tests/compile/test_fusion_all_reduce.py
+ - pytest -v -s tests/compile/distributed/test_fusion_all_reduce.py
# Limit to Inductor partition, no custom ops, and allreduce & attn fusion to reduce running time
# Wrap with quotes to escape yaml
- - "pytest -v -s tests/compile/test_fusions_e2e.py::test_tp2_attn_quant_allreduce_rmsnorm -k 'True and Llama-3.1 and -quant_fp8 and -rms_norm'"
+ - "pytest -v -s tests/compile/distributed/test_fusions_e2e.py::test_tp2_attn_quant_allreduce_rmsnorm -k 'True and not +quant_fp8 and not +rms_norm'"
+ # test_fp8_kv_scale_compile requires FlashAttention (not supported on default L4/L40)
+ - pytest -v -s tests/compile/fullgraph/test_full_graph.py::test_fp8_kv_scale_compile
- label: Blackwell Fusion E2E Tests # 30 min
timeout_in_minutes: 40
@@ -962,14 +1014,11 @@ steps:
- vllm/model_executor/layers/layernorm.py
- vllm/model_executor/layers/activation.py
- vllm/model_executor/layers/quantization/input_quant_fp8.py
- - tests/compile/test_fusions_e2e.py
- - tests/compile/test_full_graph.py
+ - tests/compile/distributed/test_fusions_e2e.py
commands:
- nvidia-smi
# Run all e2e fusion tests
- - pytest -v -s tests/compile/test_fusions_e2e.py
- # test_fp8_kv_scale_compile requires FlashAttention (not supported on default L4/L40)
- - pytest -v -s tests/compile/test_full_graph.py::test_fp8_kv_scale_compile
+ - pytest -v -s tests/compile/distributed/test_fusions_e2e.py
- label: Blackwell GPT-OSS Eval
timeout_in_minutes: 60
@@ -1067,7 +1116,7 @@ steps:
- vllm/worker/worker_base.py
- vllm/v1/engine/
- vllm/v1/worker/
- - tests/compile/test_basic_correctness.py
+ - tests/compile/fullgraph/test_basic_correctness.py
- tests/compile/test_wrapper.py
- tests/distributed/
- tests/entrypoints/llm/test_collective_rpc.py
@@ -1079,10 +1128,11 @@ steps:
# https://github.com/NVIDIA/nccl/issues/1838
- export NCCL_CUMEM_HOST_ENABLE=0
- TP_SIZE=1 DP_SIZE=2 pytest -v -s v1/distributed/test_async_llm_dp.py
+ - TP_SIZE=1 DP_SIZE=2 pytest -v -s v1/distributed/test_eagle_dp.py
- TP_SIZE=1 DP_SIZE=2 pytest -v -s v1/distributed/test_external_lb_dp.py
- DP_SIZE=2 pytest -v -s v1/entrypoints/openai/test_multi_api_servers.py
- pytest -v -s entrypoints/llm/test_collective_rpc.py
- - pytest -v -s ./compile/test_basic_correctness.py
+ - pytest -v -s ./compile/fullgraph/test_basic_correctness.py
- pytest -v -s ./compile/test_wrapper.py
- VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py | grep 'Same node test passed'
- VLLM_TEST_SAME_HOST=1 VLLM_TEST_WITH_DEFAULT_DEVICE_SET=1 torchrun --nproc-per-node=4 distributed/test_same_node.py | grep 'Same node test passed'
@@ -1262,10 +1312,11 @@ steps:
working_dir: "/vllm-workspace/"
num_gpus: 2
commands:
- - pytest -v -s tests/compile/test_async_tp.py
- - pytest -v -s tests/compile/test_sequence_parallelism.py
- - pytest -v -s tests/compile/test_fusion_all_reduce.py
- - pytest -v -s tests/compile/test_fusions_e2e.py::test_tp2_attn_quant_allreduce_rmsnorm
+ - VLLM_TEST_CLEAN_GPU_MEMORY=1 pytest -v -s tests/compile/distributed/test_async_tp.py
+ - pytest -v -s tests/compile/distributed/test_sequence_parallelism.py
+ - pytest -v -s tests/compile/distributed/test_fusion_all_reduce.py
+ - "VLLM_TEST_CLEAN_GPU_MEMORY=1 pytest -v -s tests/compile/distributed/test_fusions_e2e.py -k 'not Llama-4'"
+ - VLLM_TEST_CLEAN_GPU_MEMORY=1 pytest -v -s tests/distributed/test_sequence_parallel.py
- pytest -v -s tests/distributed/test_context_parallel.py
- CUDA_VISIBLE_DEVICES=1,2 VLLM_ALL2ALL_BACKEND=deepep_high_throughput VLLM_USE_DEEP_GEMM=1 VLLM_LOGGING_LEVEL=DEBUG python3 examples/offline_inference/data_parallel.py --model Qwen/Qwen1.5-MoE-A2.7B --tp-size=1 --dp-size=2 --max-model-len 2048
- pytest -v -s tests/v1/distributed/test_dbo.py
@@ -1302,11 +1353,20 @@ steps:
commands:
- bash .buildkite/scripts/scheduled_integration_test/deepseek_v2_lite_ep_eplb.sh 0.25 200 8010
-- label: Qwen3-30B-A3B-FP8-block Accuracy
+- label: Qwen3-30B-A3B-FP8-block Accuracy (H100)
timeout_in_minutes: 60
gpu: h100
optional: true
num_gpus: 4
working_dir: "/vllm-workspace"
commands:
- - bash .buildkite/scripts/scheduled_integration_test/qwen30b_a3b_fp8_block_ep.sh 0.8 200 8020
+ - bash .buildkite/scripts/scheduled_integration_test/qwen30b_a3b_fp8_block_ep_eplb.sh 0.8 200 8020
+
+- label: Qwen3-30B-A3B-FP8-block Accuracy (B200)
+ timeout_in_minutes: 60
+ gpu: b200
+ optional: true
+ num_gpus: 2
+ working_dir: "/vllm-workspace"
+ commands:
+ - bash .buildkite/scripts/scheduled_integration_test/qwen30b_a3b_fp8_block_ep_eplb.sh 0.8 200 8020 2 1
\ No newline at end of file
diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS
index f26c782bccf2..ecb10d1a450f 100644
--- a/.github/CODEOWNERS
+++ b/.github/CODEOWNERS
@@ -3,12 +3,13 @@
# This lists cover the "core" components of vLLM that require careful review
/vllm/attention @LucasWilkinson
-/vllm/attention/backends/abstract.py @WoosukKwon @zhuohan123 @youkaichao @alexm-redhat @comaniac @njhill
-/vllm/executor/executor_base.py @zhuohan123 @youkaichao @alexm-redhat @comaniac @njhill @22quinn
+/vllm/attention/backends/abstract.py @WoosukKwon @zhuohan123 @youkaichao @alexm-redhat @njhill
+/vllm/executor/executor_base.py @zhuohan123 @youkaichao @alexm-redhat @njhill @22quinn
/vllm/model_executor/layers/fused_moe @mgoin @pavanimajety
/vllm/model_executor/layers/quantization @mgoin @robertgshaw2-redhat @tlrmchlsmth @yewentao256 @pavanimajety
/vllm/model_executor/layers/mamba @tdoublep
/vllm/model_executor/model_loader @22quinn
+/vllm/model_executor/layers/batch_invariant.py @yewentao256
/vllm/multimodal @DarkLight1337 @ywang96 @NickLucche @tjtanaa
/vllm/vllm_flash_attn @LucasWilkinson
/vllm/lora @jeejeelee
@@ -20,27 +21,30 @@ CMakeLists.txt @tlrmchlsmth @LucasWilkinson
# Any change to the VllmConfig changes can have a large user-facing impact,
# so spam a lot of people
-/vllm/config @simon-mo @WoosukKwon @youkaichao @robertgshaw2-redhat @mgoin @tlrmchlsmth @houseroad @hmellor @yewentao256 @ProExpertProg
-/vllm/config/cache.py @simon-mo @WoosukKwon @youkaichao @robertgshaw2-redhat @mgoin @tlrmchlsmth @houseroad @hmellor @yewentao256 @ProExpertProg @heheda12345
+/vllm/config @WoosukKwon @youkaichao @robertgshaw2-redhat @mgoin @tlrmchlsmth @houseroad @hmellor @yewentao256 @ProExpertProg
+/vllm/config/cache.py @WoosukKwon @youkaichao @robertgshaw2-redhat @mgoin @tlrmchlsmth @houseroad @hmellor @yewentao256 @ProExpertProg @heheda12345
# vLLM V1
/vllm/v1/attention @LucasWilkinson
/vllm/v1/attention/backends/mla @pavanimajety
/vllm/v1/attention/backends/flashinfer.py @mgoin @pavanimajety
/vllm/v1/attention/backends/triton_attn.py @tdoublep
-/vllm/v1/core @WoosukKwon @robertgshaw2-redhat @njhill @ywang96 @comaniac @alexm-redhat @heheda12345 @ApostaC
+/vllm/v1/core @WoosukKwon @robertgshaw2-redhat @njhill @ywang96 @alexm-redhat @heheda12345 @ApostaC
/vllm/v1/sample @22quinn @houseroad @njhill
/vllm/v1/spec_decode @benchislett @luccafong
/vllm/v1/structured_output @mgoin @russellb @aarnphm @benchislett
/vllm/v1/kv_cache_interface.py @heheda12345
/vllm/v1/offloading @ApostaC
+# Model runner V2
+/vllm/v1/worker/gpu @WoosukKwon
+
# Test ownership
-/.buildkite/lm-eval-harness @mgoin @simon-mo
+/.buildkite/lm-eval-harness @mgoin
/tests/distributed/test_multi_node_assignment.py @youkaichao
/tests/distributed/test_pipeline_parallel.py @youkaichao
/tests/distributed/test_same_node.py @youkaichao
-/tests/entrypoints @DarkLight1337 @robertgshaw2-redhat @simon-mo @aarnphm @NickLucche
+/tests/entrypoints @DarkLight1337 @robertgshaw2-redhat @aarnphm @NickLucche
/tests/evals @mgoin
/tests/kernels @mgoin @tlrmchlsmth @WoosukKwon @yewentao256
/tests/models @DarkLight1337 @ywang96
@@ -49,15 +53,16 @@ CMakeLists.txt @tlrmchlsmth @LucasWilkinson
/tests/test_inputs.py @DarkLight1337 @ywang96
/tests/v1/entrypoints/llm/test_struct_output_generate.py @mgoin @russellb @aarnphm
/tests/v1/structured_output @mgoin @russellb @aarnphm
-/tests/v1/core @WoosukKwon @robertgshaw2-redhat @njhill @ywang96 @comaniac @alexm-redhat @heheda12345 @ApostaC
+/tests/v1/core @WoosukKwon @robertgshaw2-redhat @njhill @ywang96 @alexm-redhat @heheda12345 @ApostaC
/tests/weight_loading @mgoin @youkaichao @yewentao256
/tests/lora @jeejeelee
/tests/models/language/generation/test_hybrid.py @tdoublep
/tests/v1/kv_connector/nixl_integration @NickLucche
/tests/v1/kv_connector @ApostaC
/tests/v1/offloading @ApostaC
+/tests/v1/determinism @yewentao256
-# Transformers backend
+# Transformers modeling backend
/vllm/model_executor/models/transformers @hmellor
/tests/models/test_transformers.py @hmellor
@@ -144,6 +149,7 @@ mkdocs.yaml @hmellor
/examples/*/pooling/ @noooop
/tests/models/*/pooling* @noooop
/tests/entrypoints/pooling @noooop
+/vllm/entrypoints/pooling @aarnphm @chaunceyjiang @noooop
/vllm/config/pooler.py @noooop
/vllm/pooling_params.py @noooop
/vllm/model_executor/layers/pooler.py @noooop
diff --git a/.github/workflows/cleanup_pr_body.yml b/.github/workflows/cleanup_pr_body.yml
index c3e132a536a4..861290ea43c8 100644
--- a/.github/workflows/cleanup_pr_body.yml
+++ b/.github/workflows/cleanup_pr_body.yml
@@ -13,7 +13,7 @@ jobs:
steps:
- name: Checkout repository
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
+ uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # v6.0.0
- name: Set up Python
uses: actions/setup-python@e797f83bcb11b83ae66e0230d6156d7c80228e7c # v6.0.0
diff --git a/.github/workflows/issue_autolabel.yml b/.github/workflows/issue_autolabel.yml
index 7d565ef9f2e4..629966b95933 100644
--- a/.github/workflows/issue_autolabel.yml
+++ b/.github/workflows/issue_autolabel.yml
@@ -105,6 +105,31 @@ jobs:
}
],
},
+ cpu: {
+ // Keyword search - matches whole words only (with word boundaries)
+ keywords: [
+ {
+ term: "CPU Backend",
+ searchIn: "title"
+ },
+ {
+ term: "x86",
+ searchIn: "title"
+ },
+ {
+ term: "ARM",
+ searchIn: "title"
+ },
+ {
+ term: "Apple Silicon",
+ searchIn: "title"
+ },
+ {
+ term: "IBM Z",
+ searchIn: "title"
+ },
+ ],
+ },
// Add more label configurations here as needed
// example: {
// keywords: [...],
diff --git a/.github/workflows/macos-smoke-test.yml b/.github/workflows/macos-smoke-test.yml
new file mode 100644
index 000000000000..3a12c4b3a830
--- /dev/null
+++ b/.github/workflows/macos-smoke-test.yml
@@ -0,0 +1,80 @@
+name: macOS Apple Silicon Smoke Test
+
+on:
+ push:
+ branches:
+ - main
+ workflow_dispatch: # Manual trigger
+
+jobs:
+ macos-m1-smoke-test:
+ runs-on: macos-latest
+ timeout-minutes: 30
+
+ steps:
+ - uses: actions/checkout@v6
+
+ - uses: astral-sh/setup-uv@v7
+ with:
+ enable-cache: true
+ cache-dependency-glob: |
+ requirements/**/*.txt
+ pyproject.toml
+ python-version: '3.12'
+
+ - name: Create virtual environment
+ run: |
+ uv venv
+ echo "$GITHUB_WORKSPACE/.venv/bin" >> "$GITHUB_PATH"
+
+ - name: Install dependencies and build vLLM
+ run: |
+ uv pip install -r requirements/cpu.txt --index-strategy unsafe-best-match
+ uv pip install -e .
+ env:
+ CMAKE_BUILD_PARALLEL_LEVEL: 4
+
+ - name: Verify installation
+ run: |
+ python -c "import vllm; print(f'vLLM version: {vllm.__version__}')"
+
+ - name: Smoke test vllm serve
+ run: |
+ # Start server in background
+ vllm serve Qwen/Qwen3-0.6B \
+ --max-model-len=2K \
+ --load-format=dummy \
+ --hf-overrides '{"num_hidden_layers": 2}' \
+ --enforce-eager \
+ --port 8000 &
+
+ SERVER_PID=$!
+
+ # Wait for server to start
+ for i in {1..30}; do
+ if curl -s http://localhost:8000/health > /dev/null; then
+ echo "Server started successfully"
+ break
+ fi
+ if [ "$i" -eq 30 ]; then
+ echo "Server failed to start"
+ kill "$SERVER_PID"
+ exit 1
+ fi
+ sleep 2
+ done
+
+ # Test health endpoint
+ curl -f http://localhost:8000/health
+
+ # Test completion
+ curl -f http://localhost:8000/v1/completions \
+ -H "Content-Type: application/json" \
+ -d '{
+ "model": "Qwen/Qwen3-0.6B",
+ "prompt": "Hello",
+ "max_tokens": 5
+ }'
+
+ # Cleanup
+ kill "$SERVER_PID"
diff --git a/.github/workflows/pre-commit.yml b/.github/workflows/pre-commit.yml
index e21d13b8161f..d5e70f30ef63 100644
--- a/.github/workflows/pre-commit.yml
+++ b/.github/workflows/pre-commit.yml
@@ -16,7 +16,7 @@ jobs:
pre-commit:
runs-on: ubuntu-latest
steps:
- - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
+ - uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # v6.0.0
- uses: actions/setup-python@e797f83bcb11b83ae66e0230d6156d7c80228e7c # v6.0.0
with:
python-version: "3.12"
diff --git a/.gitignore b/.gitignore
index 50070d7898fe..7cda86478664 100644
--- a/.gitignore
+++ b/.gitignore
@@ -4,6 +4,9 @@
# vllm-flash-attn built from source
vllm/vllm_flash_attn/*
+# OpenAI triton kernels copied from source
+vllm/third_party/triton_kernels/*
+
# triton jit
.triton
diff --git a/.markdownlint.yaml b/.markdownlint.yaml
index cd9df57cd980..937487f47364 100644
--- a/.markdownlint.yaml
+++ b/.markdownlint.yaml
@@ -3,10 +3,9 @@ MD007:
MD013: false
MD024:
siblings_only: true
+MD031:
+ list_items: false
MD033: false
-MD045: false
MD046: false
-MD051: false
MD052: false
-MD053: false
MD059: false
diff --git a/CMakeLists.txt b/CMakeLists.txt
index dcc44be87e55..e09972fe7199 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -136,7 +136,7 @@ elseif(HIP_FOUND)
# ROCm 5.X and 6.X
if (ROCM_VERSION_DEV_MAJOR GREATER_EQUAL 5 AND
- NOT Torch_VERSION VERSION_EQUAL ${TORCH_SUPPORTED_VERSION_ROCM})
+ Torch_VERSION VERSION_LESS ${TORCH_SUPPORTED_VERSION_ROCM})
message(WARNING "Pytorch version >= ${TORCH_SUPPORTED_VERSION_ROCM} "
"expected for ROCm build, saw ${Torch_VERSION} instead.")
endif()
@@ -307,7 +307,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
SET(CUTLASS_ENABLE_HEADERS_ONLY ON CACHE BOOL "Enable only the header library")
# Set CUTLASS_REVISION. Used for FetchContent. Also fixes some bogus messages when building.
- set(CUTLASS_REVISION "v4.2.1" CACHE STRING "CUTLASS revision to use")
+ set(CUTLASS_REVISION "v4.2.1")
# Use the specified CUTLASS source directory for compilation if VLLM_CUTLASS_SRC_DIR is provided
if (DEFINED ENV{VLLM_CUTLASS_SRC_DIR})
@@ -354,8 +354,17 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
# Only build Marlin kernels if we are building for at least some compatible archs.
# Keep building Marlin for 9.0 as there are some group sizes and shapes that
# are not supported by Machete yet.
- # 9.0 for latest bf16 atomicAdd PTX
- cuda_archs_loose_intersection(MARLIN_ARCHS "8.0+PTX;9.0+PTX" "${CUDA_ARCHS}")
+
+ # marlin arches for fp16 output
+ cuda_archs_loose_intersection(MARLIN_ARCHS "8.0+PTX" "${CUDA_ARCHS}")
+ # marlin arches for bf16 output (we need 9.0 for bf16 atomicAdd PTX)
+ cuda_archs_loose_intersection(MARLIN_BF16_ARCHS "8.0+PTX;9.0+PTX" "${CUDA_ARCHS}")
+ # marlin arches for fp8 input
+ # - sm80 doesn't support fp8 computation
+ # - sm90 and sm100 don't support QMMA.16832.F32.E4M3.E4M3 SAAS instruction
+ # so we only enable fp8 computation for SM89 (e.g. RTX 40x0) and 12.0 (e.g. RTX 50x0)
+ cuda_archs_loose_intersection(MARLIN_FP8_ARCHS "8.9;12.0" "${CUDA_ARCHS}")
+
if (MARLIN_ARCHS)
#
@@ -365,16 +374,18 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
set(MARLIN_GEN_SCRIPT
${CMAKE_CURRENT_SOURCE_DIR}/csrc/quantization/gptq_marlin/generate_kernels.py)
file(MD5 ${MARLIN_GEN_SCRIPT} MARLIN_GEN_SCRIPT_HASH)
+ list(JOIN CUDA_ARCHS "," CUDA_ARCHS_STR)
+ set(MARLIN_GEN_SCRIPT_HASH_AND_ARCH "${MARLIN_GEN_SCRIPT_HASH}(ARCH:${CUDA_ARCHS_STR})")
- message(STATUS "Marlin generation script hash: ${MARLIN_GEN_SCRIPT_HASH}")
- message(STATUS "Last run Marlin generate script hash: $CACHE{MARLIN_GEN_SCRIPT_HASH}")
+ message(STATUS "Marlin generation script hash: ${MARLIN_GEN_SCRIPT_HASH_AND_ARCH}")
+ message(STATUS "Last run Marlin generate script hash: $CACHE{MARLIN_GEN_SCRIPT_HASH_AND_ARCH}")
- if (NOT DEFINED CACHE{MARLIN_GEN_SCRIPT_HASH}
- OR NOT $CACHE{MARLIN_GEN_SCRIPT_HASH} STREQUAL ${MARLIN_GEN_SCRIPT_HASH})
+ if (NOT DEFINED CACHE{MARLIN_GEN_SCRIPT_HASH_AND_ARCH}
+ OR NOT $CACHE{MARLIN_GEN_SCRIPT_HASH_AND_ARCH} STREQUAL ${MARLIN_GEN_SCRIPT_HASH_AND_ARCH})
execute_process(
COMMAND ${CMAKE_COMMAND} -E env
PYTHONPATH=$PYTHONPATH
- ${Python_EXECUTABLE} ${MARLIN_GEN_SCRIPT}
+ ${Python_EXECUTABLE} ${MARLIN_GEN_SCRIPT} ${CUDA_ARCHS_STR}
RESULT_VARIABLE marlin_generation_result
OUTPUT_VARIABLE marlin_generation_result
OUTPUT_FILE ${CMAKE_CURRENT_BINARY_DIR}/marlin_generation.log
@@ -387,15 +398,15 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
"\nCheck the log for details: "
"${CMAKE_CURRENT_BINARY_DIR}/marlin_generation.log")
else()
- set(MARLIN_GEN_SCRIPT_HASH ${MARLIN_GEN_SCRIPT_HASH}
- CACHE STRING "Last run Marlin generate script hash" FORCE)
+ set(MARLIN_GEN_SCRIPT_HASH_AND_ARCH ${MARLIN_GEN_SCRIPT_HASH_AND_ARCH}
+ CACHE STRING "Last run Marlin generate script hash and arch" FORCE)
message(STATUS "Marlin generation completed successfully.")
endif()
else()
message(STATUS "Marlin generation script has not changed, skipping generation.")
endif()
- file(GLOB MARLIN_TEMPLATE_KERNEL_SRC "csrc/quantization/gptq_marlin/kernel_*.cu")
+ file(GLOB MARLIN_TEMPLATE_KERNEL_SRC "csrc/quantization/gptq_marlin/sm80_kernel_*_float16.cu")
set_gencode_flags_for_srcs(
SRCS "${MARLIN_TEMPLATE_KERNEL_SRC}"
CUDA_ARCHS "${MARLIN_ARCHS}")
@@ -403,12 +414,34 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
set_source_files_properties(${MARLIN_TEMPLATE_KERNEL_SRC}
PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false")
endif()
-
list(APPEND VLLM_EXT_SRC ${MARLIN_TEMPLATE_KERNEL_SRC})
+ file(GLOB MARLIN_TEMPLATE_BF16_KERNEL_SRC "csrc/quantization/gptq_marlin/sm80_kernel_*_bfloat16.cu")
+ set_gencode_flags_for_srcs(
+ SRCS "${MARLIN_TEMPLATE_BF16_KERNEL_SRC}"
+ CUDA_ARCHS "${MARLIN_BF16_ARCHS}")
+ if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8)
+ set_source_files_properties(${MARLIN_TEMPLATE_BF16_KERNEL_SRC}
+ PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false")
+ endif()
+ list(APPEND VLLM_EXT_SRC ${MARLIN_TEMPLATE_BF16_KERNEL_SRC})
+
+ if (MARLIN_FP8_ARCHS)
+ file(GLOB MARLIN_TEMPLATE_FP8_KERNEL_SRC "csrc/quantization/gptq_marlin/sm89_kernel_*.cu")
+ set_gencode_flags_for_srcs(
+ SRCS "${MARLIN_TEMPLATE_FP8_KERNEL_SRC}"
+ CUDA_ARCHS "${MARLIN_FP8_ARCHS}")
+ if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8)
+ set_source_files_properties(${MARLIN_TEMPLATE_FP8_KERNEL_SRC}
+ PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false")
+ endif()
+ list(APPEND VLLM_EXT_SRC ${MARLIN_TEMPLATE_FP8_KERNEL_SRC})
+ endif()
+
set(MARLIN_SRCS
"csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu"
"csrc/quantization/gptq_marlin/gptq_marlin.cu"
+ "csrc/quantization/gptq_marlin/marlin_int4_fp8_preprocess.cu"
"csrc/quantization/gptq_marlin/gptq_marlin_repack.cu"
"csrc/quantization/gptq_marlin/awq_marlin_repack.cu")
set_gencode_flags_for_srcs(
@@ -512,9 +545,9 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
# The cutlass_scaled_mm kernels for Blackwell SM100 (c3x, i.e. CUTLASS 3.x)
# require CUDA 12.8 or later
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
- cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0f;11.0f;12.0f" "${CUDA_ARCHS}")
+ cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0f;11.0f" "${CUDA_ARCHS}")
else()
- cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0a;10.1a;10.3a;12.0a;12.1a" "${CUDA_ARCHS}")
+ cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0a;10.1a;10.3a" "${CUDA_ARCHS}")
endif()
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS)
set(SRCS
@@ -604,12 +637,15 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
set(SRCS
"csrc/quantization/fp4/nvfp4_quant_kernels.cu"
"csrc/quantization/fp4/activation_nvfp4_quant_fusion_kernels.cu"
- "csrc/quantization/fp4/nvfp4_scaled_mm_sm120_kernels.cu")
+ "csrc/quantization/fp4/nvfp4_experts_quant.cu"
+ "csrc/quantization/fp4/nvfp4_scaled_mm_sm120_kernels.cu"
+ "csrc/quantization/fp4/nvfp4_blockwise_moe_kernel.cu")
set_gencode_flags_for_srcs(
SRCS "${SRCS}"
CUDA_ARCHS "${FP4_ARCHS}")
list(APPEND VLLM_EXT_SRC "${SRCS}")
list(APPEND VLLM_GPU_FLAGS "-DENABLE_NVFP4_SM120=1")
+ list(APPEND VLLM_GPU_FLAGS "-DENABLE_CUTLASS_MOE_SM120=1")
message(STATUS "Building NVFP4 for archs: ${FP4_ARCHS}")
else()
message(STATUS "Not building NVFP4 as no compatible archs were found.")
@@ -619,9 +655,9 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
# FP4 Archs and flags
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
- cuda_archs_loose_intersection(FP4_ARCHS "10.0f;11.0f;12.0f" "${CUDA_ARCHS}")
+ cuda_archs_loose_intersection(FP4_ARCHS "10.0f;11.0f" "${CUDA_ARCHS}")
else()
- cuda_archs_loose_intersection(FP4_ARCHS "10.0a;10.1a;12.0a;12.1a" "${CUDA_ARCHS}")
+ cuda_archs_loose_intersection(FP4_ARCHS "10.0a;10.1a;10.3a" "${CUDA_ARCHS}")
endif()
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND FP4_ARCHS)
set(SRCS
@@ -695,7 +731,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0f;11.0f" "${CUDA_ARCHS}")
else()
- cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0a" "${CUDA_ARCHS}")
+ cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0a;10.1a;10.3a" "${CUDA_ARCHS}")
endif()
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS)
set(SRCS "csrc/quantization/w8a8/cutlass/moe/grouped_mm_c3x_sm100.cu")
@@ -741,9 +777,9 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
endif()
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
- cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0f;11.0f;12.0f" "${CUDA_ARCHS}")
+ cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0f;11.0f" "${CUDA_ARCHS}")
else()
- cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0a;10.1a;10.3a;12.0a;12.1a" "${CUDA_ARCHS}")
+ cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0a;10.1a;10.3a" "${CUDA_ARCHS}")
endif()
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS)
set(SRCS "csrc/quantization/w8a8/cutlass/moe/blockwise_scaled_group_mm_sm100.cu")
@@ -861,7 +897,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
endif()
# Hadacore kernels
- cuda_archs_loose_intersection(HADACORE_ARCHS "8.0;8.9;9.0" "${CUDA_ARCHS}")
+ cuda_archs_loose_intersection(HADACORE_ARCHS "8.0+PTX;9.0+PTX" "${CUDA_ARCHS}")
if(HADACORE_ARCHS)
set(SRCS "csrc/quantization/hadamard/hadacore/hadamard_transform_cuda.cu")
set_gencode_flags_for_srcs(
@@ -938,8 +974,15 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
CUDA_ARCHS "${CUDA_ARCHS}")
list(APPEND VLLM_MOE_EXT_SRC "${VLLM_MOE_WNA16_SRC}")
- # 9.0 for latest bf16 atomicAdd PTX
- cuda_archs_loose_intersection(MARLIN_MOE_ARCHS "8.0+PTX;9.0+PTX" "${CUDA_ARCHS}")
+ # moe marlin arches
+ # note that we always set `use_atomic_add=False` for moe marlin now,
+ # so we don't need 9.0 for bf16 atomicAdd PTX
+ cuda_archs_loose_intersection(MARLIN_MOE_ARCHS "8.0+PTX" "${CUDA_ARCHS}")
+ # moe marlin arches for fp8 input
+ # - sm80 doesn't support fp8 computation
+ # - sm90 and sm100 don't support QMMA.16832.F32.E4M3.E4M3 SAAS instruction
+ # so we only enable fp8 computation for SM89 (e.g. RTX 40x0) and 12.0 (e.g. RTX 50x0)
+ cuda_archs_loose_intersection(MARLIN_MOE_FP8_ARCHS "8.9;12.0" "${CUDA_ARCHS}")
if (MARLIN_MOE_ARCHS)
#
@@ -949,16 +992,18 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
set(MOE_MARLIN_GEN_SCRIPT
${CMAKE_CURRENT_SOURCE_DIR}/csrc/moe/marlin_moe_wna16/generate_kernels.py)
file(MD5 ${MOE_MARLIN_GEN_SCRIPT} MOE_MARLIN_GEN_SCRIPT_HASH)
+ list(JOIN CUDA_ARCHS "," CUDA_ARCHS_STR)
+ set(MOE_MARLIN_GEN_SCRIPT_HASH_AND_ARCH "${MOE_MARLIN_GEN_SCRIPT_HASH}(ARCH:${CUDA_ARCHS_STR})")
- message(STATUS "Marlin MOE generation script hash: ${MOE_MARLIN_GEN_SCRIPT_HASH}")
- message(STATUS "Last run Marlin MOE generate script hash: $CACHE{MOE_MARLIN_GEN_SCRIPT_HASH}")
+ message(STATUS "Marlin MOE generation script hash with arch: ${MOE_MARLIN_GEN_SCRIPT_HASH_AND_ARCH}")
+ message(STATUS "Last run Marlin MOE generate script hash with arch: $CACHE{MOE_MARLIN_GEN_SCRIPT_HASH_AND_ARCH}")
- if (NOT DEFINED CACHE{MOE_MARLIN_GEN_SCRIPT_HASH}
- OR NOT $CACHE{MOE_MARLIN_GEN_SCRIPT_HASH} STREQUAL ${MOE_MARLIN_GEN_SCRIPT_HASH})
+ if (NOT DEFINED CACHE{MOE_MARLIN_GEN_SCRIPT_HASH_AND_ARCH}
+ OR NOT $CACHE{MOE_MARLIN_GEN_SCRIPT_HASH_AND_ARCH} STREQUAL ${MOE_MARLIN_GEN_SCRIPT_HASH_AND_ARCH})
execute_process(
COMMAND ${CMAKE_COMMAND} -E env
PYTHONPATH=$PYTHONPATH
- ${Python_EXECUTABLE} ${MOE_MARLIN_GEN_SCRIPT}
+ ${Python_EXECUTABLE} ${MOE_MARLIN_GEN_SCRIPT} ${CUDA_ARCHS_STR}
RESULT_VARIABLE moe_marlin_generation_result
OUTPUT_VARIABLE moe_marlin_generation_output
OUTPUT_FILE ${CMAKE_CURRENT_BINARY_DIR}/moe_marlin_generation.log
@@ -971,7 +1016,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
"\nCheck the log for details: "
"${CMAKE_CURRENT_BINARY_DIR}/moe_marlin_generation.log")
else()
- set(MOE_MARLIN_GEN_SCRIPT_HASH ${MOE_MARLIN_GEN_SCRIPT_HASH}
+ set(MOE_MARLIN_GEN_SCRIPT_HASH_AND_ARCH ${MOE_MARLIN_GEN_SCRIPT_HASH_AND_ARCH}
CACHE STRING "Last run Marlin MOE generate script hash" FORCE)
message(STATUS "Marlin MOE generation completed successfully.")
endif()
@@ -979,16 +1024,28 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
message(STATUS "Marlin MOE generation script has not changed, skipping generation.")
endif()
- file(GLOB MOE_WNAA16_MARLIN_SRC "csrc/moe/marlin_moe_wna16/*.cu")
+ file(GLOB MARLIN_MOE_SRC "csrc/moe/marlin_moe_wna16/sm80_kernel_*.cu")
+ list(APPEND MARLIN_MOE_SRC "csrc/moe/marlin_moe_wna16/ops.cu")
set_gencode_flags_for_srcs(
- SRCS "${MOE_WNAA16_MARLIN_SRC}"
+ SRCS "${MARLIN_MOE_SRC}"
CUDA_ARCHS "${MARLIN_MOE_ARCHS}")
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8)
- set_source_files_properties(${MOE_WNAA16_MARLIN_SRC}
+ set_source_files_properties(${MARLIN_MOE_SRC}
PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false")
endif()
-
- list(APPEND VLLM_MOE_EXT_SRC ${MOE_WNAA16_MARLIN_SRC})
+ list(APPEND VLLM_MOE_EXT_SRC ${MARLIN_MOE_SRC})
+
+ if (MARLIN_MOE_FP8_ARCHS)
+ file(GLOB MARLIN_MOE_FP8_SRC "csrc/moe/marlin_moe_wna16/sm89_kernel_*.cu")
+ set_gencode_flags_for_srcs(
+ SRCS "${MARLIN_MOE_FP8_SRC}"
+ CUDA_ARCHS "${MARLIN_MOE_FP8_ARCHS}")
+ if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8)
+ set_source_files_properties(${MARLIN_MOE_FP8_SRC}
+ PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false")
+ endif()
+ list(APPEND VLLM_MOE_EXT_SRC ${MARLIN_MOE_FP8_SRC})
+ endif()
message(STATUS "Building Marlin MOE kernels for archs: ${MARLIN_MOE_ARCHS}")
else()
@@ -1030,6 +1087,11 @@ if(VLLM_GPU_LANG STREQUAL "HIP")
WITH_SOABI)
endif()
+# For CUDA and HIP builds also build the triton_kernels external package.
+if(VLLM_GPU_LANG STREQUAL "CUDA" OR VLLM_GPU_LANG STREQUAL "HIP")
+ include(cmake/external_projects/triton_kernels.cmake)
+endif()
+
# For CUDA we also build and ship some external projects.
if (VLLM_GPU_LANG STREQUAL "CUDA")
include(cmake/external_projects/flashmla.cmake)
diff --git a/README.md b/README.md
index 033e1035d891..abbb63158f16 100644
--- a/README.md
+++ b/README.md
@@ -21,6 +21,7 @@ Join us at the [PyTorch Conference, October 22-23](https://events.linuxfoundatio
*Latest News* 🔥
+- [2025/11] We hosted [vLLM Bangkok Meetup](https://luma.com/v0f647nv). We explored vLLM and LMCache inference and low-resource language adaptation with speakers from Embedded LLM, AMD, and Red Hat. Please find the meetup slides [here](https://drive.google.com/drive/folders/1H0DS57F8HQ5q3kSOSoRmucPJWL3E0A_X?usp=sharing).
- [2025/11] We hosted [the first vLLM Europe Meetup in Zurich](https://luma.com/0gls27kb) focused on quantization, distributed inference, and reinforcement learning at scale with speakers from Mistral, IBM, and Red Hat. Please find the meetup slides [here](https://docs.google.com/presentation/d/1UC9PTLCHYXQpOmJDSFg6Sljra3iVXzc09DeEI7dnxMc/edit?usp=sharing) and recording [here](https://www.youtube.com/watch?v=6m6ZE6yVEDI)
- [2025/11] We hosted [vLLM Beijing Meetup](https://mp.weixin.qq.com/s/xSrYXjNgr1HbCP4ExYNG1w) focusing on distributed inference and diverse accelerator support with vLLM! Please find the meetup slides [here](https://drive.google.com/drive/folders/1nQJ8ZkLSjKxvu36sSHaceVXtttbLvvu-?usp=drive_link).
- [2025/10] We hosted [vLLM Shanghai Meetup](https://mp.weixin.qq.com/s/__xb4OyOsImz-9eAVrdlcg) focused on hands-on vLLM inference optimization! Please find the meetup slides [here](https://drive.google.com/drive/folders/1KqwjsFJLfEsC8wlDugnrR61zsWHt94Q6).
diff --git a/benchmarks/backend_request_func.py b/benchmarks/backend_request_func.py
index 4021fede7215..d69d74ca61f5 100644
--- a/benchmarks/backend_request_func.py
+++ b/benchmarks/backend_request_func.py
@@ -620,7 +620,7 @@ def get_tokenizer(
kwargs["use_fast"] = False
if tokenizer_mode == "mistral":
try:
- from vllm.transformers_utils.tokenizer import MistralTokenizer
+ from vllm.tokenizers import MistralTokenizer
except ImportError as e:
raise ImportError(
"MistralTokenizer requires vllm package.\n"
diff --git a/benchmarks/benchmark_batch_invariance.py b/benchmarks/benchmark_batch_invariance.py
new file mode 100755
index 000000000000..b5c16c42de46
--- /dev/null
+++ b/benchmarks/benchmark_batch_invariance.py
@@ -0,0 +1,380 @@
+#!/usr/bin/env python3
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+"""
+Benchmark to measure the performance overhead of VLLM_BATCH_INVARIANT mode.
+
+This benchmark runs the same workload twice:
+1. With VLLM_BATCH_INVARIANT=0 (baseline)
+2. With VLLM_BATCH_INVARIANT=1 (batch invariant mode)
+
+And reports the timing and throughput metrics for comparison.
+
+Environment variables:
+ VLLM_BENCH_MODEL: Model to benchmark (default: "Qwen/Qwen3-1.7B")
+ VLLM_BENCH_TP_SIZE: Tensor parallel size (default: 1, use 8 for deepseek)
+ VLLM_BENCH_BATCH_SIZE: Max batch size (default: 128)
+ VLLM_BENCH_NUM_TRIALS: Number of trials to run (default: 5)
+ VLLM_BENCH_MIN_PROMPT: Min prompt length in words (default: 1024)
+ VLLM_BENCH_MAX_PROMPT: Max prompt length in words (default: 2048)
+ VLLM_BENCH_MAX_TOKENS: Max tokens to generate (default: 128)
+ VLLM_BENCH_TEMPERATURE: Temperature for sampling (default: 0.0)
+ VLLM_BENCH_GPU_MEMORY_UTILIZATION: GPU memory utilization (default: 0.4)
+ VLLM_BENCH_MAX_MODEL_LEN: Max model length (default: 5120)
+ VLLM_BENCH_BACKEND: Attention backend (default: FLASH_ATTN)
+
+Example usage:
+ # Benchmark qwen3 (default)
+ python benchmarks/benchmark_batch_invariance.py
+
+ # Benchmark deepseek with 8 GPUs
+ VLLM_BENCH_MODEL="deepseek-ai/DeepSeek-V3" VLLM_BENCH_TP_SIZE=8 \\
+ python benchmarks/benchmark_batch_invariance.py
+
+ # Quick test with fewer trials
+ VLLM_BENCH_NUM_TRIALS=2 VLLM_BENCH_BATCH_SIZE=32 \\
+ python benchmarks/benchmark_batch_invariance.py
+"""
+
+import contextlib
+import os
+import random
+import time
+
+from vllm import LLM, SamplingParams
+from vllm.platforms import current_platform
+
+
+def _random_prompt(min_words: int = 1024, max_words: int = 1024 * 2) -> str:
+ """Generate a random prompt for benchmarking."""
+ prompt_templates = [
+ "Question: What is the capital of France?\nAnswer: The capital of France is",
+ "Q: How does photosynthesis work?\nA: Photosynthesis is the process by which",
+ "User: Can you explain quantum mechanics?\nAssistant: Quantum mechanics is",
+ "Once upon a time in a distant galaxy, there lived",
+ "The old man walked slowly down the street, remembering",
+ "In the year 2157, humanity finally discovered",
+ "To implement a binary search tree in Python, first we need to",
+ "The algorithm works by iterating through the array and",
+ "Here's how to optimize database queries using indexing:",
+ "The Renaissance was a period in European history that",
+ "Climate change is caused by several factors including",
+ "The human brain contains approximately 86 billion neurons which",
+ "I've been thinking about getting a new laptop because",
+ "Yesterday I went to the store and bought",
+ "My favorite thing about summer is definitely",
+ ]
+
+ base_prompt = random.choice(prompt_templates)
+
+ if max_words < min_words:
+ max_words = min_words
+ target_words = random.randint(min_words, max_words)
+
+ if target_words > 50:
+ padding_text = (
+ " This is an interesting topic that deserves more explanation. "
+ * (target_words // 50)
+ )
+ base_prompt = base_prompt + padding_text
+
+ return base_prompt
+
+
+def run_benchmark_with_batch_invariant(
+ model: str,
+ tp_size: int,
+ max_batch_size: int,
+ num_trials: int,
+ min_prompt: int,
+ max_prompt: int,
+ max_tokens: int,
+ temperature: float,
+ gpu_mem_util: float,
+ max_model_len: int,
+ backend: str,
+ batch_invariant: bool,
+ seed: int = 12345,
+) -> dict:
+ """
+ Run the benchmark with the specified configuration.
+
+ Returns a dict with timing and throughput metrics.
+ """
+ random.seed(seed)
+
+ # Set environment variables
+ os.environ["VLLM_ATTENTION_BACKEND"] = backend
+ if batch_invariant:
+ os.environ["VLLM_BATCH_INVARIANT"] = "1"
+ else:
+ os.environ["VLLM_BATCH_INVARIANT"] = "0"
+
+ print(f"\n{'=' * 80}")
+ print(f"BENCHMARK: VLLM_BATCH_INVARIANT={int(batch_invariant)}")
+ print(f" Model: {model}")
+ print(f" TP Size: {tp_size}")
+ print(f" Backend: {backend}")
+ print(f" Max Batch Size: {max_batch_size}")
+ print(f" Trials: {num_trials}")
+ print(f" Max Tokens: {max_tokens}")
+ print(f"{'=' * 80}\n")
+
+ sampling = SamplingParams(
+ temperature=temperature,
+ top_p=0.95,
+ max_tokens=max_tokens,
+ seed=20240919,
+ )
+
+ needle_prompt = "There once was a "
+
+ llm = None
+ try:
+ # Create LLM engine
+ start_init = time.perf_counter()
+ llm = LLM(
+ model=model,
+ max_num_seqs=max_batch_size,
+ gpu_memory_utilization=gpu_mem_util,
+ max_model_len=max_model_len,
+ dtype="bfloat16",
+ tensor_parallel_size=tp_size,
+ enable_prefix_caching=False,
+ )
+ init_time = time.perf_counter() - start_init
+ print(f"Engine initialization time: {init_time:.2f}s\n")
+
+ # Generate baseline
+ print("Generating baseline (warmup)...")
+ baseline_out = llm.generate([needle_prompt], sampling)
+ assert len(baseline_out) == 1
+ baseline_text = baseline_out[0].outputs[0].text
+ print(f"Baseline output: '{baseline_text[:50]}...'\n")
+
+ # Run trials and measure timing
+ trial_times: list[float] = []
+ total_tokens = 0
+ total_prompts = 0
+
+ for trial in range(num_trials):
+ # Create a batch
+ prompts: list[str] = []
+ batch_size = random.randint(max_batch_size // 2, max_batch_size)
+ needle_pos = random.randint(0, batch_size - 1)
+ for i in range(batch_size):
+ if i == needle_pos:
+ prompts.append(needle_prompt)
+ else:
+ prompts.append(_random_prompt(min_prompt, max_prompt))
+
+ # Measure time for this trial
+ start_time = time.perf_counter()
+ outputs = llm.generate(prompts, sampling)
+ trial_time = time.perf_counter() - start_time
+
+ trial_times.append(trial_time)
+ total_prompts += len(prompts)
+
+ # Count tokens
+ for output in outputs:
+ if output.outputs:
+ total_tokens += len(output.outputs[0].token_ids)
+
+ print(
+ f"Trial {trial + 1}/{num_trials}: "
+ f"batch_size={batch_size}, "
+ f"time={trial_time:.2f}s"
+ )
+
+ # Verify needle output still matches
+ needle_output = outputs[needle_pos]
+ assert needle_output.prompt == needle_prompt
+
+ # Compute statistics
+ avg_time = sum(trial_times) / len(trial_times)
+ min_time = min(trial_times)
+ max_time = max(trial_times)
+ throughput = total_tokens / sum(trial_times)
+ prompts_per_sec = total_prompts / sum(trial_times)
+
+ print(f"\n{'=' * 80}")
+ print("RESULTS:")
+ print(f" Average time per trial: {avg_time:.2f}s")
+ print(f" Min time: {min_time:.2f}s")
+ print(f" Max time: {max_time:.2f}s")
+ print(f" Total tokens generated: {total_tokens}")
+ print(f" Total prompts processed: {total_prompts}")
+ print(f" Throughput: {throughput:.2f} tokens/s")
+ print(f" Prompts/s: {prompts_per_sec:.2f}")
+ print(f"{'=' * 80}\n")
+
+ return {
+ "init_time": init_time,
+ "avg_time": avg_time,
+ "min_time": min_time,
+ "max_time": max_time,
+ "total_tokens": total_tokens,
+ "total_prompts": total_prompts,
+ "throughput": throughput,
+ "prompts_per_sec": prompts_per_sec,
+ "trial_times": trial_times,
+ }
+
+ finally:
+ # Cleanup
+ if llm is not None:
+ with contextlib.suppress(Exception):
+ llm.shutdown()
+
+
+def main():
+ # Check platform support
+ if not (current_platform.is_cuda() and current_platform.has_device_capability(90)):
+ print("ERROR: Requires CUDA and >= Hopper (SM90)")
+ print(f"Current platform: {current_platform.device_type}")
+ if current_platform.is_cuda():
+ print(f"Device capability: {current_platform.get_device_capability()}")
+ return 1
+
+ # Read configuration from environment
+ model = os.getenv("VLLM_BENCH_MODEL", "Qwen/Qwen3-1.7B")
+ tp_size = int(os.getenv("VLLM_BENCH_TP_SIZE", "1"))
+ max_batch_size = int(os.getenv("VLLM_BENCH_BATCH_SIZE", "128"))
+ num_trials = int(os.getenv("VLLM_BENCH_NUM_TRIALS", "5"))
+ min_prompt = int(os.getenv("VLLM_BENCH_MIN_PROMPT", "1024"))
+ max_prompt = int(os.getenv("VLLM_BENCH_MAX_PROMPT", "2048"))
+ max_tokens = int(os.getenv("VLLM_BENCH_MAX_TOKENS", "128"))
+ temperature = float(os.getenv("VLLM_BENCH_TEMPERATURE", "0.0"))
+ gpu_mem_util = float(os.getenv("VLLM_BENCH_GPU_MEMORY_UTILIZATION", "0.4"))
+ max_model_len = int(os.getenv("VLLM_BENCH_MAX_MODEL_LEN", "5120"))
+ backend = os.getenv("VLLM_BENCH_BACKEND", "FLASH_ATTN")
+
+ print("\n" + "=" * 80)
+ print("VLLM BATCH INVARIANCE BENCHMARK")
+ print("=" * 80)
+ print("\nConfiguration:")
+ print(f" Model: {model}")
+ print(f" Tensor Parallel Size: {tp_size}")
+ print(f" Attention Backend: {backend}")
+ print(f" Max Batch Size: {max_batch_size}")
+ print(f" Number of Trials: {num_trials}")
+ print(f" Prompt Length Range: {min_prompt}-{max_prompt} words")
+ print(f" Max Tokens to Generate: {max_tokens}")
+ print(f" Temperature: {temperature}")
+ print(f" GPU Memory Utilization: {gpu_mem_util}")
+ print(f" Max Model Length: {max_model_len}")
+ print("=" * 80)
+
+ # Run benchmark WITHOUT batch invariance (baseline)
+ print("\n" + "=" * 80)
+ print("PHASE 1: Running WITHOUT batch invariance (baseline)")
+ print("=" * 80)
+ baseline_results = run_benchmark_with_batch_invariant(
+ model=model,
+ tp_size=tp_size,
+ max_batch_size=max_batch_size,
+ num_trials=num_trials,
+ min_prompt=min_prompt,
+ max_prompt=max_prompt,
+ max_tokens=max_tokens,
+ temperature=temperature,
+ gpu_mem_util=gpu_mem_util,
+ max_model_len=max_model_len,
+ backend=backend,
+ batch_invariant=False,
+ )
+
+ # Run benchmark WITH batch invariance
+ print("\n" + "=" * 80)
+ print("PHASE 2: Running WITH batch invariance")
+ print("=" * 80)
+ batch_inv_results = run_benchmark_with_batch_invariant(
+ model=model,
+ tp_size=tp_size,
+ max_batch_size=max_batch_size,
+ num_trials=num_trials,
+ min_prompt=min_prompt,
+ max_prompt=max_prompt,
+ max_tokens=max_tokens,
+ temperature=temperature,
+ gpu_mem_util=gpu_mem_util,
+ max_model_len=max_model_len,
+ backend=backend,
+ batch_invariant=True,
+ )
+
+ # Compare results
+ print("\n" + "=" * 80)
+ print("COMPARISON: Batch Invariance vs Baseline")
+ print("=" * 80)
+
+ init_overhead_pct = (
+ (batch_inv_results["init_time"] - baseline_results["init_time"])
+ / baseline_results["init_time"]
+ * 100
+ )
+ time_overhead_pct = (
+ (batch_inv_results["avg_time"] - baseline_results["avg_time"])
+ / baseline_results["avg_time"]
+ * 100
+ )
+ throughput_change_pct = (
+ (batch_inv_results["throughput"] - baseline_results["throughput"])
+ / baseline_results["throughput"]
+ * 100
+ )
+
+ print("\nInitialization Time:")
+ print(f" Baseline: {baseline_results['init_time']:.2f}s")
+ print(f" Batch Invariant: {batch_inv_results['init_time']:.2f}s")
+ print(f" Overhead: {init_overhead_pct:+.2f}%")
+
+ print("\nAverage Trial Time:")
+ print(f" Baseline: {baseline_results['avg_time']:.2f}s")
+ print(f" Batch Invariant: {batch_inv_results['avg_time']:.2f}s")
+ print(f" Overhead: {time_overhead_pct:+.2f}%")
+
+ print("\nThroughput (tokens/s):")
+ print(f" Baseline: {baseline_results['throughput']:.2f}")
+ print(f" Batch Invariant: {batch_inv_results['throughput']:.2f}")
+ print(f" Change: {throughput_change_pct:+.2f}%")
+
+ print("\nPrompts/s:")
+ print(f" Baseline: {baseline_results['prompts_per_sec']:.2f}")
+ print(f" Batch Invariant: {batch_inv_results['prompts_per_sec']:.2f}")
+
+ print("\n" + "=" * 80)
+ print("SUMMARY")
+ print("=" * 80)
+ if time_overhead_pct > 0:
+ print(
+ f"Batch invariance mode adds approximately {time_overhead_pct:.1f}% "
+ "overhead"
+ )
+ else:
+ print(
+ f"Batch invariance mode is approximately {-time_overhead_pct:.1f}% "
+ "faster (unexpected!)"
+ )
+
+ if abs(throughput_change_pct) < 1.0:
+ print("Throughput difference is negligible (< 1%)")
+ elif throughput_change_pct < 0:
+ print(
+ f"Throughput decreased by {-throughput_change_pct:.1f}% "
+ "with batch invariance"
+ )
+ else:
+ print(
+ f"Throughput increased by {throughput_change_pct:.1f}% "
+ "with batch invariance (unexpected!)"
+ )
+
+ print("=" * 80 + "\n")
+
+ return 0
+
+
+if __name__ == "__main__":
+ exit(main())
diff --git a/benchmarks/benchmark_prefix_caching.py b/benchmarks/benchmark_prefix_caching.py
index 146c268a6b7f..28fc383a318d 100644
--- a/benchmarks/benchmark_prefix_caching.py
+++ b/benchmarks/benchmark_prefix_caching.py
@@ -69,7 +69,7 @@ def sample_tokens(tokenizer: PreTrainedTokenizerBase, length: int) -> list[int]:
# Remove the special tokens.
return random.choices(
- [v for k, v in vocab.items() if k not in all_special_ids],
+ [v for v in vocab.values() if v not in all_special_ids],
k=length,
)
diff --git a/benchmarks/disagg_benchmarks/disagg_prefill_proxy_server.py b/benchmarks/disagg_benchmarks/disagg_prefill_proxy_server.py
index 904f80534914..d072c03c440b 100644
--- a/benchmarks/disagg_benchmarks/disagg_prefill_proxy_server.py
+++ b/benchmarks/disagg_benchmarks/disagg_prefill_proxy_server.py
@@ -5,11 +5,12 @@
import asyncio
import logging
import os
+import time
+import uuid
+from urllib.parse import urlparse
import aiohttp
from quart import Quart, Response, make_response, request
-from rate_limiter import RateLimiter
-from request_queue import RequestQueue
# Configure logging
logging.basicConfig(level=logging.INFO)
@@ -24,26 +25,8 @@ def parse_args():
parser.add_argument(
"--timeout",
type=float,
- default=300,
- help="Timeout for backend service requests in seconds (default: 300)",
- )
- parser.add_argument(
- "--max-concurrent",
- type=int,
- default=100,
- help="Maximum concurrent requests to backend services (default: 100)",
- )
- parser.add_argument(
- "--queue-size",
- type=int,
- default=500,
- help="Maximum number of requests in the queue (default: 500)",
- )
- parser.add_argument(
- "--rate-limit",
- type=int,
- default=40,
- help="Maximum requests per second (default: 40)",
+ default=6 * 60 * 60,
+ help="Timeout for backend service requests in seconds (default: 21600)",
)
parser.add_argument(
"--port",
@@ -54,14 +37,32 @@ def parse_args():
parser.add_argument(
"--prefill-url",
type=str,
- default="http://localhost:8100/v1/completions",
- help="Prefill service endpoint URL",
+ default="http://localhost:8100",
+ help="Prefill service base URL (protocol + host[:port])",
)
parser.add_argument(
"--decode-url",
type=str,
- default="http://localhost:8200/v1/completions",
- help="Decode service endpoint URL",
+ default="http://localhost:8200",
+ help="Decode service base URL (protocol + host[:port])",
+ )
+ parser.add_argument(
+ "--kv-host",
+ type=str,
+ default="localhost",
+ help="Hostname or IP used by KV transfer (default: localhost)",
+ )
+ parser.add_argument(
+ "--prefill-kv-port",
+ type=int,
+ default=14579,
+ help="Prefill KV port (default: 14579)",
+ )
+ parser.add_argument(
+ "--decode-kv-port",
+ type=int,
+ default=14580,
+ help="Decode KV port (default: 14580)",
)
return parser.parse_args()
@@ -73,70 +74,129 @@ def main():
# Initialize configuration using command line parameters
AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=args.timeout)
- MAX_CONCURRENT_REQUESTS = args.max_concurrent
- REQUEST_QUEUE_SIZE = args.queue_size
- RATE_LIMIT = args.rate_limit
PREFILL_SERVICE_URL = args.prefill_url
DECODE_SERVICE_URL = args.decode_url
PORT = args.port
- app = Quart(__name__)
+ PREFILL_KV_ADDR = f"{args.kv_host}:{args.prefill_kv_port}"
+ DECODE_KV_ADDR = f"{args.kv_host}:{args.decode_kv_port}"
- # Initialize the rate limiter and request queue
- rate_limiter = RateLimiter(RATE_LIMIT)
- request_queue = RequestQueue(MAX_CONCURRENT_REQUESTS, REQUEST_QUEUE_SIZE)
+ logger.info(
+ "Proxy resolved KV addresses -> prefill: %s, decode: %s",
+ PREFILL_KV_ADDR,
+ DECODE_KV_ADDR,
+ )
+
+ app = Quart(__name__)
- # Attach the configuration object to the application instance
+ # Attach the configuration object to the application instance so helper
+ # coroutines can read the resolved backend URLs and timeouts without using
+ # globals.
app.config.update(
{
"AIOHTTP_TIMEOUT": AIOHTTP_TIMEOUT,
- "rate_limiter": rate_limiter,
- "request_queue": request_queue,
"PREFILL_SERVICE_URL": PREFILL_SERVICE_URL,
"DECODE_SERVICE_URL": DECODE_SERVICE_URL,
+ "PREFILL_KV_ADDR": PREFILL_KV_ADDR,
+ "DECODE_KV_ADDR": DECODE_KV_ADDR,
}
)
- # Start queue processing on app startup
- @app.before_serving
- async def startup():
- """Start request processing task when app starts serving"""
- asyncio.create_task(request_queue.process())
-
- async def forward_request(url, data):
- """Forward request to backend service with rate limiting and error handling"""
- headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"}
-
- # Use rate limiter as context manager
- async with (
- rate_limiter,
- aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session,
- ):
- try:
- async with session.post(
- url=url, json=data, headers=headers
- ) as response:
- if response.status == 200:
- # Stream response chunks
- async for chunk_bytes in response.content.iter_chunked(1024):
- yield chunk_bytes
- else:
- # Handle backend service errors
- error_text = await response.text()
- logger.error(
- "Backend service error: %s - %s",
- response.status,
- error_text,
- )
- yield b'{"error": "Backend service error"}'
- except aiohttp.ClientError as e:
- # Handle connection errors
- logger.error("Connection error to %s: %s", url, str(e))
- yield b'{"error": "Service unavailable"}'
- except asyncio.TimeoutError:
- # Handle timeout errors
- logger.error("Timeout connecting to %s", url)
- yield b'{"error": "Service timeout"}'
+ def _normalize_base_url(url: str) -> str:
+ """Remove any trailing slash so path joins behave predictably."""
+ return url.rstrip("/")
+
+ def _get_host_port(url: str) -> str:
+ """Return the hostname:port portion for logging and KV headers."""
+ parsed = urlparse(url)
+ host = parsed.hostname or "localhost"
+ port = parsed.port
+ if port is None:
+ port = 80 if parsed.scheme == "http" else 443
+ return f"{host}:{port}"
+
+ PREFILL_BASE = _normalize_base_url(PREFILL_SERVICE_URL)
+ DECODE_BASE = _normalize_base_url(DECODE_SERVICE_URL)
+ KV_TARGET = _get_host_port(DECODE_SERVICE_URL)
+
+ def _build_headers(request_id: str) -> dict[str, str]:
+ """Construct the headers expected by vLLM's P2P disagg connector."""
+ headers: dict[str, str] = {"X-Request-Id": request_id, "X-KV-Target": KV_TARGET}
+ api_key = os.environ.get("OPENAI_API_KEY")
+ if api_key:
+ headers["Authorization"] = f"Bearer {api_key}"
+ return headers
+
+ async def _run_prefill(
+ request_path: str,
+ payload: dict,
+ headers: dict[str, str],
+ request_id: str,
+ ):
+ url = f"{PREFILL_BASE}{request_path}"
+ start_ts = time.perf_counter()
+ logger.info("[prefill] start request_id=%s url=%s", request_id, url)
+ try:
+ async with (
+ aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session,
+ session.post(url=url, json=payload, headers=headers) as resp,
+ ):
+ if resp.status != 200:
+ error_text = await resp.text()
+ raise RuntimeError(
+ f"Prefill backend error {resp.status}: {error_text}"
+ )
+ await resp.read()
+ logger.info(
+ "[prefill] done request_id=%s status=%s elapsed=%.2fs",
+ request_id,
+ resp.status,
+ time.perf_counter() - start_ts,
+ )
+ except asyncio.TimeoutError as exc:
+ raise RuntimeError(f"Prefill service timeout at {url}") from exc
+ except aiohttp.ClientError as exc:
+ raise RuntimeError(f"Prefill service unavailable at {url}") from exc
+
+ async def _stream_decode(
+ request_path: str,
+ payload: dict,
+ headers: dict[str, str],
+ request_id: str,
+ ):
+ url = f"{DECODE_BASE}{request_path}"
+ # Stream tokens from the decode service once the prefill stage has
+ # materialized KV caches on the target workers.
+ logger.info("[decode] start request_id=%s url=%s", request_id, url)
+ try:
+ async with (
+ aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session,
+ session.post(url=url, json=payload, headers=headers) as resp,
+ ):
+ if resp.status != 200:
+ error_text = await resp.text()
+ logger.error(
+ "Decode backend error %s - %s", resp.status, error_text
+ )
+ err_msg = (
+ '{"error": "Decode backend error ' + str(resp.status) + '"}'
+ )
+ yield err_msg.encode()
+ return
+ logger.info(
+ "[decode] streaming response request_id=%s status=%s",
+ request_id,
+ resp.status,
+ )
+ async for chunk_bytes in resp.content.iter_chunked(1024):
+ yield chunk_bytes
+ logger.info("[decode] finished streaming request_id=%s", request_id)
+ except asyncio.TimeoutError:
+ logger.error("Decode service timeout at %s", url)
+ yield b'{"error": "Decode service timeout"}'
+ except aiohttp.ClientError as exc:
+ logger.error("Decode service error at %s: %s", url, exc)
+ yield b'{"error": "Decode service unavailable"}'
async def process_request():
"""Process a single request through prefill and decode stages"""
@@ -146,13 +206,27 @@ async def process_request():
# Create prefill request (max_tokens=1)
prefill_request = original_request_data.copy()
prefill_request["max_tokens"] = 1
+ if "max_completion_tokens" in prefill_request:
+ prefill_request["max_completion_tokens"] = 1
# Execute prefill stage
- async for _ in forward_request(PREFILL_SERVICE_URL, prefill_request):
- continue
+ # The request id encodes both KV socket addresses so the backend can
+ # shuttle tensors directly via NCCL once the prefill response
+ # completes.
+ request_id = (
+ f"___prefill_addr_{PREFILL_KV_ADDR}___decode_addr_"
+ f"{DECODE_KV_ADDR}_{uuid.uuid4().hex}"
+ )
+
+ headers = _build_headers(request_id)
+ await _run_prefill(request.path, prefill_request, headers, request_id)
# Execute decode stage and stream response
- generator = forward_request(DECODE_SERVICE_URL, original_request_data)
+ # Pass the unmodified user request so the decode phase can continue
+ # sampling with the already-populated KV cache.
+ generator = _stream_decode(
+ request.path, original_request_data, headers, request_id
+ )
response = await make_response(generator)
response.timeout = None # Disable timeout for streaming response
return response
@@ -168,23 +242,10 @@ async def process_request():
@app.route("/v1/completions", methods=["POST"])
async def handle_request():
"""Handle incoming API requests with concurrency and rate limiting"""
- # Create task for request processing
- task = asyncio.create_task(process_request())
-
- # Enqueue request or reject if queue is full
- if not await request_queue.enqueue(task):
- return Response(
- response=b'{"error": "Server busy, try again later"}',
- status=503,
- content_type="application/json",
- )
-
try:
- # Return the response from the processing task
- return await task
+ return await process_request()
except asyncio.CancelledError:
- # Handle task cancellation (timeout or queue full)
- logger.warning("Request cancelled due to timeout or queue full")
+ logger.warning("Request cancelled")
return Response(
response=b'{"error": "Request cancelled"}',
status=503,
diff --git a/benchmarks/kernels/benchmark_cutlass_moe_fp8.py b/benchmarks/kernels/benchmark_cutlass_moe_fp8.py
index 027f67ad4db6..e07d6c776bc0 100644
--- a/benchmarks/kernels/benchmark_cutlass_moe_fp8.py
+++ b/benchmarks/kernels/benchmark_cutlass_moe_fp8.py
@@ -255,8 +255,8 @@ def bench_cuda_graph(graph, num_warmup=5, num_iters=100):
torch.cuda.synchronize()
# Timing
- start_event = torch.cuda.Event(enable_timing=True)
- end_event = torch.cuda.Event(enable_timing=True)
+ start_event = torch.Event(enable_timing=True)
+ end_event = torch.Event(enable_timing=True)
latencies = []
for _ in range(num_iters):
diff --git a/benchmarks/kernels/benchmark_machete.py b/benchmarks/kernels/benchmark_machete.py
index 8787724d77cf..ac78c019a59e 100644
--- a/benchmarks/kernels/benchmark_machete.py
+++ b/benchmarks/kernels/benchmark_machete.py
@@ -237,6 +237,7 @@ def marlin_create_bench_fn(bt: BenchmarkTensors) -> Callable:
b_q_weight=w_q,
b_bias=None,
b_scales=w_s,
+ a_scales=None,
global_scale=None,
b_zeros=w_zp,
g_idx=g_idx,
diff --git a/benchmarks/kernels/benchmark_marlin.py b/benchmarks/kernels/benchmark_marlin.py
index 12ca9214b1f9..48d790aec9e0 100644
--- a/benchmarks/kernels/benchmark_marlin.py
+++ b/benchmarks/kernels/benchmark_marlin.py
@@ -263,7 +263,7 @@ def gen_allspark_params():
results.append(
benchmark.Timer(
- stmt="output = gptq_marlin_gemm(a, None, marlin_q_w, marlin_s, marlin_s2, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, False, False)", # noqa: E501
+ stmt="output = gptq_marlin_gemm(a, None, marlin_q_w, marlin_s, None, marlin_s2, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, False, False)", # noqa: E501
globals=globals,
label=label,
sub_label=sub_label,
@@ -273,7 +273,7 @@ def gen_allspark_params():
results.append(
benchmark.Timer(
- stmt="output = gptq_marlin_gemm(a, None, marlin_q_w, marlin_s, marlin_s2, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, True, False)", # noqa: E501
+ stmt="output = gptq_marlin_gemm(a, None, marlin_q_w, marlin_s, None, marlin_s2, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, True, False)", # noqa: E501
globals=globals,
label=label,
sub_label=sub_label,
diff --git a/benchmarks/kernels/benchmark_moe.py b/benchmarks/kernels/benchmark_moe.py
index c99951aa2782..a1af0b8aec3d 100644
--- a/benchmarks/kernels/benchmark_moe.py
+++ b/benchmarks/kernels/benchmark_moe.py
@@ -185,8 +185,8 @@ def run():
graph.replay()
torch.cuda.synchronize()
- start_event = torch.cuda.Event(enable_timing=True)
- end_event = torch.cuda.Event(enable_timing=True)
+ start_event = torch.Event(enable_timing=True)
+ end_event = torch.Event(enable_timing=True)
latencies: list[float] = []
for i in range(num_iters):
diff --git a/benchmarks/kernels/benchmark_moe_permute_unpermute.py b/benchmarks/kernels/benchmark_moe_permute_unpermute.py
index efa5a7386027..b8913a217c60 100644
--- a/benchmarks/kernels/benchmark_moe_permute_unpermute.py
+++ b/benchmarks/kernels/benchmark_moe_permute_unpermute.py
@@ -105,8 +105,8 @@ def run():
graph.replay()
torch.cuda.synchronize()
- start_event = torch.cuda.Event(enable_timing=True)
- end_event = torch.cuda.Event(enable_timing=True)
+ start_event = torch.Event(enable_timing=True)
+ end_event = torch.Event(enable_timing=True)
latencies: list[float] = []
for i in range(num_iters):
@@ -241,8 +241,8 @@ def run(input: tuple):
graph.replay()
torch.cuda.synchronize()
- start_event = torch.cuda.Event(enable_timing=True)
- end_event = torch.cuda.Event(enable_timing=True)
+ start_event = torch.Event(enable_timing=True)
+ end_event = torch.Event(enable_timing=True)
latencies: list[float] = []
for i in range(num_iters):
diff --git a/benchmarks/kernels/benchmark_mrope.py b/benchmarks/kernels/benchmark_mrope.py
index cb848d2bf579..83bd91917508 100644
--- a/benchmarks/kernels/benchmark_mrope.py
+++ b/benchmarks/kernels/benchmark_mrope.py
@@ -6,7 +6,7 @@
#
# The CSV file (named with current date/time) contains these columns:
# model_name, tp_size, num_tokens, num_heads, num_kv_heads, head_dim, max_position,
-# rope_theta, is_neox_style, rope_scaling, dtype, torch_mean, torch_median, torch_p99,
+# is_neox_style, rope_parameters, dtype, torch_mean, torch_median, torch_p99,
# torch_min, torch_max, triton_mean, triton_median, triton_p99, triton_min, triton_max,
# speedup
#
@@ -86,9 +86,8 @@ def benchmark_mrope(
num_heads: int,
num_kv_heads: int,
max_position: int = 8192,
- rope_theta: float = 10000,
is_neox_style: bool = True,
- rope_scaling: dict[str, Any] = None,
+ rope_parameters: dict[str, Any] | None = None,
dtype: torch.dtype = torch.bfloat16,
seed: int = 0,
warmup_iter: int = 10,
@@ -102,9 +101,8 @@ def benchmark_mrope(
head_size=head_dim,
rotary_dim=head_dim,
max_position=max_position,
- base=rope_theta,
is_neox_style=is_neox_style,
- rope_scaling=rope_scaling,
+ rope_parameters=rope_parameters,
dtype=dtype,
).to(device=device)
@@ -203,9 +201,8 @@ def benchmark_mrope(
num_kv_heads,
head_dim,
max_position,
- rope_theta,
is_neox_style,
- str(rope_scaling),
+ str(rope_parameters),
str(dtype).split(".")[-1],
torch_stats["mean"],
torch_stats["median"],
@@ -255,9 +252,8 @@ def benchmark_mrope(
"num_kv_heads",
"head_dim",
"max_position",
- "rope_theta",
"is_neox_style",
- "rope_scaling",
+ "rope_parameters",
"dtype",
"torch_mean",
"torch_median",
@@ -303,7 +299,7 @@ def benchmark_mrope(
q_size = num_heads * head_dim
kv_size = num_kv_heads * head_dim
is_neox_style = True
- rope_theta = config.rope_theta
+ rope_parameters = config.rope_parameters
max_position = config.max_position_embeddings
for num_tokens in num_tokens_list:
@@ -315,9 +311,8 @@ def benchmark_mrope(
num_heads=num_heads,
num_kv_heads=num_kv_heads,
max_position=max_position,
- rope_theta=rope_theta,
is_neox_style=is_neox_style,
- rope_scaling=config.rope_scaling,
+ rope_parameters=rope_parameters,
dtype=getattr(torch, args.dtype),
seed=args.seed,
warmup_iter=args.warmup_iter,
diff --git a/benchmarks/kernels/benchmark_per_token_group_quant.py b/benchmarks/kernels/benchmark_per_token_group_quant.py
index bdc1eb733084..eba4d510258b 100644
--- a/benchmarks/kernels/benchmark_per_token_group_quant.py
+++ b/benchmarks/kernels/benchmark_per_token_group_quant.py
@@ -30,8 +30,8 @@ def _time_cuda(
fn()
torch.cuda.synchronize()
- start = torch.cuda.Event(enable_timing=True)
- end = torch.cuda.Event(enable_timing=True)
+ start = torch.Event(enable_timing=True)
+ end = torch.Event(enable_timing=True)
start.record()
for _ in range(bench_iters):
diff --git a/benchmarks/kernels/benchmark_silu_mul_fp8_quant.py b/benchmarks/kernels/benchmark_silu_mul_fp8_quant.py
index a5887aafd30d..de01ff197eab 100644
--- a/benchmarks/kernels/benchmark_silu_mul_fp8_quant.py
+++ b/benchmarks/kernels/benchmark_silu_mul_fp8_quant.py
@@ -253,8 +253,8 @@ def generate_expert_loads(n_e, total_tokens, ratio, device="cuda"):
)
torch.cuda.synchronize()
- start_event = torch.cuda.Event(enable_timing=True)
- end_event = torch.cuda.Event(enable_timing=True)
+ start_event = torch.Event(enable_timing=True)
+ end_event = torch.Event(enable_timing=True)
# Benchmark
latencies: list[float] = []
diff --git a/benchmarks/kernels/benchmark_trtllm_decode_attention.py b/benchmarks/kernels/benchmark_trtllm_decode_attention.py
index 29ce18234dfa..1d0d6fbb9a47 100644
--- a/benchmarks/kernels/benchmark_trtllm_decode_attention.py
+++ b/benchmarks/kernels/benchmark_trtllm_decode_attention.py
@@ -127,8 +127,8 @@ def benchmark_decode(
def time_fn(fn, warmup=10, trials=20):
torch.cuda.synchronize()
- start = torch.cuda.Event(enable_timing=True)
- end = torch.cuda.Event(enable_timing=True)
+ start = torch.Event(enable_timing=True)
+ end = torch.Event(enable_timing=True)
times = []
for i in range(warmup):
fn()
diff --git a/benchmarks/kernels/benchmark_trtllm_prefill_attention.py b/benchmarks/kernels/benchmark_trtllm_prefill_attention.py
index 2a25d0374811..84bde723abf7 100644
--- a/benchmarks/kernels/benchmark_trtllm_prefill_attention.py
+++ b/benchmarks/kernels/benchmark_trtllm_prefill_attention.py
@@ -139,8 +139,8 @@ def benchmark_prefill(
def time_fn(fn, warmup=10, trials=20):
torch.cuda.synchronize()
- start = torch.cuda.Event(enable_timing=True)
- end = torch.cuda.Event(enable_timing=True)
+ start = torch.Event(enable_timing=True)
+ end = torch.Event(enable_timing=True)
times = []
for i in range(warmup):
fn()
diff --git a/benchmarks/kernels/benchmark_w8a8_block_fp8.py b/benchmarks/kernels/benchmark_w8a8_block_fp8.py
index ab54f81985bc..b52500c8c521 100644
--- a/benchmarks/kernels/benchmark_w8a8_block_fp8.py
+++ b/benchmarks/kernels/benchmark_w8a8_block_fp8.py
@@ -183,8 +183,8 @@ def run():
run()
torch.cuda.synchronize()
- start_event = torch.cuda.Event(enable_timing=True)
- end_event = torch.cuda.Event(enable_timing=True)
+ start_event = torch.Event(enable_timing=True)
+ end_event = torch.Event(enable_timing=True)
latencies: list[float] = []
for i in range(num_iters):
diff --git a/benchmarks/kernels/deepgemm/README.md b/benchmarks/kernels/deepgemm/README.md
index 41e68e047be8..a28c6956be0e 100644
--- a/benchmarks/kernels/deepgemm/README.md
+++ b/benchmarks/kernels/deepgemm/README.md
@@ -2,7 +2,7 @@
This directory includes benchmarks between DeepSeek's DeepGEMM block fp8 kernels against vLLM's existing triton and CUTLASS-based kernels.
-Currently this just includes dense GEMMs and only works on Hopper GPUs.
+Currently, this just includes dense GEMMs and only works on Hopper GPUs.
## Setup
diff --git a/benchmarks/multi_turn/README.md b/benchmarks/multi_turn/README.md
index f5b5c6c97d48..b0be1e3a69a6 100644
--- a/benchmarks/multi_turn/README.md
+++ b/benchmarks/multi_turn/README.md
@@ -55,6 +55,10 @@ output_num_chunks 166.0 99.01 11.80 79.00 90.00 98.00 108.75
----------------------------------------------------------------------------------------------------
```
+If you run with `--warmup-step`, the summary will also include `warmup_runtime_sec`
+and `total_runtime_incl_warmup_sec` (while `runtime_sec` continues to reflect the
+benchmark-only runtime so the reported throughput stays comparable).
+
### JSON configuration file for synthetic conversations generation
The input flag `--input-file` is used to determine the input conversations for the benchmark.
diff --git a/benchmarks/multi_turn/benchmark_serving_multi_turn.py b/benchmarks/multi_turn/benchmark_serving_multi_turn.py
index ae9e9753441a..e23f6b923f1b 100644
--- a/benchmarks/multi_turn/benchmark_serving_multi_turn.py
+++ b/benchmarks/multi_turn/benchmark_serving_multi_turn.py
@@ -561,8 +561,11 @@ async def client_main(
f"{Color.CYAN}Started client {client_id}: max_num_requests={args.max_num_requests}, max_active_conversations={args.max_active_conversations}{Color.RESET}" # noqa: E501
)
- random.seed(args.seed)
- np.random.seed(args.seed)
+ # Set unique seed per client (each client runs in its own process)
+ # Add 1 to ensure no client uses the same seed as the main process
+ client_seed = args.seed + client_id + 1
+ random.seed(client_seed)
+ np.random.seed(client_seed)
# Active conversations
active_convs: ConversationsMap = {}
@@ -1073,6 +1076,7 @@ def process_statistics(
verbose: bool,
gen_conv_args: GenConvArgs | None = None,
excel_output: bool = False,
+ warmup_runtime_sec: float | None = None,
) -> None:
if len(client_metrics) == 0:
logger.info("No samples to process")
@@ -1166,8 +1170,13 @@ def process_statistics(
# Convert milliseconds to seconds
runtime_sec = runtime_sec / 1000.0
requests_per_sec = float(len(df)) / runtime_sec
-
- params = {"runtime_sec": runtime_sec, "requests_per_sec": requests_per_sec}
+ params = {
+ "runtime_sec": runtime_sec,
+ "requests_per_sec": requests_per_sec,
+ }
+ if warmup_runtime_sec is not None:
+ params["warmup_runtime_sec"] = warmup_runtime_sec
+ params["total_runtime_incl_warmup_sec"] = runtime_sec + warmup_runtime_sec
# Generate a summary of relevant metrics (and drop irrelevant data)
df = df.drop(columns=exclude).describe(percentiles=percentiles).transpose()
@@ -1490,6 +1499,7 @@ async def main() -> None:
f"Invalid --warmup-percentage={args.warmup_percentage}"
) from None
+ # Set global seeds for main process
random.seed(args.seed)
np.random.seed(args.seed)
@@ -1548,6 +1558,8 @@ async def main() -> None:
url=args.url, num_clients=args.num_clients, early_stop=not args.no_early_stop
)
+ warmup_runtime_sec: float | None = None
+
# Warm-up step
if args.warmup_step:
# Only send a single user prompt from every conversation.
@@ -1562,26 +1574,56 @@ async def main() -> None:
# all clients should finish their work before exiting
warmup_bench_args = bench_args._replace(early_stop=False)
- logger.info(f"{Color.PURPLE}Warmup start{Color.RESET}")
+ logger.info("%sWarmup start%s", Color.PURPLE, Color.RESET)
+ warmup_start_ns = time.perf_counter_ns()
conversations, _ = await main_mp(
warmup_client_args, req_args, warmup_bench_args, tokenizer, conversations
)
- logger.info(f"{Color.PURPLE}Warmup done{Color.RESET}")
+ warmup_runtime_sec = nanosec_to_sec(time.perf_counter_ns() - warmup_start_ns)
+ logger.info(
+ "%sWarmup runtime: %.3f sec (%.3f ms)%s",
+ Color.PURPLE,
+ warmup_runtime_sec,
+ warmup_runtime_sec * 1000,
+ Color.RESET,
+ )
+ logger.info("%sWarmup done%s", Color.PURPLE, Color.RESET)
# Run the benchmark
- start_time = time.perf_counter_ns()
+ benchmark_start_ns = time.perf_counter_ns()
client_convs, client_metrics = await main_mp(
client_args, req_args, bench_args, tokenizer, conversations
)
- total_runtime_ms = nanosec_to_millisec(time.perf_counter_ns() - start_time)
+ benchmark_runtime_sec = nanosec_to_sec(time.perf_counter_ns() - benchmark_start_ns)
# Calculate requests per second
- total_runtime_sec = total_runtime_ms / 1000.0
- rps = len(client_metrics) / total_runtime_sec
+ requests_per_sec = len(client_metrics) / benchmark_runtime_sec
+ benchmark_runtime_ms = benchmark_runtime_sec * 1000.0
logger.info(
- f"{Color.GREEN}All clients finished, total runtime: {total_runtime_sec:.3f} sec"
- f" ({total_runtime_ms:.3f} ms), requests per second: {rps:.3f}{Color.RESET}"
+ "%sAll clients finished, benchmark runtime: %.3f sec (%.3f ms), "
+ "requests per second: %.3f%s",
+ Color.GREEN,
+ benchmark_runtime_sec,
+ benchmark_runtime_ms,
+ requests_per_sec,
+ Color.RESET,
)
+ if warmup_runtime_sec is not None:
+ total_runtime_sec = benchmark_runtime_sec + warmup_runtime_sec
+ logger.info(
+ "%sWarmup runtime: %.3f sec (%.3f ms)%s",
+ Color.GREEN,
+ warmup_runtime_sec,
+ warmup_runtime_sec * 1000,
+ Color.RESET,
+ )
+ logger.info(
+ "%sTotal runtime (including warmup): %.3f sec (%.3f ms)%s",
+ Color.GREEN,
+ total_runtime_sec,
+ total_runtime_sec * 1000,
+ Color.RESET,
+ )
# Benchmark parameters
params = {
@@ -1606,6 +1648,7 @@ async def main() -> None:
verbose=args.verbose,
gen_conv_args=gen_conv_args,
excel_output=args.excel_output,
+ warmup_runtime_sec=warmup_runtime_sec,
)
if args.output_file is not None:
diff --git a/cmake/cpu_extension.cmake b/cmake/cpu_extension.cmake
index bb0179c79c10..fbbb03c5ed46 100644
--- a/cmake/cpu_extension.cmake
+++ b/cmake/cpu_extension.cmake
@@ -242,7 +242,7 @@ if ((AVX512_FOUND AND NOT AVX512_DISABLED) OR (ASIMD_FOUND AND NOT APPLE_SILICON
SUBBUILD_DIR "${FETCHCONTENT_BASE_DIR}/arm_compute-subbuild"
SOURCE_DIR "${FETCHCONTENT_BASE_DIR}/arm_compute-src"
GIT_REPOSITORY https://github.com/ARM-software/ComputeLibrary.git
- GIT_TAG v52.2.0
+ GIT_TAG v52.6.0
GIT_SHALLOW TRUE
GIT_PROGRESS TRUE
)
@@ -310,7 +310,7 @@ if ((AVX512_FOUND AND NOT AVX512_DISABLED) OR (ASIMD_FOUND AND NOT APPLE_SILICON
FetchContent_Declare(
oneDNN
GIT_REPOSITORY https://github.com/oneapi-src/oneDNN.git
- GIT_TAG v3.9
+ GIT_TAG v3.10
GIT_PROGRESS TRUE
GIT_SHALLOW TRUE
)
@@ -375,6 +375,7 @@ set(VLLM_EXT_SRC
if (AVX512_FOUND AND NOT AVX512_DISABLED)
set(VLLM_EXT_SRC
"csrc/cpu/shm.cpp"
+ "csrc/cpu/cpu_wna16.cpp"
${VLLM_EXT_SRC})
if (ENABLE_AVX512BF16 AND ENABLE_AVX512VNNI)
set(VLLM_EXT_SRC
diff --git a/cmake/external_projects/triton_kernels.cmake b/cmake/external_projects/triton_kernels.cmake
new file mode 100644
index 000000000000..d35ad123dd9d
--- /dev/null
+++ b/cmake/external_projects/triton_kernels.cmake
@@ -0,0 +1,53 @@
+# Install OpenAI triton_kernels from https://github.com/triton-lang/triton/tree/main/python/triton_kernels
+
+set(DEFAULT_TRITON_KERNELS_TAG "v3.5.0")
+
+# Set TRITON_KERNELS_SRC_DIR for use with local development with vLLM. We expect TRITON_KERNELS_SRC_DIR to
+# be directly set to the triton_kernels python directory.
+if (DEFINED ENV{TRITON_KERNELS_SRC_DIR})
+ message(STATUS "[triton_kernels] Fetch from $ENV{TRITON_KERNELS_SRC_DIR}")
+ FetchContent_Declare(
+ triton_kernels
+ SOURCE_DIR $ENV{TRITON_KERNELS_SRC_DIR}
+ )
+
+else()
+ set(TRITON_GIT "https://github.com/triton-lang/triton.git")
+ message (STATUS "[triton_kernels] Fetch from ${TRITON_GIT}:${DEFAULT_TRITON_KERNELS_TAG}")
+ FetchContent_Declare(
+ triton_kernels
+ # TODO (varun) : Fetch just the triton_kernels directory from Triton
+ GIT_REPOSITORY https://github.com/triton-lang/triton.git
+ GIT_TAG ${DEFAULT_TRITON_KERNELS_TAG}
+ GIT_PROGRESS TRUE
+ SOURCE_SUBDIR python/triton_kernels/triton_kernels
+ )
+endif()
+
+# Fetch content
+FetchContent_MakeAvailable(triton_kernels)
+
+if (NOT triton_kernels_SOURCE_DIR)
+ message (FATAL_ERROR "[triton_kernels] Cannot resolve triton_kernels_SOURCE_DIR")
+endif()
+
+if (DEFINED ENV{TRITON_KERNELS_SRC_DIR})
+ set(TRITON_KERNELS_PYTHON_DIR "${triton_kernels_SOURCE_DIR}/")
+else()
+ set(TRITON_KERNELS_PYTHON_DIR "${triton_kernels_SOURCE_DIR}/python/triton_kernels/triton_kernels/")
+endif()
+
+message (STATUS "[triton_kernels] triton_kernels is available at ${TRITON_KERNELS_PYTHON_DIR}")
+
+add_custom_target(triton_kernels)
+
+# Ensure the vllm/third_party directory exists before installation
+install(CODE "file(MAKE_DIRECTORY \"\${CMAKE_INSTALL_PREFIX}/vllm/third_party/triton_kernels\")")
+
+## Copy .py files to install directory.
+install(DIRECTORY
+ ${TRITON_KERNELS_PYTHON_DIR}
+ DESTINATION
+ vllm/third_party/triton_kernels/
+ COMPONENT triton_kernels
+ FILES_MATCHING PATTERN "*.py")
diff --git a/cmake/external_projects/vllm_flash_attn.cmake b/cmake/external_projects/vllm_flash_attn.cmake
index 29db9fa273a4..ff687e0af7b4 100644
--- a/cmake/external_projects/vllm_flash_attn.cmake
+++ b/cmake/external_projects/vllm_flash_attn.cmake
@@ -38,7 +38,7 @@ else()
FetchContent_Declare(
vllm-flash-attn
GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git
- GIT_TAG 8e1b01d56210dc72030a2d0d41c2d8d266ba6309
+ GIT_TAG 86f8f157cf82aa2342743752b97788922dd7de43
GIT_PROGRESS TRUE
# Don't share the vllm-flash-attn build between build types
BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn
diff --git a/cmake/utils.cmake b/cmake/utils.cmake
index ca0062ba4fab..5047c354ff7d 100644
--- a/cmake/utils.cmake
+++ b/cmake/utils.cmake
@@ -495,7 +495,13 @@ function (define_extension_target MOD_NAME)
set(SOABI_KEYWORD "")
endif()
- if (ARG_USE_SABI)
+ run_python(IS_FREETHREADED_PYTHON
+ "import sysconfig; print(1 if sysconfig.get_config_var(\"Py_GIL_DISABLED\") else 0)"
+ "Failed to determine whether interpreter is free-threaded")
+
+ # Free-threaded Python doesn't yet support the stable ABI (see PEP 803/809),
+ # so avoid using the stable ABI under free-threading only.
+ if (ARG_USE_SABI AND NOT IS_FREETHREADED_PYTHON)
Python_add_library(${MOD_NAME} MODULE USE_SABI ${ARG_USE_SABI} ${SOABI_KEYWORD} "${ARG_SOURCES}")
else()
Python_add_library(${MOD_NAME} MODULE ${SOABI_KEYWORD} "${ARG_SOURCES}")
diff --git a/csrc/attention/merge_attn_states.cu b/csrc/attention/merge_attn_states.cu
index 229d9862fb67..27d1e990c611 100644
--- a/csrc/attention/merge_attn_states.cu
+++ b/csrc/attention/merge_attn_states.cu
@@ -16,7 +16,8 @@ __global__ void merge_attn_states_kernel(
scalar_t* output, float* output_lse, const scalar_t* prefix_output,
const float* prefix_lse, const scalar_t* suffix_output,
const float* suffix_lse, const uint num_tokens, const uint num_heads,
- const uint head_size) {
+ const uint head_size, const uint prefix_head_stride,
+ const uint output_head_stride) {
using pack_128b_t = uint4;
const uint pack_size = 16 / sizeof(scalar_t);
const uint threads_per_head = head_size / pack_size;
@@ -34,11 +35,13 @@ __global__ void merge_attn_states_kernel(
const uint head_idx = token_head_idx % num_heads;
const uint pack_offset = pack_idx * pack_size; // (0~15)*8, etc.
- const uint head_offset =
- token_idx * num_heads * head_size + head_idx * head_size;
- const scalar_t* prefix_head_ptr = prefix_output + head_offset;
- const scalar_t* suffix_head_ptr = suffix_output + head_offset;
- scalar_t* output_head_ptr = output + head_offset;
+ const uint src_head_offset = token_idx * num_heads * prefix_head_stride +
+ head_idx * prefix_head_stride;
+ const uint dst_head_offset = token_idx * num_heads * output_head_stride +
+ head_idx * output_head_stride;
+ const scalar_t* prefix_head_ptr = prefix_output + src_head_offset;
+ const scalar_t* suffix_head_ptr = suffix_output + src_head_offset;
+ scalar_t* output_head_ptr = output + dst_head_offset;
float p_lse = prefix_lse[head_idx * num_tokens + token_idx];
float s_lse = suffix_lse[head_idx * num_tokens + token_idx];
@@ -140,7 +143,7 @@ __global__ void merge_attn_states_kernel(
reinterpret_cast(prefix_lse.data_ptr()), \
reinterpret_cast(suffix_output.data_ptr()), \
reinterpret_cast(suffix_lse.data_ptr()), num_tokens, \
- num_heads, head_size); \
+ num_heads, head_size, prefix_head_stride, output_head_stride); \
}
/*@brief Merges the attention states from prefix and suffix
@@ -166,17 +169,11 @@ void merge_attn_states_launcher(torch::Tensor& output,
const uint num_tokens = output.size(0);
const uint num_heads = output.size(1);
const uint head_size = output.size(2);
+ const uint prefix_head_stride = prefix_output.stride(1);
+ const uint output_head_stride = output.stride(1);
const uint pack_size = 16 / sizeof(scalar_t);
TORCH_CHECK(head_size % pack_size == 0,
"headsize must be multiple of pack_size:", pack_size);
- TORCH_CHECK(output.stride(-2) == head_size && output.stride(-1) == 1,
- "output heads must be contiguous in memory");
- TORCH_CHECK(
- prefix_output.stride(-2) == head_size && prefix_output.stride(-1) == 1,
- "prefix_output heads must be contiguous in memory");
- TORCH_CHECK(
- suffix_output.stride(-2) == head_size && suffix_output.stride(-1) == 1,
- "suffix_output heads must be contiguous in memory");
float* output_lse_ptr = nullptr;
if (output_lse.has_value()) {
output_lse_ptr = output_lse.value().data_ptr();
diff --git a/csrc/cache.h b/csrc/cache.h
index b162a4a2bc31..f2a5ec0acf5c 100644
--- a/csrc/cache.h
+++ b/csrc/cache.h
@@ -41,11 +41,12 @@ void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache,
const double scale, const std::string& kv_cache_dtype);
void gather_and_maybe_dequant_cache(
- torch::Tensor const& src_cache, // [NUM_BLOCKS, BLOCK_SIZE, ENTRIES...]
- torch::Tensor const& dst, // [TOT_TOKENS, ENTRIES...]
- torch::Tensor const& block_table, // [BATCH, BLOCK_INDICES]
- torch::Tensor const& cu_seq_lens, // [BATCH+1]
- int64_t batch_size, const std::string& kv_cache_dtype,
+ torch::Tensor const& src_cache, // [NUM_BLOCKS, BLOCK_SIZE, ENTRIES...]
+ torch::Tensor const& dst, // [TOT_TOKENS, ENTRIES...]
+ torch::Tensor const& block_table, // [BATCH, BLOCK_INDICES]
+ torch::Tensor const& cu_seq_lens, // [BATCH+1]
+ torch::Tensor const& token_to_seq, // [MAX_TOKEN_ACROSS_CHUNKS]
+ int64_t num_tokens, const std::string& kv_cache_dtype,
torch::Tensor const& scale,
std::optional seq_starts = std::nullopt);
diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu
index 0aa0dc14c748..8a5457206c70 100644
--- a/csrc/cache_kernels.cu
+++ b/csrc/cache_kernels.cu
@@ -552,7 +552,11 @@ __global__ void indexer_k_quant_and_cache_kernel(
#ifndef USE_ROCM
__syncwarp();
#endif
+#if defined(__gfx942__)
+ float scale = fmaxf(amax, 1e-4) / 224.0f;
+#else
float scale = fmaxf(amax, 1e-4) / 448.0f;
+#endif
if (use_ue8m0) {
scale = exp2f(ceilf(log2f(scale)));
}
@@ -901,87 +905,80 @@ void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache,
namespace vllm {
// grid is launched with dimensions (batch, num_splits)
-template
+template
__global__ void gather_and_maybe_dequant_cache(
- const cache_t* __restrict__ src_cache, // [NUM_BLOCKS, BLOCK_SIZE,
- // ENTRIES...]
- scalar_t* __restrict__ dst, // [TOT_TOKENS, ENTRIES...]
- const int32_t* __restrict__ block_table, // [BATCH, BLOCK_INDICES]
- const int32_t* __restrict__ cu_seq_lens, // [BATCH+1]
- const int32_t block_size, const int32_t entry_size,
+ const cache_t* __restrict__ src_cache, // [NUM_BLOCKS, BLOCK_SIZE,
+ // ENTRIES...]
+ scalar_t* __restrict__ dst, // [TOT_TOKENS, ENTRIES...]
+ const int32_t* __restrict__ block_table, // [BATCH, BLOCK_INDICES]
+ const int32_t* __restrict__ cu_seq_lens, // [BATCH+1]
+ const int32_t* __restrict__ token_to_seq, // [MAX_TOKEN_ACROSS_CHUNK]
+ const int32_t num_tokens, const int32_t block_size,
const int64_t block_table_stride, const int64_t cache_block_stride,
const int64_t cache_entry_stride, const int64_t dst_entry_stride,
const float* __restrict__ scale,
const int32_t* __restrict__ seq_starts) { // Optional: starting offsets per
// batch
+ constexpr int vec_size = sizeof(float4) / sizeof(scalar_t);
+ using ltype = vllm::vec_n_t;
+ using stype = vllm::vec_n_t;
+ // We are adding this for code readability which will be optimized out when
+ // build in release.
+ assert(CTA_SIZE == blockDim.x);
- const int64_t bid = blockIdx.x; // Batch ID
- const int32_t num_splits = gridDim.y;
- const int32_t split = blockIdx.y;
- const int32_t seq_start = cu_seq_lens[bid];
- const int32_t seq_end = cu_seq_lens[bid + 1];
- const int32_t seq_len = seq_end - seq_start;
- const int32_t tot_blocks = cuda_utils::ceil_div(seq_len, block_size);
- const int32_t split_blocks = cuda_utils::ceil_div(tot_blocks, num_splits);
-
- const int32_t split_start = split * split_blocks;
- const int32_t split_end = min((split + 1) * split_blocks, tot_blocks);
-
- const bool is_active_split = (split_start < tot_blocks);
- const bool is_last_split = (split_end == tot_blocks);
-
- if (!is_active_split) return;
-
- int32_t full_blocks_end = split_end;
- int32_t partial_block_size = 0;
-
- // Adjust the pointer for the block_table for this batch.
- // If seq_starts is provided, compute an offset based on (seq_starts[bid] /
- // page_size)
- const int32_t batch_offset = bid * block_table_stride;
- int32_t offset = 0;
- if (seq_starts != nullptr) {
- offset = seq_starts[bid] / block_size;
- }
- const int32_t* batch_block_table = block_table + batch_offset + offset;
-
- // Adjust dst pointer based on the cumulative sequence lengths.
- dst += seq_start * dst_entry_stride;
-
- if (is_last_split) {
- partial_block_size = seq_len % block_size;
- if (partial_block_size) full_blocks_end -= 1;
- }
+#pragma unroll
+ for (int token_id = blockIdx.x; token_id < num_tokens;
+ token_id += gridDim.x) {
+ int64_t batch_id = token_to_seq[token_id];
+ int64_t batch_start = cu_seq_lens[batch_id];
+ int64_t batch_end = cu_seq_lens[batch_id + 1];
+ int32_t batch_offset = token_id - batch_start;
+
+ if (token_id >= batch_end) return;
+ int32_t offset = 0;
+ if (seq_starts != nullptr) {
+ offset = seq_starts[batch_id];
+ }
+ batch_offset += offset;
+ int32_t block_table_id = batch_offset / block_size;
+ int32_t slot_id = batch_offset % block_size;
+ int32_t block_table_offset = batch_id * block_table_stride + block_table_id;
+ int32_t block_id = block_table[block_table_offset];
+ int64_t cache_offset =
+ block_id * cache_block_stride + slot_id * cache_entry_stride;
+ constexpr int32_t vec_iter_cnt = ENTRY_SIZE / vec_size;
+ scalar_t* dst_ = dst + token_id * dst_entry_stride;
+ cache_t* src_ = const_cast(src_cache) + cache_offset;
- auto copy_entry = [&](const cache_t* __restrict__ _src,
- scalar_t* __restrict__ _dst) {
- for (int i = threadIdx.x; i < entry_size; i += blockDim.x) {
+#pragma unroll
+ for (int idx = threadIdx.x; idx < vec_iter_cnt; idx += CTA_SIZE) {
if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) {
- _dst[i] = static_cast(_src[i]);
+ reinterpret_cast(dst_)[idx] =
+ static_cast(reinterpret_cast(src_)[idx]);
} else {
- _dst[i] =
- fp8::scaled_convert(_src[i], *scale);
+ ltype loaded_val = reinterpret_cast(src_)[idx];
+ stype store_val;
+#pragma unroll
+ for (int j = 0; j < vec_size; ++j) {
+ store_val.val[j] = fp8::scaled_convert(
+ loaded_val.val[j], *scale);
+ }
+ reinterpret_cast(dst_)[idx] = store_val;
}
}
- };
-
- for (int pid = split_start; pid < full_blocks_end; ++pid) {
- auto block_id = batch_block_table[pid];
- auto block_start_ptr = src_cache + block_id * cache_block_stride;
- auto block_dst_ptr = dst + pid * block_size * dst_entry_stride;
- for (int eid = 0; eid < block_size; ++eid) {
- copy_entry(block_start_ptr + eid * cache_entry_stride,
- block_dst_ptr + eid * dst_entry_stride);
- }
- }
-
- if (partial_block_size) {
- auto block_id = batch_block_table[full_blocks_end];
- auto block_start_ptr = src_cache + block_id * cache_block_stride;
- auto block_dst_ptr = dst + full_blocks_end * block_size * dst_entry_stride;
- for (int eid = 0; eid < partial_block_size; ++eid) {
- copy_entry(block_start_ptr + eid * cache_entry_stride,
- block_dst_ptr + eid * dst_entry_stride);
+ // process tail
+ constexpr int32_t tail_cnt = ENTRY_SIZE % vec_size;
+ dst_ = dst_ + ENTRY_SIZE - tail_cnt;
+ src_ = src_ + ENTRY_SIZE - tail_cnt;
+#pragma unroll
+ for (int idx = threadIdx.x; idx < tail_cnt; idx += CTA_SIZE) {
+ if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) {
+ dst_[idx] = static_cast(src_[idx]);
+ } else {
+ dst_[idx] =
+ fp8::scaled_convert(src_[idx], *scale);
+ }
}
}
}
@@ -992,34 +989,38 @@ __global__ void gather_and_maybe_dequant_cache(
// SCALAR_T is the data type of the destination tensor.
// CACHE_T is the stored data type of kv-cache.
// KV_DTYPE is the real data type of kv-cache.
-#define CALL_GATHER_CACHE(SCALAR_T, CACHE_T, KV_DTYPE) \
- vllm::gather_and_maybe_dequant_cache \
- <<>>( \
- reinterpret_cast(src_cache.data_ptr()), \
- reinterpret_cast(dst.data_ptr()), \
- block_table.data_ptr(), cu_seq_lens.data_ptr(), \
- block_size, entry_size, block_table_stride, cache_block_stride, \
- cache_entry_stride, dst_entry_stride, \
- reinterpret_cast(scale.data_ptr()), seq_starts_ptr);
+#define CALL_GATHER_CACHE(SCALAR_T, CACHE_T, KV_DTYPE) \
+ vllm::gather_and_maybe_dequant_cache \
+ <<>>( \
+ reinterpret_cast(src_cache.data_ptr()), \
+ reinterpret_cast(dst.data_ptr()), \
+ block_table.data_ptr(), cu_seq_lens.data_ptr(), \
+ token_to_seq.data_ptr(), num_tokens, block_size, \
+ block_table_stride, cache_block_stride, cache_entry_stride, \
+ dst_entry_stride, reinterpret_cast(scale.data_ptr()), \
+ seq_starts_ptr);
// Gather sequences from the cache into the destination tensor.
// - cu_seq_lens contains the cumulative sequence lengths for each batch
// - block_table contains the cache block indices for each sequence
+// - token_to_seq contains the back mapping from token_id to batch_id
// - Optionally, seq_starts (if provided) offsets the starting block index by
// (seq_starts[bid] / page_size)
void gather_and_maybe_dequant_cache(
- torch::Tensor const& src_cache, // [NUM_BLOCKS, BLOCK_SIZE, ENTRIES...]
- torch::Tensor const& dst, // [TOT_TOKENS, ENTRIES...]
- torch::Tensor const& block_table, // [BATCH, BLOCK_INDICES]
- torch::Tensor const& cu_seq_lens, // [BATCH+1]
- int64_t batch_size, const std::string& kv_cache_dtype,
+ torch::Tensor const& src_cache, // [NUM_BLOCKS, BLOCK_SIZE, ENTRIES...]
+ torch::Tensor const& dst, // [TOT_TOKENS, ENTRIES...]
+ torch::Tensor const& block_table, // [BATCH, BLOCK_INDICES]
+ torch::Tensor const& cu_seq_lens, // [BATCH+1]
+ torch::Tensor const& token_to_seq, // [MAX_TOKEN_ACROSS_CHUNKS]
+ int64_t num_tokens, const std::string& kv_cache_dtype,
torch::Tensor const& scale,
std::optional seq_starts = std::nullopt) {
at::cuda::OptionalCUDAGuard device_guard(src_cache.device());
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
int32_t block_size = src_cache.size(1);
- int32_t entry_size = src_cache.flatten(2, -1).size(2);
+ int32_t head_dim = dst.size(-1);
TORCH_CHECK(block_table.dtype() == torch::kInt32,
"block_table must be int32");
@@ -1029,6 +1030,9 @@ void gather_and_maybe_dequant_cache(
TORCH_CHECK(seq_starts.value().dtype() == torch::kInt32,
"seq_starts must be int32");
}
+ TORCH_CHECK(head_dim == 576,
+ "gather_and_maybe_dequant_cache only support the head_dim to 576 "
+ "for better performance")
TORCH_CHECK(src_cache.device() == dst.device(),
"src_cache and dst must be on the same device");
@@ -1046,10 +1050,9 @@ void gather_and_maybe_dequant_cache(
int64_t cache_entry_stride = src_cache.stride(1);
int64_t dst_entry_stride = dst.stride(0);
- // Decide on the number of splits based on the batch size.
- int num_splits = batch_size > 128 ? 2 : batch_size > 64 ? 4 : 16;
- dim3 grid(batch_size, num_splits);
- dim3 block(1024);
+ constexpr int32_t thread_block_size = 64;
+ dim3 grid(num_tokens);
+ dim3 block(thread_block_size);
const int32_t* seq_starts_ptr =
seq_starts.has_value() ? seq_starts.value().data_ptr() : nullptr;
diff --git a/csrc/cpu/cpu_attn.cpp b/csrc/cpu/cpu_attn.cpp
index 50f17c758c14..92f8bee5a47a 100644
--- a/csrc/cpu/cpu_attn.cpp
+++ b/csrc/cpu/cpu_attn.cpp
@@ -13,6 +13,18 @@
#define AMX_DISPATCH(...) case cpu_attention::ISA::AMX:
#endif
+#ifdef __aarch64__
+ #include "cpu_attn_neon.hpp"
+ #define NEON_DISPATCH(...) \
+ case cpu_attention::ISA::NEON: { \
+ using attn_impl = cpu_attention::AttentionImpl; \
+ return __VA_ARGS__(); \
+ }
+#else
+ #define NEON_DISPATCH(...) case cpu_attention::ISA::NEON:
+#endif // #ifdef __aarch64__
+
#define CPU_ATTN_DISPATCH_CASE(HEAD_DIM, ...) \
case HEAD_DIM: { \
constexpr size_t head_dim = HEAD_DIM; \
@@ -41,6 +53,7 @@
[&] { \
switch (ISA_TYPE) { \
AMX_DISPATCH(__VA_ARGS__) \
+ NEON_DISPATCH(__VA_ARGS__) \
case cpu_attention::ISA::VEC: { \
using attn_impl = \
cpu_attention::AttentionImpl
#include
#include
+#if defined(__APPLE__)
+ #include
+#endif
+
#include "cpu_types.hpp"
#include "scratchpad_manager.h"
#include "cpu_attn_macros.h"
+#include "utils.hpp"
namespace cpu_attention {
-enum class ISA { AMX, VEC, VEC16 };
+enum class ISA { AMX, VEC, VEC16, NEON };
template
class AttentionImpl {};
@@ -139,6 +143,12 @@ struct AttentionMetadata {
case ISA::VEC:
ss << "VEC, ";
break;
+ case ISA::VEC16:
+ ss << "VEC16, ";
+ break;
+ case ISA::NEON:
+ ss << "NEON, ";
+ break;
}
ss << "workitem_group_num: " << workitem_group_num
<< ", reduction_item_num: " << reduction_item_num
@@ -741,9 +751,21 @@ class AttentionScheduler {
static int64_t get_available_l2_size() {
static int64_t size = []() {
+#if defined(__APPLE__)
+ // macOS doesn't have _SC_LEVEL2_CACHE_SIZE. Use sysctlbyname.
+ int64_t l2_cache_size = 0;
+ size_t len = sizeof(l2_cache_size);
+ if (sysctlbyname("hw.l2cachesize", &l2_cache_size, &len, NULL, 0) == 0 &&
+ l2_cache_size > 0) {
+ return l2_cache_size >> 1; // use 50% of L2 cache
+ }
+ // Fallback if sysctlbyname fails
+ return 128LL * 1024 >> 1; // use 50% of 128KB
+#else
long l2_cache_size = sysconf(_SC_LEVEL2_CACHE_SIZE);
TORCH_CHECK_NE(l2_cache_size, -1);
return l2_cache_size >> 1; // use 50% of L2 cache
+#endif
}();
return size;
}
@@ -816,15 +838,21 @@ struct VecTypeTrait {
using vec_t = vec_op::FP32Vec16;
};
+// ARM only supports BF16 with ARMv8.6-A extension
+#if (defined(__aarch64__) && !defined(ARM_BF16_SUPPORT))
+#else
template <>
struct VecTypeTrait {
using vec_t = vec_op::BF16Vec16;
};
+#endif
+#if !defined(__powerpc__) && !defined(__s390x__)
template <>
struct VecTypeTrait {
using vec_t = vec_op::FP16Vec16;
};
+#endif
template
void print_logits(const char* name, T* ptr, int32_t row, int32_t col,
@@ -1586,9 +1614,17 @@ class AttentionMainLoop {
if (use_sink) {
alignas(64) float s_aux_fp32[16];
+#if defined(__aarch64__) && !defined(ARM_BF16_SUPPORT)
+ // ARM without native BF16 support: manual conversion
+ for (int i = 0; i < 16; ++i) {
+ s_aux_fp32[i] = static_cast(curr_s_aux[i]);
+ }
+#else
+ // All other platforms have BF16Vec16 available
vec_op::BF16Vec16 vec_bf16(curr_s_aux);
vec_op::FP32Vec16 vec_fp32(vec_bf16);
vec_fp32.save(s_aux_fp32);
+#endif
float* __restrict__ curr_sum_buffer = sum_buffer;
float* __restrict__ curr_max_buffer = max_buffer;
diff --git a/csrc/cpu/cpu_attn_neon.hpp b/csrc/cpu/cpu_attn_neon.hpp
new file mode 100644
index 000000000000..827f0cfbc718
--- /dev/null
+++ b/csrc/cpu/cpu_attn_neon.hpp
@@ -0,0 +1,386 @@
+#ifndef CPU_ATTN_NEON_HPP
+#define CPU_ATTN_NEON_HPP
+
+#include "cpu_attn_impl.hpp"
+#include
+#include
+namespace cpu_attention {
+
+namespace {
+
+#define BLOCK_SIZE_ALIGNMENT 32
+#define HEAD_SIZE_ALIGNMENT 32
+#define MAX_Q_HEAD_NUM_PER_ITER 16
+
+// These do not use vectorized class for loading / converting
+// because csrc/cpu/cpu_types_arm.hpp does not have fallback options
+// for vec_op::BF16Vec* / vec_op::BF16Vec* on Arm HW that
+// doesn't support BF16.
+// We don't use vec_op::FP32Vec* or vec_op::FP16Vec* for consistency.
+template
+FORCE_INLINE void load_row8_B_as_f32(const kv_cache_t* p, float32x4_t& b0,
+ float32x4_t& b1);
+
+template <>
+FORCE_INLINE void load_row8_B_as_f32(const float* p, float32x4_t& b0,
+ float32x4_t& b1) {
+ b0 = vld1q_f32(p + 0);
+ b1 = vld1q_f32(p + 4);
+}
+
+template <>
+FORCE_INLINE void load_row8_B_as_f32(const c10::Half* p,
+ float32x4_t& b0,
+ float32x4_t& b1) {
+ const float16_t* h = reinterpret_cast(p);
+ float16x8_t v = vld1q_f16(h);
+ b0 = vcvt_f32_f16(vget_low_f16(v));
+ b1 = vcvt_f32_f16(vget_high_f16(v));
+}
+
+template <>
+FORCE_INLINE void load_row8_B_as_f32(const c10::BFloat16* p,
+ float32x4_t& b0,
+ float32x4_t& b1) {
+ const uint16_t* u = reinterpret_cast(p);
+#ifdef ARM_BF16_SUPPORT
+ uint16x8_t u0 = vld1q_u16(u);
+ bfloat16x8_t bf0 = vreinterpretq_bf16_u16(u0);
+ b0 = vcvtq_low_f32_bf16(bf0);
+ b1 = vcvtq_high_f32_bf16(bf0);
+#else
+ uint16x8_t x0 = vld1q_u16(u);
+ uint32x4_t lo = vshlq_n_u32(vmovl_u16(vget_low_u16(x0)), 16);
+ uint32x4_t hi = vshlq_n_u32(vmovl_u16(vget_high_u16(x0)), 16);
+ b0 = vreinterpretq_f32_u32(lo);
+ b1 = vreinterpretq_f32_u32(hi);
+#endif
+}
+
+// Mx8, with 1 <= M <= 8 , K streamed, unroll-by-4 with NEON FMLAs
+// #Loads = (K // 4) * (M + 4 * sizeof(kv_cache_t) / 2)
+// #FMLAs = (K // 4) * (4 * 2 * M)
+// We have (4 * 2 * M) FMLAs for (M + 4 * sizeof(kv_cache_t) / 2) loads
+template
+FORCE_INLINE void gemm_micro_neon_fmla_Mx8_Ku4(
+ const float* __restrict A, // [M x K],
+ const kv_cache_t* __restrict B, // [K x 8],
+ float* __restrict C, // [M x 8],
+ int64_t lda, int64_t ldb, int64_t ldc, int32_t K, bool accumulate) {
+ // kernel supports max M of 8, as it'd spill for larger M
+ static_assert(1 <= M && M <= 8, "M must be in [1,8]");
+
+// helpers for per-M codegen
+#define ROWS_APPLY(OP) OP(0) OP(1) OP(2) OP(3) OP(4) OP(5) OP(6) OP(7)
+#define IF_M(i) if constexpr (M > (i))
+
+ // A row base pointers
+#define DECL_A(i) const float* a##i = A + (i) * lda;
+ ROWS_APPLY(DECL_A)
+#undef DECL_A
+
+ // declare 2 accumulators per row of M
+#define DECL_ACC(i) float32x4_t acc##i##_0, acc##i##_1;
+ ROWS_APPLY(DECL_ACC)
+#undef DECL_ACC
+
+ // initialize accumulators
+#define INIT_ACC(i) \
+ IF_M(i) { \
+ if (accumulate) { \
+ acc##i##_0 = vld1q_f32(C + (i) * ldc + 0); \
+ acc##i##_1 = vld1q_f32(C + (i) * ldc + 4); \
+ } else { \
+ acc##i##_0 = vdupq_n_f32(0.f); \
+ acc##i##_1 = vdupq_n_f32(0.f); \
+ } \
+ }
+ ROWS_APPLY(INIT_ACC)
+#undef INIT_ACC
+
+ int32_t k = 0;
+
+ // K unrolled by 4
+ for (; k + 3 < K; k += 4) {
+ // load A[k..k+3] for each active row (M)
+#define LOAD_A4(i) \
+ float32x4_t a##i##v; \
+ IF_M(i) a##i##v = vld1q_f32(a##i + k);
+ ROWS_APPLY(LOAD_A4)
+#undef LOAD_A4
+
+ // helper: FMA lane L from aiv
+#define FMAS_LANE(i, aiv, L) \
+ IF_M(i) { \
+ acc##i##_0 = vfmaq_laneq_f32(acc##i##_0, b0, aiv, L); \
+ acc##i##_1 = vfmaq_laneq_f32(acc##i##_1, b1, aiv, L); \
+ }
+
+ // k + 0
+ {
+ float32x4_t b0, b1;
+ load_row8_B_as_f32(B + (int64_t)(k + 0) * ldb, b0, b1);
+#define STEP_K0(i) FMAS_LANE(i, a##i##v, 0)
+ ROWS_APPLY(STEP_K0)
+#undef STEP_K0
+ }
+ // k + 1
+ {
+ float32x4_t b0, b1;
+ load_row8_B_as_f32(B + (int64_t)(k + 1) * ldb, b0, b1);
+#define STEP_K1(i) FMAS_LANE(i, a##i##v, 1)
+ ROWS_APPLY(STEP_K1)
+#undef STEP_K1
+ }
+ // k + 2
+ {
+ float32x4_t b0, b1;
+ load_row8_B_as_f32(B + (int64_t)(k + 2) * ldb, b0, b1);
+#define STEP_K2(i) FMAS_LANE(i, a##i##v, 2)
+ ROWS_APPLY(STEP_K2)
+#undef STEP_K2
+ }
+ // k + 3
+ {
+ float32x4_t b0, b1;
+ load_row8_B_as_f32(B + (int64_t)(k + 3) * ldb, b0, b1);
+#define STEP_K3(i) FMAS_LANE(i, a##i##v, 3)
+ ROWS_APPLY(STEP_K3)
+#undef STEP_K3
+ }
+#undef FMAS_LANE
+ }
+
+ // K tail
+ for (; k < K; ++k) {
+ float32x4_t b0, b1;
+ load_row8_B_as_f32(B + (int64_t)k * ldb, b0, b1);
+#define TAIL_ROW(i) \
+ IF_M(i) { \
+ float32x4_t ai = vdupq_n_f32(*(a##i + k)); \
+ acc##i##_0 = vfmaq_f32(acc##i##_0, b0, ai); \
+ acc##i##_1 = vfmaq_f32(acc##i##_1, b1, ai); \
+ }
+ ROWS_APPLY(TAIL_ROW)
+#undef TAIL_ROW
+ }
+
+ // store accumulators to C
+#define STORE_ROW(i) \
+ IF_M(i) { \
+ vst1q_f32(C + (i) * ldc + 0, acc##i##_0); \
+ vst1q_f32(C + (i) * ldc + 4, acc##i##_1); \
+ }
+ ROWS_APPLY(STORE_ROW)
+#undef STORE_ROW
+
+#undef ROWS_APPLY
+#undef IF_M
+}
+
+template
+FORCE_INLINE void gemm_macro_neon_fmla_Mx8_Ku4(const float* __restrict A,
+ const kv_cache_t* __restrict B,
+ float* __restrict C, int32_t M,
+ int32_t K, int64_t lda,
+ int64_t ldb, int64_t ldc,
+ bool accumulate) {
+ // micro kernel is Mx8
+ static_assert(N % 8 == 0, "N must be a multiple of 8");
+ for (int32_t m = 0; m < M;) {
+ int32_t mb = (M - m >= 8) ? 8 : (M - m >= 4) ? 4 : (M - m >= 2) ? 2 : 1;
+ const float* Ab = A + m * lda;
+ float* Cb = C + m * ldc;
+
+ for (int32_t n = 0; n < N; n += 8) {
+ const kv_cache_t* Bn = B + n;
+ float* Cn = Cb + n;
+ switch (mb) {
+ case 8:
+ gemm_micro_neon_fmla_Mx8_Ku4<8, kv_cache_t>(Ab, Bn, Cn, lda, ldb, ldc,
+ K, accumulate);
+ break;
+ case 4:
+ gemm_micro_neon_fmla_Mx8_Ku4<4, kv_cache_t>(Ab, Bn, Cn, lda, ldb, ldc,
+ K, accumulate);
+ break;
+ case 2:
+ gemm_micro_neon_fmla_Mx8_Ku4<2, kv_cache_t>(Ab, Bn, Cn, lda, ldb, ldc,
+ K, accumulate);
+ break;
+ default:
+ gemm_micro_neon_fmla_Mx8_Ku4<1, kv_cache_t>(Ab, Bn, Cn, lda, ldb, ldc,
+ K, accumulate);
+ break;
+ }
+ }
+ // no tail loop for N as it's guaranteed to be a multiple of 8
+ m += mb;
+ }
+}
+
+template
+class TileGemmNeonFMLA {
+ public:
+ template
+ FORCE_INLINE static void gemm(const int32_t m_size,
+ float* __restrict__ a_tile,
+ kv_cache_t* __restrict__ b_tile,
+ float* __restrict__ c_tile, const int64_t lda,
+ const int64_t ldb, const int64_t ldc,
+ const int32_t block_size,
+ const int32_t dynamic_k_size,
+ const bool accum_c) {
+ if constexpr (phase == AttentionGemmPhase::QK) {
+ gemm_macro_neon_fmla_Mx8_Ku4(
+ a_tile, b_tile, c_tile, m_size, k_size, lda, ldb, ldc, accum_c);
+ } else {
+ gemm_macro_neon_fmla_Mx8_Ku4(
+ a_tile, b_tile, c_tile, m_size, dynamic_k_size, lda, ldb, ldc,
+ accum_c);
+ }
+ }
+};
+
+} // namespace
+
+// this is similar to "ISA::VEC" at the moment
+template
+class AttentionImpl {
+ public:
+ using query_t = scalar_t;
+ using q_buffer_t = float;
+ using kv_cache_t = scalar_t;
+ using logits_buffer_t = float;
+ using partial_output_buffer_t = float;
+ using prob_buffer_t = float;
+
+ constexpr static int64_t BlockSizeAlignment =
+ BLOCK_SIZE_ALIGNMENT; // KV token num unit of QK and PV phases
+ constexpr static int64_t HeadDimAlignment =
+ HEAD_SIZE_ALIGNMENT; // headdim num unit of PV phase
+ constexpr static int64_t MaxQHeadNumPerIteration = MAX_Q_HEAD_NUM_PER_ITER;
+ constexpr static int64_t HeadDim = head_dim;
+ constexpr static ISA ISAType = ISA::NEON;
+ constexpr static bool scale_on_logits = false; // apply scale on q_buffer
+
+ static_assert(HeadDim % HeadDimAlignment == 0);
+ // the gemm micro kernel is Mx8
+ static_assert(HeadDimAlignment % 8 == 0);
+ static_assert(BlockSizeAlignment % 8 == 0);
+
+ public:
+ template typename attention>
+ FORCE_INLINE void execute_attention(DEFINE_CPU_ATTENTION_PARAMS) {
+ attention> attention_iteration;
+ attention_iteration(CPU_ATTENTION_PARAMS);
+ }
+
+ // k_cache_token_group_stride: stride of K cache when move to next
+ // BlockSizeAlignment tokens in a block
+ constexpr static int64_t k_cache_token_group_stride(
+ const int32_t block_size) {
+ return BlockSizeAlignment; // layout of k_cache block is [head_dim,
+ // block_size], row-major
+ }
+
+ // v_cache_token_group_stride: stride of V cache when move to next
+ // BlockSizeAlignment tokens in a block
+ constexpr static int64_t v_cache_token_group_stride(
+ const int32_t block_size) {
+ return head_dim * BlockSizeAlignment; // layout of v_cache is [block_size,
+ // head_dim], row-major
+ }
+
+ // v_cache_head_group_stride: stride of V cache when move to next
+ // HeadDimAlignment head dims in a block
+ constexpr static int64_t v_cache_head_group_stride(const int32_t block_size) {
+ return HeadDimAlignment; // layout of v_cache is [block_size, head_dim],
+ // row-major
+ }
+
+ // Copy q to q_buffer and cast it to fp32
+ static void copy_q_heads_tile(
+ scalar_t* __restrict__ src, // [q_num, q_heads_per_kv, head_size]
+ float* __restrict__ q_buffer, const int32_t q_num,
+ const int32_t q_heads_per_kv, const int64_t q_num_stride,
+ const int64_t q_head_stride, float scale) {
+ static_assert(head_dim % 16 == 0);
+ constexpr int32_t unroll_size = head_dim / 16;
+ using load_vec_t = typename VecTypeTrait::vec_t;
+
+ vec_op::FP32Vec16 scale_vec(scale);
+ for (int32_t q_num_idx = 0; q_num_idx < q_num; ++q_num_idx) {
+ for (int32_t q_head_idx = 0; q_head_idx < q_heads_per_kv; ++q_head_idx) {
+ scalar_t* __restrict__ curr_q =
+ src + q_num_idx * q_num_stride + q_head_idx * q_head_stride;
+ float* __restrict__ curr_q_buffer =
+ q_buffer + q_num_idx * q_heads_per_kv * head_dim +
+ q_head_idx * head_dim;
+
+ vec_op::unroll_loop([&](int32_t i) {
+ load_vec_t vec(curr_q);
+ vec_op::FP32Vec16 fp32_vec(vec);
+ fp32_vec = fp32_vec * scale_vec;
+ fp32_vec.save(curr_q_buffer);
+
+ curr_q += 16;
+ curr_q_buffer += 16;
+ });
+ }
+ }
+ }
+
+ // reshape K as column-major and V as row-major
+ static void reshape_and_cache(
+ const scalar_t* __restrict__ key, const scalar_t* __restrict__ value,
+ scalar_t* __restrict__ key_cache, scalar_t* __restrict__ value_cache,
+ const int64_t* __restrict__ slot_mapping, const int64_t token_num,
+ const int64_t key_token_num_stride, const int64_t value_token_num_stride,
+ const int64_t head_num, const int64_t key_head_num_stride,
+ const int64_t value_head_num_stride, const int64_t num_blocks,
+ const int64_t num_blocks_stride, const int64_t cache_head_num_stride,
+ const int64_t block_size, const int64_t block_size_stride) {
+#pragma omp parallel for collapse(2)
+ for (int64_t token_idx = 0; token_idx < token_num; ++token_idx) {
+ for (int64_t head_idx = 0; head_idx < head_num; ++head_idx) {
+ const int64_t pos = slot_mapping[token_idx];
+ if (pos < 0) {
+ // skip
+ continue;
+ }
+
+ const int64_t block_idx = pos / block_size;
+ const int64_t block_offset = pos % block_size;
+ {
+ // Write Key
+ const scalar_t* key_start_ptr = key +
+ token_idx * key_token_num_stride +
+ head_idx * key_head_num_stride;
+ scalar_t* key_cache_start_ptr =
+ key_cache + block_idx * num_blocks_stride +
+ head_idx * cache_head_num_stride + block_offset;
+
+#pragma GCC unroll 8
+ for (int64_t i = 0, j = 0; i < head_dim; ++i, j += block_size) {
+ key_cache_start_ptr[j] = key_start_ptr[i];
+ }
+ }
+ {
+ // Write Value
+ const scalar_t* value_start_ptr = value +
+ token_idx * value_token_num_stride +
+ head_idx * value_head_num_stride;
+ scalar_t* value_cache_start_ptr =
+ value_cache + block_idx * num_blocks_stride +
+ head_idx * cache_head_num_stride + block_offset * head_dim;
+ std::memcpy(value_cache_start_ptr, value_start_ptr,
+ sizeof(scalar_t) * head_dim);
+ }
+ }
+ }
+ }
+};
+} // namespace cpu_attention
+
+#endif // #ifndef CPU_ATTN_NEON_HPP
diff --git a/csrc/cpu/cpu_types_scalar.hpp b/csrc/cpu/cpu_types_scalar.hpp
index 1a9278bc662e..f9da78283da5 100644
--- a/csrc/cpu/cpu_types_scalar.hpp
+++ b/csrc/cpu/cpu_types_scalar.hpp
@@ -26,10 +26,6 @@ namespace vec_op {
#define FORCE_INLINE __attribute__((always_inline)) inline
-#define __max(a, b) ((a) > (b) ? (a) : (b))
-#define __min(a, b) ((a) < (b) ? (a) : (b))
-#define __abs(a) ((a) < (0) ? (0 - a) : (a))
-
typedef struct f16x8_t {
uint16_t val[8];
} f16x8_t;
@@ -99,7 +95,7 @@ struct FP16Vec16 : public Vec {
void save(void* ptr) const { *reinterpret_cast(ptr) = reg; }
void save(void* ptr, const int elem_num) const {
- int num = __min(elem_num, VEC_ELEM_NUM);
+ int num = std::min(elem_num, VEC_ELEM_NUM);
std::memcpy(ptr, &(reg.val[0]), num * sizeof(uint16_t));
}
};
@@ -128,7 +124,7 @@ struct BF16Vec16 : public Vec {
void save(void* ptr) const { *reinterpret_cast(ptr) = reg; }
void save(void* ptr, const int elem_num) const {
- int num = __min(elem_num, VEC_ELEM_NUM);
+ int num = std::min(elem_num, VEC_ELEM_NUM);
std::memcpy(ptr, &(reg.val[0]), num * sizeof(uint16_t));
}
};
@@ -143,9 +139,9 @@ struct BF16Vec32 : public Vec {
explicit BF16Vec32(f16x32_t data) : reg(data) {};
explicit BF16Vec32(BF16Vec8& vec8_data) {
- for (int i = 0; i < VEC_ELEM_NUM; ++i) {
+ unroll_loop([&vec8_data, this](int i) {
reg.val[i] = vec8_data.reg.val[i % BF16Vec8::VEC_ELEM_NUM];
- }
+ });
}
void save(void* ptr) const { *reinterpret_cast(ptr) = reg; }
@@ -157,15 +153,11 @@ struct FP32Vec4 : public Vec {
f32x4_t reg;
explicit FP32Vec4(float v) {
- for (int i = 0; i < VEC_ELEM_NUM; ++i) {
- reg.val[i] = v;
- }
+ unroll_loop([&v, this](int i) { reg.val[i] = v; });
}
explicit FP32Vec4() {
- for (int i = 0; i < VEC_ELEM_NUM; ++i) {
- reg.val[i] = 0.0f;
- }
+ unroll_loop([this](int i) { reg.val[i] = 0.0f; });
}
explicit FP32Vec4(const float* ptr)
@@ -182,15 +174,11 @@ struct FP32Vec8 : public Vec {
f32x8_t reg;
explicit FP32Vec8(float v) {
- for (int i = 0; i < VEC_ELEM_NUM; ++i) {
- reg.val[i] = v;
- }
+ unroll_loop([&v, this](int i) { reg.val[i] = v; });
}
explicit FP32Vec8() {
- for (int i = 0; i < VEC_ELEM_NUM; ++i) {
- reg.val[i] = 0.0f;
- }
+ unroll_loop([this](int i) { reg.val[i] = 0.0f; });
}
explicit FP32Vec8(const float* ptr)
@@ -201,78 +189,68 @@ struct FP32Vec8 : public Vec {
explicit FP32Vec8(const FP32Vec8& data) : reg(data.reg) {};
explicit FP32Vec8(const FP16Vec8& v) {
- for (int i = 0; i < VEC_ELEM_NUM; ++i) {
- reg.val[i] = fp16_to_float(v.reg.val[i]);
- }
+ unroll_loop(
+ [&v, this](int i) { reg.val[i] = fp16_to_float(v.reg.val[i]); });
}
FP32Vec8(const BF16Vec8& v) {
- for (int i = 0; i < VEC_ELEM_NUM; ++i) {
- reg.val[i] = bf16_to_float(v.reg.val[i]);
- }
+ unroll_loop(
+ [&v, this](int i) { reg.val[i] = bf16_to_float(v.reg.val[i]); });
}
float reduce_sum() const {
float result = 0;
- for (int i = 0; i < VEC_ELEM_NUM; ++i) {
- result += reg.val[i];
- }
+ unroll_loop(
+ [&result, this](int i) { result += reg.val[i]; });
return result;
}
FP32Vec8 exp() const {
f32x8_t ret;
- for (int i = 0; i < VEC_ELEM_NUM; ++i) {
- ret.val[i] = expf(reg.val[i]);
- }
+ unroll_loop(
+ [&ret, this](int i) { ret.val[i] = expf(reg.val[i]); });
return FP32Vec8(ret);
}
FP32Vec8 tanh() const {
f32x8_t ret;
- for (int i = 0; i < VEC_ELEM_NUM; ++i) {
- ret.val[i] = tanhf(reg.val[i]);
- }
+ unroll_loop(
+ [&ret, this](int i) { ret.val[i] = tanhf(reg.val[i]); });
return FP32Vec8(ret);
}
FP32Vec8 er() const {
f32x8_t ret;
- for (int i = 0; i < VEC_ELEM_NUM; ++i) {
- ret.val[i] = erf(reg.val[i]);
- }
+ unroll_loop(
+ [&ret, this](int i) { ret.val[i] = erf(reg.val[i]); });
return FP32Vec8(ret);
}
FP32Vec8 operator*(const FP32Vec8& b) const {
f32x8_t ret;
- for (int i = 0; i < VEC_ELEM_NUM; ++i) {
- ret.val[i] = reg.val[i] * b.reg.val[i];
- }
+ unroll_loop(
+ [&ret, &b, this](int i) { ret.val[i] = reg.val[i] * b.reg.val[i]; });
return FP32Vec8(ret);
}
FP32Vec8 operator+(const FP32Vec8& b) const {
f32x8_t ret;
- for (int i = 0; i < VEC_ELEM_NUM; ++i) {
- ret.val[i] = reg.val[i] + b.reg.val[i];
- }
+ unroll_loop(
+ [&ret, &b, this](int i) { ret.val[i] = reg.val[i] + b.reg.val[i]; });
return FP32Vec8(ret);
}
FP32Vec8 operator-(const FP32Vec8& b) const {
f32x8_t ret;
- for (int i = 0; i < VEC_ELEM_NUM; ++i) {
- ret.val[i] = reg.val[i] - b.reg.val[i];
- }
+ unroll_loop(
+ [&ret, &b, this](int i) { ret.val[i] = reg.val[i] - b.reg.val[i]; });
return FP32Vec8(ret);
}
FP32Vec8 operator/(const FP32Vec8& b) const {
f32x8_t ret;
- for (int i = 0; i < VEC_ELEM_NUM; ++i) {
- ret.val[i] = reg.val[i] / b.reg.val[i];
- }
+ unroll_loop(
+ [&ret, &b, this](int i) { ret.val[i] = reg.val[i] / b.reg.val[i]; });
return FP32Vec8(ret);
}
@@ -284,15 +262,11 @@ struct FP32Vec16 : public Vec {
f32x16_t reg;
explicit FP32Vec16(float v) {
- for (int i = 0; i < VEC_ELEM_NUM; ++i) {
- reg.val[i] = v;
- }
+ unroll_loop([&v, this](int i) { reg.val[i] = v; });
}
explicit FP32Vec16() {
- for (int i = 0; i < VEC_ELEM_NUM; ++i) {
- reg.val[i] = 0.0f;
- }
+ unroll_loop([this](int i) { reg.val[i] = 0.0f; });
}
explicit FP32Vec16(const float* ptr)
@@ -301,29 +275,27 @@ struct FP32Vec16 : public Vec {
explicit FP32Vec16(f32x16_t data) : reg(data) {};
FP32Vec16(const FP32Vec4& data) {
- for (int i = 0; i < VEC_ELEM_NUM; ++i) {
+ unroll_loop([&data, this](int i) {
reg.val[i] = data.reg.val[i % FP32Vec4::VEC_ELEM_NUM];
- }
+ });
}
FP32Vec16(const FP32Vec8& data) {
- for (int i = 0; i < VEC_ELEM_NUM; ++i) {
+ unroll_loop([&data, this](int i) {
reg.val[i] = data.reg.val[i % FP32Vec8::VEC_ELEM_NUM];
- }
+ });
}
FP32Vec16(const FP32Vec16& data) : reg(data.reg) {};
explicit FP32Vec16(const FP16Vec16& v) {
- for (int i = 0; i < VEC_ELEM_NUM; ++i) {
- reg.val[i] = fp16_to_float(v.reg.val[i]);
- }
+ unroll_loop(
+ [&v, this](int i) { reg.val[i] = fp16_to_float(v.reg.val[i]); });
}
explicit FP32Vec16(const BF16Vec16& v) {
- for (int i = 0; i < VEC_ELEM_NUM; ++i) {
- reg.val[i] = bf16_to_float(v.reg.val[i]);
- }
+ unroll_loop(
+ [&v, this](int i) { reg.val[i] = bf16_to_float(v.reg.val[i]); });
}
explicit FP32Vec16(const FP16Vec8& v) : FP32Vec16(FP32Vec8(v)) {};
@@ -331,82 +303,74 @@ struct FP32Vec16 : public Vec {
FP32Vec16(const BF16Vec8& v) : FP32Vec16(FP32Vec8(v)) {};
FP32Vec16 operator*(const FP32Vec16& b) const {
- FP32Vec16 result(0.0f);
- for (int i = 0; i < VEC_ELEM_NUM; ++i) {
- result.reg.val[i] = reg.val[i] * b.reg.val[i];
- }
- return result;
+ f32x16_t ret;
+ unroll_loop(
+ [&ret, &b, this](int i) { ret.val[i] = reg.val[i] * b.reg.val[i]; });
+ return FP32Vec16(ret);
}
FP32Vec16 operator+(const FP32Vec16& b) const {
- FP32Vec16 result(0.0f);
- for (int i = 0; i < VEC_ELEM_NUM; ++i) {
- result.reg.val[i] = reg.val[i] + b.reg.val[i];
- }
- return result;
+ f32x16_t ret;
+ unroll_loop(
+ [&ret, &b, this](int i) { ret.val[i] = reg.val[i] + b.reg.val[i]; });
+ return FP32Vec16(ret);
}
FP32Vec16 operator-(const FP32Vec16& b) const {
- FP32Vec16 result(0.0f);
- for (int i = 0; i < VEC_ELEM_NUM; ++i) {
- result.reg.val[i] = reg.val[i] - b.reg.val[i];
- }
- return result;
+ f32x16_t ret;
+ unroll_loop(
+ [&ret, &b, this](int i) { ret.val[i] = reg.val[i] - b.reg.val[i]; });
+ return FP32Vec16(ret);
}
FP32Vec16 operator/(const FP32Vec16& b) const {
- FP32Vec16 result(0.0f);
- for (int i = 0; i < VEC_ELEM_NUM; ++i) {
- result.reg.val[i] = reg.val[i] / b.reg.val[i];
- }
- return result;
+ f32x16_t ret;
+ unroll_loop(
+ [&ret, &b, this](int i) { ret.val[i] = reg.val[i] / b.reg.val[i]; });
+ return FP32Vec16(ret);
}
FP32Vec16 max(const FP32Vec16& b) const {
- FP32Vec16 result(0.0f);
- for (int i = 0; i < VEC_ELEM_NUM; ++i) {
- result.reg.val[i] = __max(reg.val[i], b.reg.val[i]);
- }
- return result;
+ f32x16_t ret;
+ unroll_loop([&ret, &b, this](int i) {
+ ret.val[i] = std::max(reg.val[i], b.reg.val[i]);
+ });
+ return FP32Vec16(ret);
}
FP32Vec16 min(const FP32Vec16& b) const {
- FP32Vec16 result(0.0f);
- for (int i = 0; i < VEC_ELEM_NUM; ++i) {
- result.reg.val[i] = __min(reg.val[i], b.reg.val[i]);
- }
- return result;
+ f32x16_t ret;
+ unroll_loop([&ret, &b, this](int i) {
+ ret.val[i] = std::min(reg.val[i], b.reg.val[i]);
+ });
+ return FP32Vec16(ret);
}
FP32Vec16 abs() const {
- FP32Vec16 result(0.0f);
- for (int i = 0; i < VEC_ELEM_NUM; ++i) {
- result.reg.val[i] = __abs(reg.val[i]);
- }
- return result;
+ f32x16_t ret;
+ unroll_loop(
+ [&ret, this](int i) { ret.val[i] = std::abs(reg.val[i]); });
+ return FP32Vec16(ret);
}
float reduce_sum() const {
float result = 0.0f;
- for (int i = 0; i < VEC_ELEM_NUM; ++i) {
- result += reg.val[i];
- }
+ unroll_loop(
+ [&result, this](int i) { result += reg.val[i]; });
return result;
}
float reduce_max() const {
- float result = reg.val[0];
- for (int i = 0; i < VEC_ELEM_NUM; ++i) {
- result = __max(reg.val[i], result);
- }
+ float result = std::numeric_limits::lowest();
+ unroll_loop(
+ [&result, this](int i) { result = std::max(reg.val[i], result); });
return result;
}
float reduce_min() const {
- float result = reg.val[0];
- for (int i = 0; i < VEC_ELEM_NUM; ++i) {
- result = __min(reg.val[i], result);
- }
+ float result = std::numeric_limits::max();
+ unroll_loop(
+ [&result, this](int i) { result = std::min(reg.val[i], result); });
return result;
}
@@ -414,13 +378,9 @@ struct FP32Vec16 : public Vec {
float reduce_sub_sum(int idx) {
static_assert(VEC_ELEM_NUM % group_size == 0);
float sum = 0.0;
- int start = idx * group_size;
- int end = (idx + 1) * group_size;
-
- for (; (start < VEC_ELEM_NUM) && (start < end); ++start) {
- sum += reg.val[start];
- }
-
+ const int start = idx * group_size;
+ unroll_loop(
+ [&sum, &start, this](int i) { sum += reg.val[start + i]; });
return sum;
}
@@ -477,17 +437,13 @@ inline void storeFP32(float v, c10::BFloat16* ptr) {
}
inline FP16Vec16::FP16Vec16(const FP32Vec16& v) {
- int i = 0;
- for (i = 0; i < FP16Vec16::VEC_ELEM_NUM; ++i) {
- reg.val[i] = float_to_fp16(v.reg.val[i]);
- }
+ unroll_loop(
+ [&v, this](int i) { reg.val[i] = float_to_fp16(v.reg.val[i]); });
}
inline FP16Vec8 ::FP16Vec8(const FP32Vec8& v) {
- int i = 0;
- for (i = 0; i < FP16Vec8::VEC_ELEM_NUM; ++i) {
- reg.val[i] = float_to_fp16(v.reg.val[i]);
- }
+ unroll_loop(
+ [&v, this](int i) { reg.val[i] = float_to_fp16(v.reg.val[i]); });
}
inline void fma(FP32Vec16& acc, FP32Vec16& a, FP32Vec16& b) {
@@ -495,17 +451,13 @@ inline void fma(FP32Vec16& acc, FP32Vec16& a, FP32Vec16& b) {
}
inline BF16Vec8::BF16Vec8(const FP32Vec8& v) {
- int i = 0;
- for (i = 0; i < BF16Vec8::VEC_ELEM_NUM; ++i) {
- reg.val[i] = float_to_bf16(v.reg.val[i]);
- }
+ unroll_loop(
+ [&v, this](int i) { reg.val[i] = float_to_bf16(v.reg.val[i]); });
}
inline BF16Vec16::BF16Vec16(const FP32Vec16& v) {
- int i = 0;
- for (i = 0; i < BF16Vec16::VEC_ELEM_NUM; ++i) {
- reg.val[i] = float_to_bf16(v.reg.val[i]);
- }
+ unroll_loop(
+ [&v, this](int i) { reg.val[i] = float_to_bf16(v.reg.val[i]); });
}
inline void prefetch(const void* addr) { __builtin_prefetch(addr, 0, 3); }
diff --git a/csrc/cpu/cpu_types_vxe.hpp b/csrc/cpu/cpu_types_vxe.hpp
index 51bca37e699b..9efd8b7ec14a 100644
--- a/csrc/cpu/cpu_types_vxe.hpp
+++ b/csrc/cpu/cpu_types_vxe.hpp
@@ -4,6 +4,7 @@
#include
#include
+#include
#include
namespace vec_op {
@@ -174,8 +175,9 @@ struct FP32Vec8 : public Vec {
}
explicit FP32Vec8(const BF16Vec8& v) {
- reg.val[0] = (__vector float)vec_mergeh(zero, v.reg);
- reg.val[1] = (__vector float)vec_mergel(zero, v.reg);
+ // On big-endian s390x, place BF16 first to get correct byte order
+ reg.val[0] = (__vector float)vec_mergeh(v.reg, zero);
+ reg.val[1] = (__vector float)vec_mergel(v.reg, zero);
}
float reduce_sum() const {
@@ -189,51 +191,257 @@ struct FP32Vec8 : public Vec {
}
FP32Vec8 exp() const {
- // TODO: Vectorize this
- AliasReg ar;
- ar.reg = reg;
- f32x4x4_t ret;
- ret.val[0][0] = std::exp(ar.values[0]);
- ret.val[0][1] = std::exp(ar.values[1]);
- ret.val[0][2] = std::exp(ar.values[2]);
- ret.val[0][3] = std::exp(ar.values[3]);
- ret.val[1][0] = std::exp(ar.values[4]);
- ret.val[1][1] = std::exp(ar.values[5]);
- ret.val[1][2] = std::exp(ar.values[6]);
- ret.val[1][3] = std::exp(ar.values[7]);
- return FP32Vec8(f32x4x2_t({ret.val[0], ret.val[1]}));
+ f32x4x2_t out;
+
+ const __vector float log2e = vec_splats(1.44269504088896341f);
+ const __vector float one = vec_splats(1.0f);
+ const __vector float min_x = vec_splats(-87.3f);
+ const __vector float max_x = vec_splats(88.7f);
+
+ // 5th-degree minimax polynomial for 2^r (r in [0,1))
+ const __vector float c1 = vec_splats(0.6931471805599453f);
+ const __vector float c2 = vec_splats(0.240226506959101f);
+ const __vector float c3 = vec_splats(0.05550410866482158f);
+ const __vector float c4 = vec_splats(0.009618129107628477f);
+ const __vector float c5 = vec_splats(0.0013333558146428443f);
+
+ for (int i = 0; i < 2; i++) {
+ __vector float x = reg.val[i];
+
+ x = vec_max(x, min_x);
+ x = vec_min(x, max_x);
+
+ __vector float y = vec_mul(x, log2e);
+
+ __vector float kf = vec_floor(y);
+ __vector float r = vec_sub(y, kf);
+
+ __vector signed int k = vec_signed(kf);
+ const __vector signed int min_k = vec_splats((signed int)-126);
+ const __vector signed int max_k = vec_splats((signed int)127);
+ k = vec_min(vec_max(k, min_k), max_k);
+
+ // Build 2^k from exponent bits
+ __vector signed int exp_int = vec_add(k, vec_splats((signed int)127));
+ __vector unsigned int bits = (__vector unsigned int)exp_int;
+ bits = vec_sl(bits, vec_splats((unsigned int)23));
+ __vector float pow2k = (__vector float)bits;
+
+ // Improved minimax polynomial
+ __vector float poly = vec_madd(c5, r, c4);
+ poly = vec_madd(poly, r, c3);
+ poly = vec_madd(poly, r, c2);
+ poly = vec_madd(poly, r, c1);
+ poly = vec_madd(poly, r, one);
+
+ out.val[i] = vec_mul(pow2k, poly);
+ }
+
+ return FP32Vec8(out);
}
FP32Vec8 tanh() const {
- // TODO: Vectorize this
- AliasReg ar;
- ar.reg = reg;
- f32x4x4_t ret;
- ret.val[0][0] = std::tanh(ar.values[0]);
- ret.val[0][1] = std::tanh(ar.values[1]);
- ret.val[0][2] = std::tanh(ar.values[2]);
- ret.val[0][3] = std::tanh(ar.values[3]);
- ret.val[1][0] = std::tanh(ar.values[4]);
- ret.val[1][1] = std::tanh(ar.values[5]);
- ret.val[1][2] = std::tanh(ar.values[6]);
- ret.val[1][3] = std::tanh(ar.values[7]);
- return FP32Vec8(f32x4x2_t({ret.val[0], ret.val[1]}));
+ // tanh(x) = (exp(2x) - 1) / (exp(2x) + 1)
+ const __vector float one = vec_splats(1.0f);
+ const __vector float two = vec_splats(2.0f);
+ const __vector float zero = vec_splats(0.0f);
+ const __vector float sat =
+ vec_splats(9.0f); // beyond this, tanh(x) ~ sign(x)
+
+ f32x4x2_t out;
+
+ for (int i = 0; i < 2; i++) {
+ __vector float x = reg.val[i];
+ __vector float ax = vec_abs(x);
+
+ // sign(x): +1 or -1
+ __vector float sign = vec_sel(vec_splats(-1.0f), one, vec_cmpgt(x, zero));
+
+ // saturation mask: |x| > sat
+ __vector __bool int saturated = vec_cmpgt(ax, sat);
+
+ // 2x
+ __vector float two_x = vec_mul(x, two);
+
+ // Build a temporary FP32Vec8 with both lanes = 2x, reuse exp()
+ f32x4x2_t tmp;
+ tmp.val[0] = two_x;
+ tmp.val[1] = two_x;
+ FP32Vec8 exp_2x_vec(tmp);
+
+ FP32Vec8 e2x = exp_2x_vec.exp();
+ __vector float e = e2x.reg.val[i];
+
+ // tanh(x) = (e - 1) / (e + 1)
+ __vector float num = vec_sub(e, one);
+ __vector float den = vec_add(e, one);
+
+ __vector float t = vec_div(num, den);
+
+ // For large |x|, clamp to sign(x)
+ out.val[i] = vec_sel(t, sign, saturated);
+ }
+
+ return FP32Vec8(out);
}
FP32Vec8 er() const {
- // TODO: Vectorize this
- AliasReg ar;
- ar.reg = reg;
- f32x4x4_t ret;
- ret.val[0][0] = std::erf(ar.values[0]);
- ret.val[0][1] = std::erf(ar.values[1]);
- ret.val[0][2] = std::erf(ar.values[2]);
- ret.val[0][3] = std::erf(ar.values[3]);
- ret.val[1][0] = std::erf(ar.values[4]);
- ret.val[1][1] = std::erf(ar.values[5]);
- ret.val[1][2] = std::erf(ar.values[6]);
- ret.val[1][3] = std::erf(ar.values[7]);
- return FP32Vec8(f32x4x2_t({ret.val[0], ret.val[1]}));
+ // A&S 7.1.26 approximation:
+ // erf(x) = sign(x) * (1 - ((((a5*t + a4)*t + a3)*t + a2)*t + a1) * t *
+ // exp(-x^2)) t = 1 / (1 + p*|x|), p = 0.3275911
+
+ const __vector float one = vec_splats(1.0f);
+ const __vector float zero = vec_splats(0.0f);
+ const __vector float p = vec_splats(0.3275911f);
+
+ // Polynomial coeffs
+ const __vector float a1 = vec_splats(0.254829592f);
+ const __vector float a2 = vec_splats(-0.284496736f);
+ const __vector float a3 = vec_splats(1.421413741f);
+ const __vector float a4 = vec_splats(-1.453152027f);
+ const __vector float a5 = vec_splats(1.061405429f);
+
+ // Threshold where erf(x) ~ sign(x)
+ const __vector float sat = vec_splats(6.0f);
+
+ f32x4x2_t out;
+
+ for (int lane = 0; lane < 2; lane++) {
+ __vector float x = reg.val[lane];
+ __vector float ax = vec_abs(x);
+
+ // sign(x)
+ __vector float sign = vec_sel(vec_splats(-1.0f), one, vec_cmpgt(x, zero));
+
+ // |x| > 6 → erf(x) = ±1
+ __vector __bool int saturated = vec_cmpgt(ax, sat);
+
+ // t = 1 / (1 + p * |x|)
+ __vector float t = vec_madd(p, ax, one);
+ t = vec_div(one, t);
+
+ // poly = a5
+ __vector float poly = a5;
+ poly = vec_madd(poly, t, a4);
+ poly = vec_madd(poly, t, a3);
+ poly = vec_madd(poly, t, a2);
+ poly = vec_madd(poly, t, a1);
+
+ // full polynomial: poly = poly * t
+ poly = vec_mul(poly, t);
+
+ // Compute exp(-x^2)
+ __vector float x2 = vec_mul(x, x);
+ __vector float neg_x2 = vec_neg(x2);
+
+ f32x4x2_t tmp;
+ tmp.val[0] = neg_x2;
+ tmp.val[1] = neg_x2;
+ FP32Vec8 exp_neg_x2(tmp);
+
+ FP32Vec8 e = exp_neg_x2.exp();
+ __vector float ex = e.reg.val[lane];
+
+ // erf(x) = sign * (1 - poly * exp(-x^2))
+ __vector float term = vec_mul(poly, ex);
+ __vector float y = vec_sub(one, term);
+ y = vec_mul(y, sign);
+
+ // saturated → ±1
+ __vector float sat_val = vec_mul(sign, one);
+ out.val[lane] = vec_sel(y, sat_val, saturated);
+ }
+
+ return FP32Vec8(out);
+ }
+ // Elementwise sigmoid(x) = 1 / (1 + exp(-x))
+ FP32Vec8 sigmoid() const {
+ const __vector float one = vec_splats(1.0f);
+
+ f32x4x2_t neg;
+ for (int i = 0; i < 2; ++i) {
+ neg.val[i] = vec_neg(reg.val[i]);
+ }
+
+ FP32Vec8 neg_x(neg);
+ FP32Vec8 e = neg_x.exp(); // exp(-x)
+
+ f32x4x2_t denom;
+ for (int i = 0; i < 2; ++i) {
+ denom.val[i] = vec_add(one, e.reg.val[i]);
+ }
+
+ FP32Vec8 denom_vec(denom);
+ FP32Vec8 one_vec(1.0f);
+
+ return one_vec / denom_vec;
+ }
+
+ // Tanh-based GELU:
+ // gelu(x) = 0.5 * x * (1 + tanh(√(2/π) * (x + 0.044715 * x^3)))
+ FP32Vec8 gelu_tanh() const {
+ const __vector float k_s2pi = vec_splats(0.7978845608028654f); // √(2/π)
+ const __vector float k_0_0447 = vec_splats(0.044715f);
+
+ f32x4x2_t x2, x3, inner;
+ for (int i = 0; i < 2; ++i) {
+ __vector float x = reg.val[i];
+ x2.val[i] = vec_mul(x, x); // x^2
+ x3.val[i] = vec_mul(x2.val[i], x); // x^3
+ __vector float t = vec_madd(k_0_0447, x3.val[i], x); // x + 0.044715*x^3
+ inner.val[i] = vec_mul(k_s2pi, t); // √(2/π)*(...)
+ }
+
+ FP32Vec8 inner_vec(inner);
+ FP32Vec8 t = inner_vec.tanh(); // tanh part
+
+ FP32Vec8 one_vec(1.0f);
+ FP32Vec8 half_vec(0.5f);
+
+ FP32Vec8 x_vec(*this);
+ return x_vec * half_vec * (one_vec + t);
+ }
+
+ // Erf-based GELU:
+ // gelu(x) = 0.5 * x * (1 + erf(x / √2))
+ FP32Vec8 gelu_erf() const {
+ const __vector float inv_sqrt2 = vec_splats(0.7071067811865476f); // 1/√2
+ FP32Vec8 x_vec(*this);
+
+ f32x4x2_t scaled;
+ for (int i = 0; i < 2; ++i) {
+ scaled.val[i] = vec_mul(reg.val[i], inv_sqrt2);
+ }
+ FP32Vec8 x_scaled(scaled);
+
+ FP32Vec8 erf_x = x_scaled.er();
+
+ FP32Vec8 one_vec(1.0f);
+ FP32Vec8 half_vec(0.5f);
+
+ return x_vec * half_vec * (one_vec + erf_x);
+ }
+
+ // Elementwise reciprocal: 1/x (scalar per lane, for correctness)
+ FP32Vec8 rcp() const {
+ AliasReg in, out;
+ in.reg = reg;
+
+ for (int i = 0; i < VEC_ELEM_NUM; ++i) {
+ out.values[i] = 1.0f / in.values[i];
+ }
+ return FP32Vec8(out.reg);
+ }
+
+ // Elementwise rsqrt(x) = 1 / sqrt(x) (scalar per lane, for correctness)
+ FP32Vec8 rsqrt() const {
+ AliasReg in, out;
+ in.reg = reg;
+
+ for (int i = 0; i < VEC_ELEM_NUM; ++i) {
+ out.values[i] = 1.0f / std::sqrt(in.values[i]);
+ }
+ return FP32Vec8(out.reg);
}
FP32Vec8 operator*(const FP32Vec8& b) const {
@@ -316,10 +524,11 @@ struct FP32Vec16 : public Vec {
}
explicit FP32Vec16(const BF16Vec16& v) {
- reg.val[0] = (__vector float)vec_mergeh(zero, v.reg.val[0]);
- reg.val[1] = (__vector float)vec_mergel(zero, v.reg.val[0]);
- reg.val[2] = (__vector float)vec_mergeh(zero, v.reg.val[1]);
- reg.val[3] = (__vector float)vec_mergel(zero, v.reg.val[1]);
+ // On big-endian s390x, place BF16 first to get correct byte order
+ reg.val[0] = (__vector float)vec_mergeh(v.reg.val[0], zero);
+ reg.val[1] = (__vector float)vec_mergel(v.reg.val[0], zero);
+ reg.val[2] = (__vector float)vec_mergeh(v.reg.val[1], zero);
+ reg.val[3] = (__vector float)vec_mergel(v.reg.val[1], zero);
}
explicit FP32Vec16(const BF16Vec8& v) : FP32Vec16(FP32Vec8(v)) {}
@@ -376,6 +585,23 @@ struct FP32Vec16 : public Vec {
return result;
}
+ FP32Vec16 max(const FP32Vec16& b) const {
+ return FP32Vec16(f32x4x4_t({vec_max(reg.val[0], b.reg.val[0]),
+ vec_max(reg.val[1], b.reg.val[1]),
+ vec_max(reg.val[2], b.reg.val[2]),
+ vec_max(reg.val[3], b.reg.val[3])}));
+ }
+
+ float reduce_max() const {
+ AliasReg ar;
+ ar.reg = reg;
+ float result = ar.values[0];
+ unroll_loop([&result, &ar](int i) {
+ if (ar.values[i] > result) result = ar.values[i];
+ });
+ return result;
+ }
+
void save(float* ptr) const {
vec_xst(reg.val[0], 0, ptr);
vec_xst(reg.val[1], 16, ptr);
@@ -402,15 +628,14 @@ struct VecType {
using vec_type = BF16Vec8;
};
+// On s390x, FP16 (Half) is not natively supported, use FP32 vectors instead
+using FP16Vec16 = FP32Vec16;
+
template
void storeFP32(float v, T* ptr) {
*ptr = v;
}
-inline void fma(FP32Vec16& acc, FP32Vec16& a, FP32Vec16& b) {
- acc = acc + a * b;
-}
-
namespace c10 {
struct BFloat16 {
uint16_t value; // Assume BFloat16 is defined as a struct containing a 16-bit
@@ -429,6 +654,79 @@ inline void storeFP32(float v, c10::BFloat16* ptr) {
#define __VEC_CLASS_FP_NAN (1 << 6)
#endif
+// Optimized FMA (Fused Multiply-Add) implementations using IBM Z vector
+// intrinsics
+
+// FP32Vec4 FMA: acc = acc + (a * b) or equivalently acc = fma(a, b, acc)
+FORCE_INLINE void fma(FP32Vec4& acc, const FP32Vec4& a, const FP32Vec4& b) {
+ acc.reg = vec_madd(a.reg, b.reg, acc.reg);
+}
+
+// FP32Vec8 FMA: acc = acc + (a * b)
+FORCE_INLINE void fma(FP32Vec8& acc, const FP32Vec8& a, const FP32Vec8& b) {
+ acc.reg.val[0] = vec_madd(a.reg.val[0], b.reg.val[0], acc.reg.val[0]);
+ acc.reg.val[1] = vec_madd(a.reg.val[1], b.reg.val[1], acc.reg.val[1]);
+}
+
+// FP32Vec16 FMA: acc = acc + (a * b)
+FORCE_INLINE void fma(FP32Vec16& acc, const FP32Vec16& a, const FP32Vec16& b) {
+ acc.reg.val[0] = vec_madd(a.reg.val[0], b.reg.val[0], acc.reg.val[0]);
+ acc.reg.val[1] = vec_madd(a.reg.val[1], b.reg.val[1], acc.reg.val[1]);
+ acc.reg.val[2] = vec_madd(a.reg.val[2], b.reg.val[2], acc.reg.val[2]);
+ acc.reg.val[3] = vec_madd(a.reg.val[3], b.reg.val[3], acc.reg.val[3]);
+}
+
+// Multiply-Subtract: acc = acc - (a * b)
+FORCE_INLINE void fms(FP32Vec4& acc, const FP32Vec4& a, const FP32Vec4& b) {
+ acc.reg = vec_msub(a.reg, b.reg, acc.reg);
+}
+
+FORCE_INLINE void fms(FP32Vec8& acc, const FP32Vec8& a, const FP32Vec8& b) {
+ acc.reg.val[0] = vec_msub(a.reg.val[0], b.reg.val[0], acc.reg.val[0]);
+ acc.reg.val[1] = vec_msub(a.reg.val[1], b.reg.val[1], acc.reg.val[1]);
+}
+
+FORCE_INLINE void fms(FP32Vec16& acc, const FP32Vec16& a, const FP32Vec16& b) {
+ acc.reg.val[0] = vec_msub(a.reg.val[0], b.reg.val[0], acc.reg.val[0]);
+ acc.reg.val[1] = vec_msub(a.reg.val[1], b.reg.val[1], acc.reg.val[1]);
+ acc.reg.val[2] = vec_msub(a.reg.val[2], b.reg.val[2], acc.reg.val[2]);
+ acc.reg.val[3] = vec_msub(a.reg.val[3], b.reg.val[3], acc.reg.val[3]);
+}
+
+// Negative Multiply-Add: acc = -(a * b) + acc
+FORCE_INLINE void nfma(FP32Vec4& acc, const FP32Vec4& a, const FP32Vec4& b) {
+ acc.reg = vec_nmadd(a.reg, b.reg, acc.reg);
+}
+
+FORCE_INLINE void nfma(FP32Vec8& acc, const FP32Vec8& a, const FP32Vec8& b) {
+ acc.reg.val[0] = vec_nmadd(a.reg.val[0], b.reg.val[0], acc.reg.val[0]);
+ acc.reg.val[1] = vec_nmadd(a.reg.val[1], b.reg.val[1], acc.reg.val[1]);
+}
+
+FORCE_INLINE void nfma(FP32Vec16& acc, const FP32Vec16& a, const FP32Vec16& b) {
+ acc.reg.val[0] = vec_nmadd(a.reg.val[0], b.reg.val[0], acc.reg.val[0]);
+ acc.reg.val[1] = vec_nmadd(a.reg.val[1], b.reg.val[1], acc.reg.val[1]);
+ acc.reg.val[2] = vec_nmadd(a.reg.val[2], b.reg.val[2], acc.reg.val[2]);
+ acc.reg.val[3] = vec_nmadd(a.reg.val[3], b.reg.val[3], acc.reg.val[3]);
+}
+
+// Negative Multiply-Subtract: acc = -(a * b) - acc
+FORCE_INLINE void nfms(FP32Vec4& acc, const FP32Vec4& a, const FP32Vec4& b) {
+ acc.reg = vec_nmsub(a.reg, b.reg, acc.reg);
+}
+
+FORCE_INLINE void nfms(FP32Vec8& acc, const FP32Vec8& a, const FP32Vec8& b) {
+ acc.reg.val[0] = vec_nmsub(a.reg.val[0], b.reg.val[0], acc.reg.val[0]);
+ acc.reg.val[1] = vec_nmsub(a.reg.val[1], b.reg.val[1], acc.reg.val[1]);
+}
+
+FORCE_INLINE void nfms(FP32Vec16& acc, const FP32Vec16& a, const FP32Vec16& b) {
+ acc.reg.val[0] = vec_nmsub(a.reg.val[0], b.reg.val[0], acc.reg.val[0]);
+ acc.reg.val[1] = vec_nmsub(a.reg.val[1], b.reg.val[1], acc.reg.val[1]);
+ acc.reg.val[2] = vec_nmsub(a.reg.val[2], b.reg.val[2], acc.reg.val[2]);
+ acc.reg.val[3] = vec_nmsub(a.reg.val[3], b.reg.val[3], acc.reg.val[3]);
+}
+
const static __vector unsigned char omask = {2, 3, 6, 7, 10, 11, 14, 15,
18, 19, 22, 23, 26, 27, 30, 31};
const static __vector unsigned int bias = {0x00007fff, 0x00007fff, 0x00007fff,
@@ -441,13 +739,24 @@ const static __vector unsigned int one = {1, 1, 1, 1};
inline BF16Vec8::BF16Vec8(const FP32Vec8& v) {
__vector unsigned int inp0 = (__vector unsigned int)(v.reg.val[0]);
__vector unsigned int inp1 = (__vector unsigned int)(v.reg.val[1]);
+ __vector unsigned int lsb0 = inp0 >> sh16;
+ __vector unsigned int lsb1 = inp1 >> sh16;
+ lsb0 = lsb0 & one;
+ lsb1 = lsb1 & one;
+ __vector unsigned int rnd0 = lsb0 + bias;
+ __vector unsigned int rnd1 = lsb1 + bias;
+ inp0 = inp0 + rnd0;
+ inp1 = inp1 + rnd1;
int cc;
__vector __bool int sel0 =
vec_fp_test_data_class(v.reg.val[0], __VEC_CLASS_FP_NAN, &cc);
__vector __bool int sel1 =
vec_fp_test_data_class(v.reg.val[1], __VEC_CLASS_FP_NAN, &cc);
- inp0 = vec_sel(inp0, nan, sel0) >> sh16;
- inp1 = vec_sel(inp1, nan, sel1) >> sh16;
+ inp0 = vec_sel(inp0, nan, sel0);
+ inp1 = vec_sel(inp1, nan, sel1);
+ inp0 = inp0 >> sh16;
+ inp1 = inp1 >> sh16;
+
reg = (__vector signed short)vec_perm(inp0, inp1, omask);
}
@@ -456,6 +765,22 @@ inline BF16Vec16::BF16Vec16(const FP32Vec16& v) {
__vector unsigned int inp1 = (__vector unsigned int)(v.reg.val[1]);
__vector unsigned int inp2 = (__vector unsigned int)(v.reg.val[2]);
__vector unsigned int inp3 = (__vector unsigned int)(v.reg.val[3]);
+ __vector unsigned int lsb0 = inp0 >> sh16;
+ __vector unsigned int lsb1 = inp1 >> sh16;
+ __vector unsigned int lsb2 = inp2 >> sh16;
+ __vector unsigned int lsb3 = inp3 >> sh16;
+ lsb0 = lsb0 & one;
+ lsb1 = lsb1 & one;
+ lsb2 = lsb2 & one;
+ lsb3 = lsb3 & one;
+ __vector unsigned int rnd0 = lsb0 + bias;
+ __vector unsigned int rnd1 = lsb1 + bias;
+ __vector unsigned int rnd2 = lsb2 + bias;
+ __vector unsigned int rnd3 = lsb3 + bias;
+ inp0 = inp0 + rnd0;
+ inp1 = inp1 + rnd1;
+ inp2 = inp2 + rnd2;
+ inp3 = inp3 + rnd3;
int cc;
__vector __bool int sel0 =
vec_fp_test_data_class(v.reg.val[0], __VEC_CLASS_FP_NAN, &cc);
@@ -465,15 +790,164 @@ inline BF16Vec16::BF16Vec16(const FP32Vec16& v) {
vec_fp_test_data_class(v.reg.val[2], __VEC_CLASS_FP_NAN, &cc);
__vector __bool int sel3 =
vec_fp_test_data_class(v.reg.val[3], __VEC_CLASS_FP_NAN, &cc);
- inp0 = vec_sel(inp0, nan, sel0) >> sh16;
- inp1 = vec_sel(inp1, nan, sel1) >> sh16;
- inp2 = vec_sel(inp2, nan, sel2) >> sh16;
- inp3 = vec_sel(inp3, nan, sel3) >> sh16;
+ inp0 = vec_sel(inp0, nan, sel0);
+ inp1 = vec_sel(inp1, nan, sel1);
+ inp2 = vec_sel(inp2, nan, sel2);
+ inp3 = vec_sel(inp3, nan, sel3);
+ inp0 = inp0 >> sh16;
+ inp1 = inp1 >> sh16;
+ inp2 = inp2 >> sh16;
+ inp3 = inp3 >> sh16;
+
reg.val[0] = (__vector signed short)vec_perm(inp0, inp1, omask);
reg.val[1] = (__vector signed short)vec_perm(inp2, inp3, omask);
}
-inline void prefetch(const void* addr) { void __dcbt(const void* addr); }
+// 1D softmax over `n` elements in `input`, writes result to `output`.
+// Uses FP32Vec8 for main body, scalar tail handling.
+// Requirement: n > 0
+FORCE_INLINE void softmax_fp32vec8(float* output, const float* input, int n) {
+ if (n <= 0) return;
+
+ // ---------- Pass 1: find max ----------
+ float max_val = -std::numeric_limits::infinity();
+ int i = 0;
+
+ for (; i + FP32Vec8::VEC_ELEM_NUM <= n; i += FP32Vec8::VEC_ELEM_NUM) {
+ FP32Vec8 v(input + i);
+ FP32Vec8::AliasReg ar;
+ ar.reg = v.reg;
+ for (int j = 0; j < FP32Vec8::VEC_ELEM_NUM; ++j) {
+ if (ar.values[j] > max_val) max_val = ar.values[j];
+ }
+ }
+ for (; i < n; ++i) {
+ if (input[i] > max_val) max_val = input[i];
+ }
+
+ // ---------- Pass 2: compute exp(x - max) and sum ----------
+ float sum = 0.0f;
+ i = 0;
+
+ for (; i + FP32Vec8::VEC_ELEM_NUM <= n; i += FP32Vec8::VEC_ELEM_NUM) {
+ float tmp[FP32Vec8::VEC_ELEM_NUM];
+ for (int j = 0; j < FP32Vec8::VEC_ELEM_NUM; ++j) {
+ tmp[j] = input[i + j] - max_val;
+ }
+
+ FP32Vec8 v(tmp);
+ FP32Vec8 e = v.exp();
+
+ FP32Vec8::AliasReg ar;
+ ar.reg = e.reg;
+ for (int j = 0; j < FP32Vec8::VEC_ELEM_NUM; ++j) {
+ output[i + j] = ar.values[j];
+ sum += ar.values[j];
+ }
+ }
+
+ // Tail
+ for (; i < n; ++i) {
+ float x = input[i] - max_val;
+ float ex = std::exp(x); // scalar tail
+ output[i] = ex;
+ sum += ex;
+ }
+
+ // ---------- Pass 3: normalize ----------
+ float inv_sum = 1.0f / sum;
+ i = 0;
+
+ for (; i + FP32Vec8::VEC_ELEM_NUM <= n; i += FP32Vec8::VEC_ELEM_NUM) {
+ float tmp[FP32Vec8::VEC_ELEM_NUM];
+ for (int j = 0; j < FP32Vec8::VEC_ELEM_NUM; ++j) {
+ tmp[j] = output[i + j] * inv_sum;
+ }
+ FP32Vec8 v(tmp);
+ v.save(output + i);
+ }
+
+ for (; i < n; ++i) {
+ output[i] *= inv_sum;
+ }
+}
+
+// 1D RMSNorm kernel:
+// input: x[0..n-1]
+// weight: w[0..n-1] (gamma), may be nullptr
+// output: y[i] = x[i] * inv_rms * (weight[i] if weight != nullptr else 1)
+// eps: small epsilon for numerical stability
+FORCE_INLINE void rmsnorm_fp32vec8(float* output, const float* input,
+ const float* weight, int n, float eps) {
+ if (n <= 0) return;
+
+ // ---------- Pass 1: compute sum of squares ----------
+ float sum_sq = 0.0f;
+ int i = 0;
+
+ for (; i + FP32Vec8::VEC_ELEM_NUM <= n; i += FP32Vec8::VEC_ELEM_NUM) {
+ FP32Vec8 x_vec(input + i);
+
+ FP32Vec8 sq = x_vec * x_vec;
+
+ FP32Vec8::AliasReg ar;
+ ar.reg = sq.reg;
+ for (int j = 0; j < FP32Vec8::VEC_ELEM_NUM; ++j) {
+ sum_sq += ar.values[j];
+ }
+ }
+
+ // Tail
+ for (; i < n; ++i) {
+ float v = input[i];
+ sum_sq += v * v;
+ }
+
+ float mean_sq = sum_sq / static_cast(n);
+ float inv_rms = 1.0f / std::sqrt(mean_sq + eps);
+
+ // ---------- Pass 2: scale (and apply weight if given) ----------
+ const float inv_rms_f = inv_rms;
+ i = 0;
+
+ if (weight) {
+ // with gamma
+ for (; i + FP32Vec8::VEC_ELEM_NUM <= n; i += FP32Vec8::VEC_ELEM_NUM) {
+ FP32Vec8 x_vec(input + i);
+
+ float wtmp[FP32Vec8::VEC_ELEM_NUM];
+ for (int j = 0; j < FP32Vec8::VEC_ELEM_NUM; ++j) {
+ wtmp[j] = weight[i + j];
+ }
+ FP32Vec8 w_vec(wtmp);
+
+ FP32Vec8 scale_vec(inv_rms_f);
+ FP32Vec8 y = x_vec * scale_vec * w_vec;
+ y.save(output + i);
+ }
+
+ for (; i < n; ++i) {
+ output[i] = input[i] * inv_rms_f * weight[i];
+ }
+ } else {
+ // without gamma
+ for (; i + FP32Vec8::VEC_ELEM_NUM <= n; i += FP32Vec8::VEC_ELEM_NUM) {
+ FP32Vec8 x_vec(input + i);
+ FP32Vec8 scale_vec(inv_rms_f);
+ FP32Vec8 y = x_vec * scale_vec;
+ y.save(output + i);
+ }
+
+ for (; i < n; ++i) {
+ output[i] = input[i] * inv_rms_f;
+ }
+ }
+}
+
+// Prefetch data to cache for better memory access performance
+FORCE_INLINE void prefetch(const void* addr) {
+ __builtin_prefetch(addr, 0, 3); // 0=read, 3=high temporal locality
+}
}; // namespace vec_op
diff --git a/csrc/cpu/cpu_types_x86.hpp b/csrc/cpu/cpu_types_x86.hpp
index 7ddf028e6e13..6f51277f7844 100644
--- a/csrc/cpu/cpu_types_x86.hpp
+++ b/csrc/cpu/cpu_types_x86.hpp
@@ -104,6 +104,8 @@ struct FP16Vec16 : public Vec {
explicit FP16Vec16(bool, void* ptr)
: reg(_mm256_stream_load_si256((__m256i*)ptr)) {}
+ explicit FP16Vec16(const c10::Half v) : reg(_mm256_set1_epi16(v.x)) {}
+
explicit FP16Vec16(const FP32Vec16&);
void save(void* ptr) const { _mm256_storeu_si256((__m256i*)ptr, reg); }
@@ -141,6 +143,8 @@ struct BF16Vec16 : public Vec {
explicit BF16Vec16(bool, void* ptr)
: reg(_mm256_stream_load_si256((__m256i*)ptr)) {}
+ explicit BF16Vec16(const c10::BFloat16 v) : reg(_mm256_set1_epi16(v.x)) {}
+
explicit BF16Vec16(const FP32Vec16&);
void save(void* ptr) const { _mm256_storeu_si256((__m256i*)ptr, reg); }
@@ -350,6 +354,22 @@ struct FP32Vec16 : public Vec {
explicit FP32Vec16(__m512 data) : reg(data) {}
+ // de-pack 4 bit values
+ explicit FP32Vec16(int64_t value, const FP32Vec16& lut) {
+ int64_t mask_0 = 0x0F0F0F0F0F0F0F0F;
+ int64_t mask_1 = 0xF0F0F0F0F0F0F0F0;
+ int64_t value_0 = value & mask_0;
+ int64_t value_1 = value & mask_1;
+ __m128i vec_0 = _mm_movpi64_epi64((__m64)value_0);
+ __m128i vec_1 = _mm_movpi64_epi64((__m64)value_1);
+ vec_0 = _mm_cvtepu8_epi16(vec_0);
+ vec_1 = _mm_cvtepu8_epi16(vec_1);
+ vec_1 = _mm_slli_epi16(vec_1, 4);
+ __m128i vec = _mm_or_si128(vec_0, vec_1);
+ __m512i vec_i32 = _mm512_cvtepu8_epi32(vec);
+ reg = _mm512_permutexvar_ps(vec_i32, lut.reg);
+ }
+
explicit FP32Vec16(const FP32Vec4& data)
: reg((__m512)_mm512_inserti32x4(
_mm512_inserti32x4(
@@ -426,14 +446,6 @@ struct FP32Vec16 : public Vec {
float get_last_elem() const { return _mm512_cvtss_f32(reg); }
- template
- float reduce_sub_sum(int idx) {
- static_assert(VEC_ELEM_NUM % group_size == 0);
- constexpr uint32_t base_mask = (0xFFFF >> (16 - group_size));
- __mmask16 mask = _cvtu32_mask16(base_mask << (idx * group_size));
- return _mm512_mask_reduce_add_ps(mask, reg);
- }
-
void save(float* ptr) const { _mm512_storeu_ps(ptr, reg); }
void save(float* ptr, const int elem_num) const {
@@ -755,6 +767,25 @@ inline void non_temporal_save(BF16Vec16& vec, void* ptr) {
inline void non_temporal_save(FP32Vec16& vec, void* ptr) {
_mm512_stream_ps((float*)ptr, vec.reg);
}
+
+static void interleave_save(const BF16Vec16& vec0, const BF16Vec16& vec1,
+ void* ptr) {
+ __m512i vec_0 = _mm512_cvtepu16_epi32(vec0.reg);
+ __m512i vec_1 = _mm512_cvtepu16_epi32(vec1.reg);
+ vec_1 = _mm512_slli_epi32(vec_1, 16);
+ vec_0 = _mm512_or_si512(vec_0, vec_1);
+ _mm512_storeu_epi32(ptr, vec_0);
+}
+
+static void interleave_save(const FP16Vec16& vec0, const FP16Vec16& vec1,
+ void* ptr) {
+ __m512i vec_0 = _mm512_cvtepu16_epi32(vec0.reg);
+ __m512i vec_1 = _mm512_cvtepu16_epi32(vec1.reg);
+ vec_1 = _mm512_slli_epi32(vec_1, 16);
+ vec_0 = _mm512_or_si512(vec_0, vec_1);
+ _mm512_storeu_epi32(ptr, vec_0);
+}
+
#endif
inline void mem_barrier() { _mm_mfence(); }
diff --git a/csrc/cpu/cpu_wna16.cpp b/csrc/cpu/cpu_wna16.cpp
new file mode 100644
index 000000000000..816d195506e5
--- /dev/null
+++ b/csrc/cpu/cpu_wna16.cpp
@@ -0,0 +1,402 @@
+#include "cpu_types.hpp"
+#include "scratchpad_manager.h"
+#include "utils.hpp"
+
+#ifdef CPU_CAPABILITY_AMXBF16
+ #include "cpu/micro_gemm/cpu_micro_gemm_amx.hpp"
+#endif
+#include "cpu/micro_gemm/cpu_micro_gemm_vec.hpp"
+
+#define VLLM_DISPATCH_CASE_16B_TYPES(...) \
+ AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \
+ AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__)
+
+#define VLLM_DISPATCH_16B_TYPES(TYPE, NAME, ...) \
+ AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_16B_TYPES(__VA_ARGS__))
+
+template
+void print_logits(const char* name, T* ptr, int32_t row, int32_t col,
+ int32_t stride) {
+ std::stringstream ss;
+ ss << std::fixed << std::setprecision(5) << name << ": [\n";
+ auto* curr_logits_buffer = ptr;
+ for (int32_t m = 0; m < row; ++m) {
+ for (int32_t n = 0; n < col; ++n) {
+ ss << curr_logits_buffer[n] << ", ";
+ }
+ ss << "\n";
+ curr_logits_buffer += stride;
+ }
+ ss << "]\n";
+ std::printf("%s", ss.str().c_str());
+}
+
+namespace {
+using cpu_utils::ISA;
+using cpu_utils::VecTypeTrait;
+
+template
+class Dequantizer4b {
+ public:
+ constexpr static int32_t pack_num = 32 / 4;
+ using scalar_vec_t = typename VecTypeTrait::vec_t;
+
+ public:
+ static void dequant(int32_t* __restrict__ q_weight,
+ scalar_t* __restrict__ weight,
+ scalar_t* __restrict__ scales,
+ int32_t* __restrict__ zeros, int32_t* __restrict__ g_idx,
+ const int64_t scales_stride, const int64_t zeros_stride,
+ const int32_t k_size, const int32_t group_size) {
+ vec_op::FP32Vec16 lut;
+ if constexpr (has_zp) {
+ // AWQ
+ alignas(64) static const float LUT[16] = {
+ 0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f,
+ 8.0f, 9.0f, 10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f};
+ lut = vec_op::FP32Vec16(LUT);
+ } else {
+ // GPTQ
+ alignas(64) static const float LUT[16] = {
+ -8.0f, -7.0f, -6.0f, -5.0f, -4.0f, -3.0f, -2.0f, -1.0f,
+ 0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f};
+ lut = vec_op::FP32Vec16(LUT);
+ }
+
+ // per 64-bits elem contains 16 output channels
+ int64_t* __restrict__ curr_q_weight = reinterpret_cast(q_weight);
+ int64_t* __restrict__ curr_zeros = reinterpret_cast(zeros);
+ scalar_t* __restrict__ curr_weight = weight;
+ scalar_t* __restrict__ curr_scale = scales;
+ vec_op::FP32Vec16 scale_0;
+ vec_op::FP32Vec16 scale_1;
+ vec_op::FP32Vec16 zero_0;
+ vec_op::FP32Vec16 zero_1;
+ int32_t group_counter = 0;
+ for (int32_t k_idx = 0; k_idx < k_size; k_idx += 2) {
+ int64_t qwb_0 = *curr_q_weight;
+ int64_t qwb_1 = *(curr_q_weight + 1);
+ vec_op::FP32Vec16 wb_0(qwb_0, lut);
+ vec_op::FP32Vec16 wb_1(qwb_1, lut);
+
+ if constexpr (!use_desc_act) {
+ if (group_counter == 0) {
+ scale_0 = vec_op::FP32Vec16(scalar_vec_t(curr_scale));
+ scale_1 = vec_op::FP32Vec16(scale_0);
+ curr_scale += scales_stride;
+
+ if constexpr (has_zp) {
+ zero_0 = vec_op::FP32Vec16(*curr_zeros, lut);
+ zero_1 = vec_op::FP32Vec16(zero_0);
+ curr_zeros += zeros_stride / 2;
+ }
+ }
+ } else {
+ int32_t g_idx_0 = g_idx[k_idx];
+ int32_t g_idx_1 = g_idx[k_idx + 1];
+ scale_0 = vec_op::FP32Vec16(
+ scalar_vec_t(curr_scale + g_idx_0 * scales_stride));
+ scale_1 = vec_op::FP32Vec16(
+ scalar_vec_t(curr_scale + g_idx_1 * scales_stride));
+ if constexpr (has_zp) {
+ zero_0 = vec_op::FP32Vec16(*(curr_zeros + g_idx_0 * zeros_stride / 2),
+ lut);
+ zero_1 = vec_op::FP32Vec16(*(curr_zeros + g_idx_1 * zeros_stride / 2),
+ lut);
+ }
+ }
+
+ if constexpr (has_zp) {
+ wb_0 = wb_0 - zero_0;
+ wb_1 = wb_1 - zero_1;
+ }
+
+ wb_0 = wb_0 * scale_0;
+ wb_1 = wb_1 * scale_1;
+
+ scalar_vec_t output_vec_0(wb_0);
+ scalar_vec_t output_vec_1(wb_1);
+
+ // AMX needs to interlave K elements to pack as 32 bits
+ if constexpr (isa == ISA::AMX) {
+ vec_op::interleave_save(output_vec_0, output_vec_1, curr_weight);
+ } else {
+ output_vec_0.save(curr_weight);
+ output_vec_1.save(curr_weight + 16);
+ }
+
+ // update
+ curr_q_weight += 2;
+ curr_weight += 32;
+ if constexpr (!use_desc_act) {
+ group_counter += 2;
+ if (group_counter == group_size) {
+ group_counter = 0;
+ }
+ }
+ }
+ }
+};
+}; // namespace
+
+template
+void cpu_gemm_wna16_impl(
+ scalar_t* __restrict__ input, int32_t* __restrict__ q_weight,
+ scalar_t* __restrict__ output, scalar_t* __restrict__ scales,
+ int32_t* __restrict__ zeros, int32_t* __restrict__ g_idx,
+ scalar_t* __restrict__ bias, const int32_t m_size, const int32_t n_size,
+ const int32_t k_size, const int64_t input_stride,
+ const int64_t output_stride, const int64_t scales_group_stride,
+ const int64_t zeros_group_stride, const int32_t group_num,
+ const int32_t group_size, const int64_t pack_factor) {
+ constexpr int32_t gemm_n_tile_size = gemm_t::NSize;
+ constexpr int32_t gemm_m_tile_size = gemm_t::MaxMSize;
+ constexpr int32_t n_block_size = 16;
+ static_assert(gemm_n_tile_size % n_block_size == 0);
+ const int32_t thread_num = omp_get_max_threads();
+
+ // a simple schedule policy, just to hold more B tiles in L2 and make sure
+ // each thread has tasks
+ const int32_t n_partition_size = [&]() {
+ const int64_t cache_size = cpu_utils::get_l2_size();
+ int64_t ps_cache_limit = cache_size / (k_size * sizeof(scalar_t));
+ int64_t ps_thread_limit = n_size / thread_num;
+ ps_cache_limit =
+ std::max((ps_cache_limit / gemm_n_tile_size) * gemm_n_tile_size,
+ (int64_t)gemm_n_tile_size);
+ ps_thread_limit =
+ std::max((ps_thread_limit / gemm_n_tile_size) * gemm_n_tile_size,
+ (int64_t)gemm_n_tile_size);
+ return std::min(ps_cache_limit, ps_thread_limit);
+ }();
+ const int32_t task_num = (n_size + n_partition_size - 1) / n_partition_size;
+
+ // get buffer size
+ const int64_t b_buffer_size =
+ (((n_partition_size * k_size * sizeof(scalar_t) + 63) / 64) * 64);
+ const int64_t c_buffer_size =
+ (((gemm_m_tile_size * gemm_n_tile_size * sizeof(float) + 63) / 64) * 64);
+ const int64_t b_buffer_offset = 0;
+ const int64_t c_buffer_offset = b_buffer_size;
+ const int64_t buffer_size = b_buffer_size + c_buffer_size;
+ DNNLScratchPadManager::get_dnnl_scratchpad_manager()->realloc(buffer_size *
+ thread_num);
+
+ alignas(64) cpu_utils::Counter counter;
+ cpu_utils::Counter* counter_ptr = &counter;
+
+#pragma omp parallel for schedule(static, 1)
+ for (int32_t thread_id = 0; thread_id < thread_num; ++thread_id) {
+ scalar_t* __restrict__ b_buffer = nullptr;
+ float* __restrict__ c_buffer = nullptr;
+ {
+ uint8_t* buffer_ptr = DNNLScratchPadManager::get_dnnl_scratchpad_manager()
+ ->get_data() +
+ thread_id * buffer_size;
+ b_buffer = reinterpret_cast(buffer_ptr + b_buffer_offset);
+ c_buffer = reinterpret_cast(buffer_ptr + c_buffer_offset);
+ }
+
+ const int64_t q_weight_block_stride = n_block_size / pack_factor * k_size;
+ const int64_t b_buffer_block_stride = n_block_size * k_size;
+ const int32_t zeros_block_stride = n_block_size / pack_factor;
+
+ gemm_t gemm;
+
+ for (;;) {
+ int32_t task_id = counter_ptr->acquire_counter();
+
+ if (task_id >= task_num) {
+ break;
+ }
+
+ const int32_t n_start_idx = task_id * n_partition_size;
+ const int32_t n_block_start_idx = n_start_idx / n_block_size;
+ const int32_t n_num = std::min(n_partition_size, n_size - n_start_idx);
+ const int32_t n_block_num = n_num / n_block_size;
+ // std::printf("thread_id: %d, task_id: %d, n_start_idx: %d, n_num: %d\n",
+ // thread_id, task_id, n_start_idx, n_num);
+
+ // dequant weight
+ {
+ int32_t* __restrict__ curr_q_weight =
+ q_weight + n_block_start_idx * q_weight_block_stride;
+ scalar_t* __restrict__ curr_b_buffer = b_buffer;
+ scalar_t* __restrict__ curr_scales = scales + n_start_idx;
+ int32_t* __restrict__ curr_zeros = zeros + n_start_idx / pack_factor;
+ for (int32_t block_idx = 0; block_idx < n_block_num; ++block_idx) {
+ dequantizer_t::dequant(curr_q_weight, curr_b_buffer, curr_scales,
+ curr_zeros, g_idx, scales_group_stride,
+ zeros_group_stride, k_size, group_size);
+
+ // if (block_idx == 0 && n_start_idx == 0) {
+ // print_logits("depacked weight", curr_b_buffer, k_size,
+ // n_block_size, n_block_size);
+ // }
+
+ // update
+ curr_q_weight += q_weight_block_stride;
+ curr_b_buffer += b_buffer_block_stride;
+ curr_scales += n_block_size;
+ curr_zeros += zeros_block_stride;
+ }
+ }
+
+ // compute loop
+ {
+ const int32_t n_tile_num = n_num / gemm_n_tile_size;
+ scalar_t* __restrict__ curr_input = input;
+ scalar_t* __restrict__ init_bias = bias;
+ if (bias != nullptr) {
+ init_bias += n_start_idx;
+ }
+ scalar_t* __restrict__ init_output = output + n_start_idx;
+ for (int32_t m_idx = 0; m_idx < m_size; m_idx += gemm_m_tile_size) {
+ const int32_t curr_m_size =
+ std::min(gemm_m_tile_size, m_size - m_idx);
+ scalar_t* __restrict__ curr_b_buffer = b_buffer;
+ scalar_t* __restrict__ curr_bias = init_bias;
+ scalar_t* __restrict__ curr_output = init_output;
+ for (int32_t n_tile_idx = 0; n_tile_idx < n_tile_num; ++n_tile_idx) {
+ gemm.gemm(curr_input, curr_b_buffer, c_buffer, curr_m_size, k_size,
+ input_stride, b_buffer_block_stride, gemm_n_tile_size,
+ false);
+
+ if (bias != nullptr) {
+ cpu_micro_gemm::bias_epilogue(
+ c_buffer, curr_output, curr_bias, curr_m_size,
+ gemm_n_tile_size, output_stride);
+ curr_bias += gemm_n_tile_size;
+ } else {
+ cpu_micro_gemm::default_epilogue(
+ c_buffer, curr_output, curr_m_size, gemm_n_tile_size,
+ output_stride);
+ }
+
+ curr_b_buffer +=
+ b_buffer_block_stride * (gemm_n_tile_size / n_block_size);
+ curr_output += gemm_n_tile_size;
+ }
+ curr_input += gemm_m_tile_size * input_stride;
+ init_output += gemm_m_tile_size * output_stride;
+ }
+ }
+ }
+ }
+}
+
+void cpu_gemm_wna16(
+ const torch::Tensor& input, // [M, K]
+ const torch::Tensor&
+ q_weight, // [N / 16, K * 16 / pack_factor], packed as int32
+ torch::Tensor& output, // [M, N]
+ const torch::Tensor& scales, // [group_num, N]
+ const std::optional&
+ zeros, // [group_num, N / pack_factor], packed as int32
+ const std::optional& g_idx, // [K]
+ const std::optional& bias, // [N]
+ const int64_t pack_factor, const std::string& isa_hint) {
+ using cpu_utils::ISA;
+ TORCH_CHECK_EQ(pack_factor, 8); // only supports 4bits
+ const int32_t a_m_size = input.size(0);
+ const int32_t a_k_size = input.size(1);
+ const int64_t a_m_stride = input.stride(0);
+ const int32_t b_n_size = q_weight.size(0) * 16;
+ TORCH_CHECK_EQ(a_k_size % 32, 0);
+ TORCH_CHECK_EQ(b_n_size % 32, 0);
+ const int32_t group_num = scales.size(0);
+ const int32_t group_size = a_k_size / group_num;
+ TORCH_CHECK_EQ(group_size % 2, 0);
+ const int64_t scales_group_stride = scales.stride(0);
+ const int64_t output_m_stride = output.stride(0);
+
+ bool has_zp = zeros.has_value();
+ bool use_desc_act = g_idx.has_value();
+ TORCH_CHECK(!(has_zp && use_desc_act));
+
+ ISA isa = [&]() {
+ if (isa_hint == "amx") {
+ return ISA::AMX;
+ } else if (isa_hint == "vec") {
+ return ISA::VEC;
+ } else {
+ TORCH_CHECK(false, "unsupported isa hint: " + isa_hint);
+ }
+ }();
+
+ int32_t* zeros_ptr = has_zp ? zeros->data_ptr() : nullptr;
+ const int64_t zeros_group_stride = has_zp ? zeros->stride(0) : 0;
+ int32_t* g_idx_ptr = use_desc_act ? g_idx->data_ptr() : nullptr;
+
+ VLLM_DISPATCH_16B_TYPES(input.scalar_type(), "cpu_gemm_wna16", [&]() {
+ if (isa == ISA::AMX) {
+ using gemm_t = cpu_micro_gemm::MicroGemm;
+ if (has_zp) {
+ using dequantizer_t = Dequantizer4b;
+ cpu_gemm_wna16_impl(
+ input.data_ptr(), q_weight.data_ptr(),
+ output.data_ptr(), scales.data_ptr(), zeros_ptr,
+ g_idx_ptr, bias.has_value() ? bias->data_ptr() : nullptr,
+ a_m_size, b_n_size, a_k_size, a_m_stride, output_m_stride,
+ scales_group_stride, zeros_group_stride, group_num, group_size,
+ pack_factor);
+ return;
+ }
+ if (use_desc_act) {
+ using dequantizer_t = Dequantizer4b;
+ cpu_gemm_wna16_impl(
+ input.data_ptr(), q_weight.data_ptr(),
+ output.data_ptr(), scales.data_ptr(), zeros_ptr,
+ g_idx_ptr, bias.has_value() ? bias->data_ptr() : nullptr,
+ a_m_size, b_n_size, a_k_size, a_m_stride, output_m_stride,
+ scales_group_stride, zeros_group_stride, group_num, group_size,
+ pack_factor);
+ return;
+ } else {
+ using dequantizer_t = Dequantizer4b;
+ cpu_gemm_wna16_impl(
+ input.data_ptr(), q_weight.data_ptr(),
+ output.data_ptr(), scales.data_ptr(), zeros_ptr,
+ g_idx_ptr, bias.has_value() ? bias->data_ptr() : nullptr,
+ a_m_size, b_n_size, a_k_size, a_m_stride, output_m_stride,
+ scales_group_stride, zeros_group_stride, group_num, group_size,
+ pack_factor);
+ return;
+ }
+ } else if (isa == ISA::VEC) {
+ using gemm_t = cpu_micro_gemm::MicroGemm;
+ if (has_zp) {
+ using dequantizer_t = Dequantizer4b;
+ cpu_gemm_wna16_impl(
+ input.data_ptr(), q_weight.data_ptr(),
+ output.data_ptr(), scales.data_ptr(), zeros_ptr,
+ g_idx_ptr, bias.has_value() ? bias->data_ptr() : nullptr,
+ a_m_size, b_n_size, a_k_size, a_m_stride, output_m_stride,
+ scales_group_stride, zeros_group_stride, group_num, group_size,
+ pack_factor);
+ return;
+ }
+ if (use_desc_act) {
+ using dequantizer_t = Dequantizer4b;
+ cpu_gemm_wna16_impl(
+ input.data_ptr(), q_weight.data_ptr(),
+ output.data_ptr(), scales.data_ptr(), zeros_ptr,
+ g_idx_ptr, bias.has_value() ? bias->data_ptr() : nullptr,
+ a_m_size, b_n_size, a_k_size, a_m_stride, output_m_stride,
+ scales_group_stride, zeros_group_stride, group_num, group_size,
+ pack_factor);
+ return;
+ } else {
+ using dequantizer_t = Dequantizer4b;
+ cpu_gemm_wna16_impl(
+ input.data_ptr(), q_weight.data_ptr(),
+ output.data_ptr(), scales.data_ptr(), zeros_ptr,
+ g_idx_ptr, bias.has_value() ? bias->data_ptr() : nullptr,
+ a_m_size, b_n_size, a_k_size, a_m_stride, output_m_stride,
+ scales_group_stride, zeros_group_stride, group_num, group_size,
+ pack_factor);
+ return;
+ }
+ }
+ });
+}
diff --git a/csrc/cpu/dnnl_helper.cpp b/csrc/cpu/dnnl_helper.cpp
index 02a8072ccf30..cfb6e78cba9a 100644
--- a/csrc/cpu/dnnl_helper.cpp
+++ b/csrc/cpu/dnnl_helper.cpp
@@ -396,9 +396,9 @@ MatMulPrimitiveHandler::MatMulPrimitiveHandler(const Args& args)
: DNNLMatMulPrimitiveHandler(
static_cast(args), args.ab_type),
m_size_cache_(nullptr) {
- assert(ab_type_ == dnnl::memory::data_type::f32 ||
- ab_type_ == dnnl::memory::data_type::bf16 ||
- ab_type_ == dnnl::memory::data_type::f16);
+ assert(b_type_ == dnnl::memory::data_type::f32 ||
+ b_type_ == dnnl::memory::data_type::bf16 ||
+ b_type_ == dnnl::memory::data_type::f16);
dnnl::memory::desc original_b_md({b_k_size_, b_n_size_}, b_type_,
{b_k_stride_, b_n_stride_});
diff --git a/csrc/cpu/micro_gemm/cpu_micro_gemm_amx.hpp b/csrc/cpu/micro_gemm/cpu_micro_gemm_amx.hpp
new file mode 100644
index 000000000000..87a019773a89
--- /dev/null
+++ b/csrc/cpu/micro_gemm/cpu_micro_gemm_amx.hpp
@@ -0,0 +1,245 @@
+#ifndef CPU_MICRO_GEMM_AMX_HPP
+#define CPU_MICRO_GEMM_AMX_HPP
+#include "cpu/micro_gemm/cpu_micro_gemm_impl.hpp"
+
+namespace cpu_micro_gemm {
+namespace {
+// AMX specific
+constexpr static int64_t AMX_TILE_ROW_BYTES = 64;
+constexpr static int64_t AMX_TILE_ROW_NUM = 16;
+constexpr static int64_t AMX_TILE_BYTES = AMX_TILE_ROW_BYTES * AMX_TILE_ROW_NUM;
+
+typedef struct __tile_config {
+ uint8_t palette_id = 1;
+ uint8_t start_row = 0;
+ uint8_t reserved_0[14] = {0};
+ uint16_t colsb[16] = {0};
+ uint8_t rows[16] = {0};
+} __tilecfg;
+
+// 2-2-4 pattern, for 16 < m <= 32
+// TILE 0, 1: load A matrix, row num should be 16, m - 16
+// TILE 2, 3: load B matrix, row num should be 16
+// TILE 4, 5, 6, 7: store results C matrix, row num should be 16, 16, m - 16, m
+// - 16
+template
+class TileGemm224 {
+ public:
+ FORCE_INLINE static void gemm(DEFINE_CPU_MICRO_GEMM_PARAMS) {
+ TORCH_CHECK(false, "Unsupported data type for TileGemm224");
+ }
+
+ FORCE_INLINE static void init_tile_config(int32_t m, __tilecfg& config) {
+ TORCH_CHECK(false, "Unsupported data type for TileGemm224");
+ }
+};
+
+template <>
+class TileGemm224 {
+ public:
+ using scalar_t = c10::BFloat16;
+ FORCE_INLINE static void gemm(DEFINE_CPU_MICRO_GEMM_PARAMS) {
+ const int32_t k_times = k / (AMX_TILE_ROW_NUM * 4 / sizeof(c10::BFloat16));
+ c10::BFloat16* __restrict__ a_tile_0 = a_ptr;
+ c10::BFloat16* __restrict__ a_tile_1 = a_ptr + lda * AMX_TILE_ROW_NUM;
+ const int64_t a_tile_stride = lda * sizeof(c10::BFloat16);
+
+ // B is always packed as 16 output channels block
+ c10::BFloat16* __restrict__ b_tile_2 = b_ptr;
+ c10::BFloat16* __restrict__ b_tile_3 = b_ptr + b_n_group_stride;
+ const int32_t b_tile_stride = AMX_TILE_ROW_BYTES;
+
+ float* __restrict__ c_tile_4 = c_ptr;
+ float* __restrict__ c_tile_5 =
+ c_tile_4 + AMX_TILE_ROW_BYTES / sizeof(float);
+ float* __restrict__ c_tile_6 = c_ptr + AMX_TILE_ROW_NUM * ldc;
+ float* __restrict__ c_tile_7 =
+ c_tile_6 + AMX_TILE_ROW_BYTES / sizeof(float);
+ const int32_t c_tile_stride = ldc * sizeof(float);
+
+ if (accum_c) {
+ _tile_loadd(4, c_tile_4, c_tile_stride);
+ _tile_loadd(5, c_tile_5, c_tile_stride);
+ _tile_loadd(6, c_tile_6, c_tile_stride);
+ _tile_loadd(7, c_tile_7, c_tile_stride);
+ } else {
+ _tile_zero(4);
+ _tile_zero(5);
+ _tile_zero(6);
+ _tile_zero(7);
+ }
+
+ for (int32_t k = 0; k < k_times; ++k) {
+ _tile_loadd(0, a_tile_0, a_tile_stride);
+ _tile_stream_loadd(2, b_tile_2, b_tile_stride);
+ _tile_dpbf16ps(4, 0, 2);
+ _tile_stream_loadd(3, b_tile_3, b_tile_stride);
+ _tile_dpbf16ps(5, 0, 3);
+ _tile_loadd(1, a_tile_1, a_tile_stride);
+ _tile_dpbf16ps(6, 1, 2);
+ _tile_dpbf16ps(7, 1, 3);
+
+ // update ptrs
+ a_tile_0 += AMX_TILE_ROW_BYTES / sizeof(c10::BFloat16);
+ a_tile_1 += AMX_TILE_ROW_BYTES / sizeof(c10::BFloat16);
+ b_tile_2 += AMX_TILE_BYTES / sizeof(c10::BFloat16);
+ b_tile_3 += AMX_TILE_BYTES / sizeof(c10::BFloat16);
+ }
+
+ _tile_stored(4, c_tile_4, c_tile_stride);
+ _tile_stored(5, c_tile_5, c_tile_stride);
+ _tile_stored(6, c_tile_6, c_tile_stride);
+ _tile_stored(7, c_tile_7, c_tile_stride);
+ }
+
+ FORCE_INLINE static void init_tile_config(int32_t m, __tilecfg& config) {
+ const int32_t m_0 = AMX_TILE_ROW_NUM;
+ const int32_t m_1 = m - AMX_TILE_ROW_NUM;
+ config.rows[0] = m_0;
+ config.rows[1] = m_1;
+ config.rows[2] = AMX_TILE_ROW_NUM;
+ config.rows[3] = AMX_TILE_ROW_NUM;
+ config.rows[4] = m_0;
+ config.rows[5] = m_0;
+ config.rows[6] = m_1;
+ config.rows[7] = m_1;
+ _tile_loadconfig(&config);
+ }
+};
+
+// 1-2-2 pattern, for 0 < m <= 16
+// TILE 0, (1): load A matrix, use extra 1 tile for prefetch, row num should be
+// m, m
+// TILE 2, 3, (4, 5): load B matrix, use extra 2 tiles for prefetch, row
+// num should be 16
+// TILE 6, 7, (6, 7): store results C matrix, row num should be
+// m
+template
+class TileGemm122 {
+ public:
+ FORCE_INLINE static void gemm(DEFINE_CPU_MICRO_GEMM_PARAMS) {
+ TORCH_CHECK(false, "Unsupported data type for TileGemm122");
+ }
+
+ FORCE_INLINE static void init_tile_config(int32_t m, __tilecfg& config) {
+ TORCH_CHECK(false, "Unsupported data type for TileGemm122");
+ }
+};
+
+template <>
+class TileGemm122 {
+ public:
+ using scalar_t = c10::BFloat16;
+ FORCE_INLINE static void gemm(DEFINE_CPU_MICRO_GEMM_PARAMS) {
+ c10::BFloat16* __restrict__ a_tile_0 = a_ptr;
+ c10::BFloat16* __restrict__ a_tile_1 =
+ a_ptr + AMX_TILE_ROW_BYTES / sizeof(c10::BFloat16);
+ const int64_t a_tile_stride = lda * sizeof(c10::BFloat16);
+
+ c10::BFloat16* __restrict__ b_tile_2 = b_ptr;
+ c10::BFloat16* __restrict__ b_tile_3 = b_ptr + b_n_group_stride;
+ c10::BFloat16* __restrict__ b_tile_4 =
+ b_tile_2 + AMX_TILE_BYTES / sizeof(c10::BFloat16);
+ c10::BFloat16* __restrict__ b_tile_5 =
+ b_tile_3 + AMX_TILE_BYTES / sizeof(c10::BFloat16);
+ int64_t b_stride = AMX_TILE_ROW_BYTES;
+
+ float* __restrict__ c_tile_6 = c_ptr;
+ float* __restrict__ c_tile_7 = c_ptr + AMX_TILE_ROW_BYTES / sizeof(float);
+ int64_t c_stride = ldc * sizeof(float);
+
+ const int32_t k_times = k / (AMX_TILE_ROW_NUM * 4 / sizeof(c10::BFloat16));
+ const int32_t k_group_times = k_times / 2;
+ const bool has_tail = (k_times % 2 == 1);
+
+ if (accum_c) {
+ _tile_loadd(6, c_tile_6, c_stride);
+ _tile_loadd(7, c_tile_7, c_stride);
+ } else {
+ _tile_zero(6);
+ _tile_zero(7);
+ }
+
+ for (int32_t k = 0; k < k_group_times; ++k) {
+ _tile_loadd(0, a_tile_0, a_tile_stride);
+ _tile_stream_loadd(2, b_tile_2, b_stride);
+ _tile_dpbf16ps(6, 0, 2);
+ _tile_stream_loadd(3, b_tile_3, b_stride);
+ _tile_dpbf16ps(7, 0, 3);
+ _tile_loadd(1, a_tile_1, a_tile_stride);
+ _tile_stream_loadd(4, b_tile_4, b_stride);
+ _tile_dpbf16ps(6, 1, 4);
+ _tile_stream_loadd(5, b_tile_5, b_stride);
+ _tile_dpbf16ps(7, 1, 5);
+
+ // update ptrs
+ a_tile_0 += 2 * AMX_TILE_ROW_BYTES / sizeof(c10::BFloat16);
+ a_tile_1 += 2 * AMX_TILE_ROW_BYTES / sizeof(c10::BFloat16);
+ b_tile_2 += 2 * AMX_TILE_BYTES / sizeof(c10::BFloat16);
+ b_tile_3 += 2 * AMX_TILE_BYTES / sizeof(c10::BFloat16);
+ b_tile_4 += 2 * AMX_TILE_BYTES / sizeof(c10::BFloat16);
+ b_tile_5 += 2 * AMX_TILE_BYTES / sizeof(c10::BFloat16);
+ }
+
+ if (has_tail) {
+ _tile_loadd(0, a_tile_0, a_tile_stride);
+ _tile_stream_loadd(2, b_tile_2, b_stride);
+ _tile_dpbf16ps(6, 0, 2);
+ _tile_stream_loadd(3, b_tile_3, b_stride);
+ _tile_dpbf16ps(7, 0, 3);
+ }
+
+ _tile_stored(6, c_tile_6, c_stride);
+ _tile_stored(7, c_tile_7, c_stride);
+ }
+
+ FORCE_INLINE static void init_tile_config(int32_t m, __tilecfg& config) {
+ config.rows[0] = m;
+ config.rows[1] = m;
+ config.rows[2] = AMX_TILE_ROW_NUM;
+ config.rows[3] = AMX_TILE_ROW_NUM;
+ config.rows[4] = AMX_TILE_ROW_NUM;
+ config.rows[5] = AMX_TILE_ROW_NUM;
+ config.rows[6] = m;
+ config.rows[7] = m;
+ _tile_loadconfig(&config);
+ }
+};
+} // namespace
+
+// Gemm kernel uses AMX, requires B matrix to be packed
+template
+class MicroGemm {
+ public:
+ static constexpr int32_t MaxMSize = 32;
+ static constexpr int32_t NSize = 32;
+
+ public:
+ MicroGemm() : curr_m_(-1) {
+ vec_op::unroll_loop([&](int i) { amx_tile_config_.colsb[i] = 64; });
+ }
+
+ void gemm(DEFINE_CPU_MICRO_GEMM_PARAMS) {
+ if (m > AMX_TILE_ROW_NUM) {
+ if (m != curr_m_) {
+ curr_m_ = m;
+ TileGemm224::init_tile_config(m, amx_tile_config_);
+ }
+ TileGemm224::gemm(CPU_MICRO_GEMM_PARAMS);
+ } else {
+ if (m != curr_m_) {
+ curr_m_ = m;
+ TileGemm122::init_tile_config(m, amx_tile_config_);
+ }
+ TileGemm122::gemm(CPU_MICRO_GEMM_PARAMS);
+ }
+ }
+
+ private:
+ alignas(64) __tilecfg amx_tile_config_;
+ int32_t curr_m_;
+};
+
+} // namespace cpu_micro_gemm
+
+#endif
diff --git a/csrc/cpu/micro_gemm/cpu_micro_gemm_impl.hpp b/csrc/cpu/micro_gemm/cpu_micro_gemm_impl.hpp
new file mode 100644
index 000000000000..784da55a420e
--- /dev/null
+++ b/csrc/cpu/micro_gemm/cpu_micro_gemm_impl.hpp
@@ -0,0 +1,91 @@
+#ifndef CPU_MICRO_GEMM_IMPL_HPP
+#define CPU_MICRO_GEMM_IMPL_HPP
+#include "cpu/utils.hpp"
+#include "cpu/cpu_types.hpp"
+
+namespace cpu_micro_gemm {
+#define DEFINE_CPU_MICRO_GEMM_PARAMS \
+ scalar_t *__restrict__ a_ptr, scalar_t *__restrict__ b_ptr, \
+ float *__restrict__ c_ptr, const int32_t m, const int32_t k, \
+ const int64_t lda, const int64_t b_n_group_stride, const int64_t ldc, \
+ const bool accum_c
+
+#define CPU_MICRO_GEMM_PARAMS \
+ a_ptr, b_ptr, c_ptr, m, k, lda, b_n_group_stride, ldc, accum_c
+
+template
+class MicroGemm {
+ public:
+ static constexpr int32_t MaxMSize = 16;
+ static constexpr int32_t NSize = 16;
+
+ public:
+ void gemm(DEFINE_CPU_MICRO_GEMM_PARAMS) {
+ TORCH_CHECK(false, "Unimplemented MicroGemm.");
+ }
+};
+
+template
+FORCE_INLINE void default_epilogue(float* __restrict__ c_ptr,
+ scalar_t* __restrict__ d_ptr,
+ const int32_t m, const int64_t ldc,
+ const int64_t ldd) {
+ using scalar_vec_t = typename cpu_utils::VecTypeTrait::vec_t;
+ static_assert(n_size % 16 == 0);
+
+ float* __restrict__ curr_c = c_ptr;
+ scalar_t* __restrict__ curr_d = d_ptr;
+ for (int32_t i = 0; i < m; ++i) {
+ float* __restrict__ curr_c_iter = curr_c;
+ scalar_t* __restrict__ curr_d_iter = curr_d;
+ vec_op::unroll_loop([&](int32_t n_g_idx) {
+ vec_op::FP32Vec16 c_vec_fp32(curr_c_iter);
+ scalar_vec_t c_vec(c_vec_fp32);
+ c_vec.save(curr_d_iter);
+ curr_c_iter += 16;
+ curr_d_iter += 16;
+ });
+ curr_c += ldc;
+ curr_d += ldd;
+ }
+}
+
+template
+FORCE_INLINE void bias_epilogue(float* __restrict__ c_ptr,
+ scalar_t* __restrict__ d_ptr,
+ scalar_t* __restrict__ bias_ptr,
+ const int32_t m, const int64_t ldc,
+ const int64_t ldd) {
+ using scalar_vec_t = typename cpu_utils::VecTypeTrait::vec_t;
+ static_assert(n_size % 16 == 0);
+ constexpr int32_t n_group_num = n_size / 16;
+ static_assert(n_group_num <= 16);
+
+ vec_op::FP32Vec16 bias_vecs[n_group_num];
+ scalar_t* __restrict__ curr_bias = bias_ptr;
+ vec_op::unroll_loop([&](int32_t i) {
+ scalar_vec_t vec(curr_bias);
+ bias_vecs[i] = vec_op::FP32Vec16(vec);
+ curr_bias += 16;
+ });
+
+ float* __restrict__ curr_c = c_ptr;
+ scalar_t* __restrict__ curr_d = d_ptr;
+ for (int32_t i = 0; i < m; ++i) {
+ float* __restrict__ curr_c_iter = curr_c;
+ scalar_t* __restrict__ curr_d_iter = curr_d;
+ vec_op::unroll_loop([&](int32_t n_g_idx) {
+ vec_op::FP32Vec16 c_vec_fp32(curr_c_iter);
+ c_vec_fp32 = c_vec_fp32 + bias_vecs[n_g_idx];
+ scalar_vec_t c_vec(c_vec_fp32);
+ c_vec.save(curr_d_iter);
+ curr_c_iter += 16;
+ curr_d_iter += 16;
+ });
+ curr_c += ldc;
+ curr_d += ldd;
+ }
+}
+} // namespace cpu_micro_gemm
+
+#endif
diff --git a/csrc/cpu/micro_gemm/cpu_micro_gemm_vec.hpp b/csrc/cpu/micro_gemm/cpu_micro_gemm_vec.hpp
new file mode 100644
index 000000000000..3985c2f2e5fe
--- /dev/null
+++ b/csrc/cpu/micro_gemm/cpu_micro_gemm_vec.hpp
@@ -0,0 +1,115 @@
+#ifndef CPU_MICRO_GEMM_VEC_HPP
+#define CPU_MICRO_GEMM_VEC_HPP
+#include "cpu/micro_gemm/cpu_micro_gemm_impl.hpp"
+
+namespace cpu_micro_gemm {
+namespace {
+// 8-2-16 pattern, 8 regs for A, 2 regs for B, 16 regs for C, [8, K] @ [k, 32]
+template
+class TileGemm82 {
+ public:
+ FORCE_INLINE static void gemm(DEFINE_CPU_MICRO_GEMM_PARAMS) {
+ switch (m) {
+ case 1:
+ gemm_micro<1>(CPU_MICRO_GEMM_PARAMS);
+ break;
+ case 2:
+ gemm_micro<2>(CPU_MICRO_GEMM_PARAMS);
+ break;
+ case 3:
+ gemm_micro<3>(CPU_MICRO_GEMM_PARAMS);
+ break;
+ case 4:
+ gemm_micro<4>(CPU_MICRO_GEMM_PARAMS);
+ break;
+ case 5:
+ gemm_micro<5>(CPU_MICRO_GEMM_PARAMS);
+ break;
+ case 6:
+ gemm_micro<6>(CPU_MICRO_GEMM_PARAMS);
+ break;
+ case 7:
+ gemm_micro<7>(CPU_MICRO_GEMM_PARAMS);
+ break;
+ case 8:
+ gemm_micro<8>(CPU_MICRO_GEMM_PARAMS);
+ break;
+ }
+ }
+
+ template
+ static void gemm_micro(DEFINE_CPU_MICRO_GEMM_PARAMS) {
+ static_assert(0 < M <= 8);
+ using load_vec_t = typename cpu_utils::VecTypeTrait::vec_t;
+
+ scalar_t* __restrict__ curr_b_0 = b_ptr;
+ scalar_t* __restrict__ curr_b_1 = b_ptr + b_n_group_stride;
+ float* __restrict__ curr_c_0 = c_ptr;
+ float* __restrict__ curr_c_1 = c_ptr + 16;
+
+ vec_op::FP32Vec16 c_regs[M * 2];
+ if (accum_c) {
+ float* __restrict__ curr_m_c_0 = curr_c_0;
+ float* __restrict__ curr_m_c_1 = curr_c_1;
+ vec_op::unroll_loop([&](int32_t i) {
+ c_regs[i * 2] = vec_op::FP32Vec16(curr_m_c_0);
+ c_regs[i * 2 + 1] = vec_op::FP32Vec16(curr_m_c_1);
+
+ // update
+ curr_m_c_0 += ldc;
+ curr_m_c_1 += ldc;
+ });
+ }
+
+ scalar_t* __restrict__ curr_a = a_ptr;
+ for (int32_t k_idx = 0; k_idx < k; ++k_idx) {
+ load_vec_t b_0_reg(curr_b_0);
+ vec_op::FP32Vec16 fp32_b_0_reg(b_0_reg);
+ load_vec_t b_1_reg(curr_b_1);
+ vec_op::FP32Vec16 fp32_b_1_reg(b_1_reg);
+
+ scalar_t* __restrict__ curr_m_a = curr_a;
+ vec_op::unroll_loop([&](int32_t i) {
+ scalar_t v = *curr_m_a;
+ load_vec_t a_reg_original(v);
+ vec_op::FP32Vec16 a_reg(a_reg_original);
+ c_regs[i * 2] = c_regs[i * 2] + a_reg * fp32_b_0_reg;
+ c_regs[i * 2 + 1] = c_regs[i * 2 + 1] + a_reg * fp32_b_1_reg;
+
+ // update
+ curr_m_a += lda;
+ });
+
+ // update
+ curr_a += 1;
+ curr_b_0 += 16;
+ curr_b_1 += 16;
+ }
+
+ vec_op::unroll_loop([&](int32_t i) {
+ c_regs[i * 2].save(curr_c_0);
+ c_regs[i * 2 + 1].save(curr_c_1);
+
+ // update
+ curr_c_0 += ldc;
+ curr_c_1 += ldc;
+ });
+ }
+};
+} // namespace
+
+// Gemm kernel uses vector instructions, requires B matrix to be packed
+template
+class MicroGemm {
+ public:
+ static constexpr int32_t MaxMSize = 8;
+ static constexpr int32_t NSize = 32;
+
+ public:
+ void gemm(DEFINE_CPU_MICRO_GEMM_PARAMS) {
+ TileGemm82::gemm(CPU_MICRO_GEMM_PARAMS);
+ }
+};
+} // namespace cpu_micro_gemm
+
+#endif
diff --git a/csrc/cpu/torch_bindings.cpp b/csrc/cpu/torch_bindings.cpp
index 5e2aa7069256..e0e3ef71b485 100644
--- a/csrc/cpu/torch_bindings.cpp
+++ b/csrc/cpu/torch_bindings.cpp
@@ -100,6 +100,16 @@ void cpu_attention_with_kv_cache(
const torch::Tensor& scheduler_metadata,
const std::optional& s_aux);
+// Note: just for avoiding importing errors
+void placeholder_op() { TORCH_CHECK(false, "Unimplemented"); }
+
+void cpu_gemm_wna16(const torch::Tensor& input, const torch::Tensor& q_weight,
+ torch::Tensor& output, const torch::Tensor& scales,
+ const std::optional& zeros,
+ const std::optional