From c3af2af0b8be65ecd1a8538bcfb9622e873e6b3c Mon Sep 17 00:00:00 2001 From: ilmarkov Date: Thu, 4 Sep 2025 07:25:57 -0700 Subject: [PATCH 001/183] Split PR. Second part. Compile ranges Signed-off-by: ilmarkov --- tests/compile/test_compile_ranges.py | 86 ++++++++++++++ vllm/compilation/backends.py | 104 +++++++--------- vllm/compilation/collective_fusion.py | 144 +++++++++-------------- vllm/compilation/compiler_interface.py | 40 ++++--- vllm/compilation/inductor_pass.py | 11 +- vllm/compilation/pass_manager.py | 4 +- vllm/compilation/piecewise_backend.py | 57 +++++---- vllm/compilation/sequence_parallelism.py | 6 +- vllm/config/compilation.py | 33 ++++++ 9 files changed, 288 insertions(+), 197 deletions(-) create mode 100644 tests/compile/test_compile_ranges.py diff --git a/tests/compile/test_compile_ranges.py b/tests/compile/test_compile_ranges.py new file mode 100644 index 000000000000..6759da199f4b --- /dev/null +++ b/tests/compile/test_compile_ranges.py @@ -0,0 +1,86 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import torch +from torch import nn +from torch.library import Library + +from vllm.compilation.counter import compilation_counter +from vllm.compilation.decorators import support_torch_compile +from vllm.config import (CompilationConfig, CompilationLevel, VllmConfig, + set_current_vllm_config) +from vllm.forward_context import set_forward_context +from vllm.utils import direct_register_custom_op + +# create a library to hold the custom op +silly_lib = Library("silly", "FRAGMENT") # noqa + +BATCH_SIZE = 64 +MLP_SIZE = 128 + + +def silly_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, + out: torch.Tensor) -> None: + out.copy_(q) + out += k + out += v + + +def silly_attention_fake(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, + out: torch.Tensor) -> None: + return + + +direct_register_custom_op( + op_name="attention", + op_func=silly_attention, + mutates_args=["out"], + fake_impl=silly_attention_fake, + target_lib=silly_lib, +) + + +@support_torch_compile +class TestModel(nn.Module): + + def __init__(self, + *, + vllm_config: VllmConfig, + prefix: str = '', + **kwargs) -> None: + super().__init__() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = x + x + attn_output = torch.empty_like(x) + torch.ops.silly.attention(x, x, x, attn_output) + x = attn_output + x = x * 3 + return x + + +@torch.inference_mode +def run_model(vllm_config: VllmConfig, model: nn.Module, + batch_sizes: list[int]): + with set_forward_context({}, vllm_config=vllm_config): + model(torch.randn(BATCH_SIZE, MLP_SIZE).cuda()) + for batch_size in batch_sizes: + model(torch.randn(batch_size, MLP_SIZE).cuda()) + + +def test_compile_ranges(): + vllm_config = VllmConfig(compilation_config=CompilationConfig( + level=CompilationLevel.PIECEWISE, + compile_ranges_split_points=[8, 32], + )) + + with set_current_vllm_config(vllm_config): + model = TestModel(vllm_config=vllm_config, prefix='').eval().cuda() + batch_sizes = [1, 16, 48] + # A has support_torch_compile + with compilation_counter.expect( + num_graphs_seen=1, + num_piecewise_graphs_seen=1, + num_backend_compilations=4, + # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen + ): + run_model(vllm_config, model, batch_sizes) diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 53fd5e74dc0a..686c415f7ac3 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -80,7 +80,8 @@ class CompilerManager: """ def __init__(self, compilation_config: CompilationConfig): - self.cache: dict[tuple[int | None, int, str], Any] = dict() + self.cache: dict[tuple[tuple[int, int] | None, int, str], + Any] = (dict()) self.is_cache_updated = False self.compilation_config = compilation_config self.compiler = make_compiler(compilation_config) @@ -89,11 +90,11 @@ def compute_hash(self, vllm_config: VllmConfig) -> str: return self.compiler.compute_hash(vllm_config) @contextmanager - def compile_context(self, runtime_shape: int | None = None): + def compile_context(self, compile_range: tuple[int, int] | None = None): """Provide compilation context for the duration of compilation to set any torch global properties we want to scope to a single Inductor compilation (e.g. partition rules, pass context).""" - with pass_context(runtime_shape): + with pass_context(compile_range): if self.compilation_config.use_inductor_graph_partition: inductor_partition_ops = resolve_defined_ops( self.compilation_config.splitting_ops @@ -150,29 +151,25 @@ def load( graph: fx.GraphModule, example_inputs: list[Any], graph_index: int, - runtime_shape: int | None = None, + compile_range: tuple[int, int] | None = None, ) -> Callable | None: - if (runtime_shape, graph_index, self.compiler.name) not in self.cache: + if (compile_range, graph_index, self.compiler.name) not in self.cache: return None - handle = self.cache[(runtime_shape, graph_index, self.compiler.name)] - compiled_graph = self.compiler.load( - handle, graph, example_inputs, graph_index, runtime_shape - ) - if runtime_shape is None: + handle = self.cache[(compile_range, graph_index, self.compiler.name)] + compiled_graph = self.compiler.load(handle, graph, example_inputs, + graph_index, compile_range) + if compile_range is None: logger.debug( - "Directly load the %s-th graph for dynamic shape from %s via handle %s", + "Directly load the %s-th graph for dynamic compile range from %s via handle %s", graph_index, self.compiler.name, handle, ) else: logger.debug( - "Directly load the %s-th graph for shape %s from %s via handle %s", - graph_index, - str(runtime_shape), - self.compiler.name, - handle, - ) + "Directly load the %s-th graph for compile range %s from %s via " + "handle %s", graph_index, str(compile_range), + self.compiler.name, handle) return compiled_graph def compile( @@ -183,7 +180,7 @@ def compile( compilation_config: CompilationConfig, graph_index: int = 0, num_graphs: int = 1, - runtime_shape: int | None = None, + compile_range: tuple[int, int] | None = None, ) -> Any: if graph_index == 0: # before compiling the first graph, record the start time @@ -195,15 +192,15 @@ def compile( compiled_graph = None # try to load from the cache - compiled_graph = self.load(graph, example_inputs, graph_index, runtime_shape) + compiled_graph = self.load(graph, example_inputs, graph_index, + compile_range) if compiled_graph is not None: if graph_index == num_graphs - 1: # after loading the last graph for this shape, record the time. # there can be multiple graphs due to piecewise compilation. now = time.time() elapsed = now - compilation_start_time - compilation_config.compilation_time += elapsed - if runtime_shape is None: + if compile_range is None: logger.info( "Directly load the compiled graph(s) for dynamic shape " "from the cache, took %.3f s", @@ -211,11 +208,9 @@ def compile( ) else: logger.info( - "Directly load the compiled graph(s) for shape %s " - "from the cache, took %.3f s", - str(runtime_shape), - elapsed, - ) + "Directly load the compiled graph(s) for compile range %s " + "from the cache, took %.3f s", str(compile_range), + elapsed) return compiled_graph # no compiler cached the graph, or the cache is disabled, @@ -224,48 +219,40 @@ def compile( # Let compile_fx generate a key for us maybe_key = None else: - maybe_key = f"artifact_shape_{runtime_shape}_subgraph_{graph_index}" - - with self.compile_context(runtime_shape): - compiled_graph, handle = self.compiler.compile( - graph, - example_inputs, - additional_inductor_config, - runtime_shape, - maybe_key, - ) + maybe_key = \ + f"artifact_compile_range_{compile_range}_subgraph_{graph_index}" + compiled_graph, handle = self.compiler.compile( + graph, example_inputs, additional_inductor_config, compile_range, + maybe_key) assert compiled_graph is not None, "Failed to compile the graph" # store the artifact in the cache if not envs.VLLM_DISABLE_COMPILE_CACHE and handle is not None: - self.cache[(runtime_shape, graph_index, self.compiler.name)] = handle + self.cache[(compile_range, graph_index, + self.compiler.name)] = handle compilation_counter.num_cache_entries_updated += 1 self.is_cache_updated = True if graph_index == 0: # adds some info logging for the first graph - if runtime_shape is None: + if compile_range is None: logger.info_once( - "Cache the graph for dynamic shape for later use", scope="local" - ) + "Cache the graph for dynamic shape for later use") else: - logger.info_once( - "Cache the graph of shape %s for later use", - str(runtime_shape), - scope="local", - ) - if runtime_shape is None: + logger.info_once("Cache the graph of compile range %s for later use", + str(compile_range)) + if compile_range is None: logger.debug( - "Store the %s-th graph for dynamic shape from %s via handle %s", + "Store the %s-th graph for dynamic compile range from %s via handle %s", graph_index, self.compiler.name, handle, ) else: logger.debug( - "Store the %s-th graph for shape %s from %s via handle %s", + "Store the %s-th graph for compile range %s from %s via handle %s", graph_index, - str(runtime_shape), + str(compile_range), self.compiler.name, handle, ) @@ -275,19 +262,16 @@ def compile( now = time.time() elapsed = now - compilation_start_time compilation_config.compilation_time += elapsed - if runtime_shape is None: + if compile_range is None: logger.info_once( - "Compiling a graph for dynamic shape takes %.2f s", + "Compiling a graph for dynamic compile range takes %.2f s", + elapsed, scope="local", ) else: - logger.info_once( - "Compiling a graph for shape %s takes %.2f s", - runtime_shape, - elapsed, - scope="local", - ) + logger.info_once("Compiling a graph for compile range %s takes %.2f s", + str(compile_range), elapsed, scope="local") return compiled_graph @@ -408,7 +392,6 @@ def call_module( i for i, x in enumerate(args) if isinstance(x, torch.SymInt) ] global compilation_start_time - compiled_graph_for_dynamic_shape = ( self.vllm_backend.compiler_manager.compile( submod, @@ -417,9 +400,8 @@ def call_module( self.compilation_config, graph_index=index, num_graphs=len(self.compile_submod_names), - runtime_shape=None, - ) - ) + compile_range=None, + )) # Lazy import here to avoid circular import from .piecewise_backend import PiecewiseBackend diff --git a/vllm/compilation/collective_fusion.py b/vllm/compilation/collective_fusion.py index cf89182357f2..a4758c971611 100644 --- a/vllm/compilation/collective_fusion.py +++ b/vllm/compilation/collective_fusion.py @@ -504,93 +504,59 @@ def call_trtllm_fused_allreduce_norm( num_tokens, hidden_size = allreduce_in.shape element_size = allreduce_in.element_size() current_tensor_size = num_tokens * hidden_size * element_size - - if num_tokens <= max_token_num: - device_capability = ( - current_platform.get_device_capability().as_version_str() - ) - # Get one shot input size limit for the current world size - # for the current device capability - max_one_shot_size_mb = _FI_ALLREDUCE_ONE_SHOT_MAX_SIZES_MB.get( - device_capability, {} - ).get(world_size, None) - # Use one shot if no max size for one shot is specified - use_oneshot = ( - max_one_shot_size_mb is None - or current_tensor_size <= max_one_shot_size_mb * MiB - ) - - assert _FI_WORKSPACE_TENSOR is not None, ( - "Flashinfer must be enabled when using flashinfer" - ) - if norm_out is None: - norm_out = allreduce_in - residual_out = residual - else: - # return residual_out as allreduce_out with zeroed residual_in - # as flashinfer does not support rms_norm - # and allreduce_out together - residual_out = allreduce_in - # For the sizes that are smaller than the max size, - # we only use flashinfer one shot allreduce - flashinfer_comm.trtllm_allreduce_fusion( - allreduce_in=allreduce_in, - token_num=allreduce_in.shape[0], - residual_in=residual, - residual_out=residual_out, - norm_out=norm_out, - rms_gamma=rms_gamma, - rms_eps=rms_eps, - world_rank=world_rank, - world_size=world_size, - hidden_dim=allreduce_in.shape[-1], - workspace_ptrs=_FI_WORKSPACE_TENSOR, - launch_with_pdl=launch_with_pdl, - use_oneshot=use_oneshot, - trigger_completion_at_end=trigger_completion_at_end, - fp32_acc=fp32_acc, - pattern_code=pattern_code, - allreduce_out=None, - quant_out=quant_out, - scale_out=scale_out, - # in vllm we only support swizzled layout - layout_code=flashinfer_comm.QuantizationSFLayout.SWIZZLED_128x4, - scale_factor=scale_factor, - ) + max_tensor_size = max_token_num * hidden_size * element_size + assert current_tensor_size <= max_tensor_size, \ + f"Current tensor size {current_tensor_size} is larger than " \ + f"max token num {max_token_num} * hidden size {hidden_size} * " \ + f"element size {element_size}" + device_capability = current_platform.get_device_capability( + ).as_version_str() + # Get one shot input size limit for the current world size + # for the current device capability + max_one_shot_size = _FI_ALLREDUCE_ONE_SHOT_MAX_SIZES_MB. \ + get(device_capability, {}). \ + get(world_size, None) + # Use one shot if no max size is specified + use_oneshot = max_one_shot_size is None or \ + current_tensor_size <= max_one_shot_size * MiB + + assert ( + _FI_WORKSPACE_TENSOR + is not None), "Flashinfer must be enabled when using flashinfer" + if norm_out is None: + norm_out = allreduce_in + residual_out = residual else: - allreduce_out = tensor_model_parallel_all_reduce(allreduce_in) - if scale_factor is not None and scale_out is None: - # Do fused rms norm static fp8 quant fused op - if norm_out is None: - torch.ops._C.fused_add_rms_norm_static_fp8_quant( - quant_out, - allreduce_out, - residual, - rms_gamma, - scale_factor, - rms_eps, - ) - else: - torch.ops._C.rms_norm_static_fp8_quant( - quant_out, allreduce_out, rms_gamma, scale_factor, rms_eps - ) - else: - if norm_out is None: - torch.ops._C.fused_add_rms_norm( - allreduce_out, residual, rms_gamma, rms_eps - ) - norm_out = allreduce_out - else: - torch.ops._C.rms_norm(norm_out, allreduce_out, rms_gamma, rms_eps) - if scale_factor is not None and scale_out is not None: - torch.ops._C.scaled_fp4_quant( - quant_out, norm_out, scale_out, scale_factor - ) - if scale_factor is None or norm_out is not None: - # we need to return allreduce output - # in cases of non quant fused AR + RMS norm - # and fused AR + RMS norm + quant without fused add - allreduce_in.copy_(allreduce_out) + # return residual_out as allreduce_out with zeroed residual_in + # as flashinfer does not support rms_norm + # and allreduce_out together + residual_out = allreduce_in + # For the sizes that are smaller than the max size, + # we only use flashinfer one shot allreduce + flashinfer_comm.trtllm_allreduce_fusion( + allreduce_in=allreduce_in, + token_num=allreduce_in.shape[0], + residual_in=residual, + residual_out=residual_out, + norm_out=norm_out, + rms_gamma=rms_gamma, + rms_eps=rms_eps, + world_rank=world_rank, + world_size=world_size, + hidden_dim=allreduce_in.shape[-1], + workspace_ptrs=_FI_WORKSPACE_TENSOR, + launch_with_pdl=launch_with_pdl, + use_oneshot=use_oneshot, + trigger_completion_at_end=trigger_completion_at_end, + fp32_acc=fp32_acc, + pattern_code=pattern_code, + allreduce_out=None, + quant_out=quant_out, + scale_out=scale_out, + # in vllm we only support swizzled layout + layout_code=flashinfer_comm.QuantizationSFLayout.SWIZZLED_128x4, + scale_factor=scale_factor, + ) def call_trtllm_fused_allreduce_norm_fake( allreduce_in: torch.Tensor, @@ -1212,6 +1178,12 @@ def register_patterns(self): self.disabled = False @VllmInductorPass.time_and_log + def is_applicable_for_range( + self, compile_range: tuple[int, int] | None) -> bool: + if compile_range is None: + return False + return compile_range[1] - 1 <= self.max_token_num + def __call__(self, graph: fx.Graph): if self.disabled: logger.debug("AllReduceFusionPass disabled") diff --git a/vllm/compilation/compiler_interface.py b/vllm/compilation/compiler_interface.py index 0a3f0769db94..3861bfed11d5 100644 --- a/vllm/compilation/compiler_interface.py +++ b/vllm/compilation/compiler_interface.py @@ -63,16 +63,17 @@ def compile( graph: fx.GraphModule, example_inputs: list[Any], compiler_config: dict[str, Any], - runtime_shape: int | None = None, + compile_range: tuple[int, int | None] = None, key: str | None = None, ) -> tuple[Callable | None, Any | None]: """ Compile the graph with the given example inputs and compiler config, - with a runtime shape. If the `runtime_shape` is None, it means + with a range. If the `compile_range` is None, it means the `example_inputs` have a dynamic shape. Otherwise, the - `runtime_shape` specifies the shape of the inputs. Right now we only - support one variable shape for all inputs, which is the batchsize - (number of tokens) during inference. + `compile_range` specifies the range of the inputs, + it could be concrete size, e.g. (4, 4). + Right now we only support one variable range of shapes for all inputs, + which is the batchsize (number of tokens) during inference. Dynamo will make sure `graph(*example_inputs)` is valid. @@ -98,7 +99,7 @@ def load( graph: fx.GraphModule, example_inputs: list[Any], graph_index: int, - runtime_shape: int | None = None, + compile_range: tuple[int, int | None] = None, ) -> Callable: """ Load the compiled function from the handle. @@ -192,18 +193,21 @@ def compile( graph: fx.GraphModule, example_inputs: list[Any], compiler_config: dict[str, Any], - runtime_shape: int | None = None, + compile_range: tuple[int, int | None] = None, key: str | None = None, ) -> tuple[Callable | None, Any | None]: compilation_counter.num_inductor_compiles += 1 current_config = {} if compiler_config is not None: current_config.update(compiler_config) - set_inductor_config(current_config, runtime_shape) + set_inductor_config(current_config, compile_range) set_functorch_config() - if isinstance(runtime_shape, int): - dynamic_shapes = "from_example_inputs" + if isinstance(compile_range, tuple): + if compile_range[0] == compile_range[1]: + dynamic_shapes = "from_example_inputs" + else: + dynamic_shapes = "from_graph" else: dynamic_shapes = "from_tracing_context" @@ -230,7 +234,7 @@ def load( graph: fx.GraphModule, example_inputs: list[Any], graph_index: int, - runtime_shape: int | None = None, + compile_range: tuple[int, int | None] = None, ) -> Callable: assert isinstance(handle, tuple) assert isinstance(handle[0], str) @@ -294,7 +298,7 @@ def compile( graph: fx.GraphModule, example_inputs: list[Any], compiler_config: dict[str, Any], - runtime_shape: int | None = None, + compile_range: tuple[int, int | None] = None, key: str | None = None, ) -> tuple[Callable | None, Any | None]: compilation_counter.num_inductor_compiles += 1 @@ -308,7 +312,7 @@ def compile( current_config["fx_graph_cache"] = True current_config["fx_graph_remote_cache"] = False - set_inductor_config(current_config, runtime_shape) + set_inductor_config(current_config, compile_range) set_functorch_config() # inductor can inplace modify the graph, so we need to copy it @@ -493,7 +497,7 @@ def load( graph: fx.GraphModule, example_inputs: list[Any], graph_index: int, - runtime_shape: int | None = None, + compile_range: tuple[int, int | None] = None, ) -> Callable: assert isinstance(handle, tuple) assert isinstance(handle[0], str) @@ -589,9 +593,9 @@ def metrics_context(self) -> contextlib.AbstractContextManager: return contextlib.nullcontext() -def set_inductor_config(config, runtime_shape): - if isinstance(runtime_shape, int): - # for a specific batchsize, tuning triton kernel parameters +def set_inductor_config(config, compile_range): + if isinstance(compile_range, tuple): + # for a specific range of batchsizes, tuning triton kernel parameters # can be beneficial config["max_autotune"] = envs.VLLM_ENABLE_INDUCTOR_MAX_AUTOTUNE config["coordinate_descent_tuning"] = ( @@ -611,7 +615,7 @@ def compile( graph: fx.GraphModule, example_inputs: list[Any], compiler_config: dict[str, Any], - runtime_shape: int | None = None, + compile_range: tuple[int, int | None] = None, key: str | None = None, ) -> tuple[Callable | None, Any | None]: compilation_counter.num_eager_compiles += 1 diff --git a/vllm/compilation/inductor_pass.py b/vllm/compilation/inductor_pass.py index 9af635a929b4..1b4430c82b2d 100644 --- a/vllm/compilation/inductor_pass.py +++ b/vllm/compilation/inductor_pass.py @@ -28,8 +28,8 @@ class PassContext: - def __init__(self, runtime_shape: int | None): - self.runtime_shape = runtime_shape + def __init__(self, compile_range: tuple[int, int] | None): + self.compile_range = compile_range def get_pass_context() -> PassContext: @@ -39,13 +39,13 @@ def get_pass_context() -> PassContext: @contextmanager -def pass_context(runtime_shape: int | None): +def pass_context(compile_range: tuple[int, int] | None): """A context manager that stores the current pass context, usually it is a list of sizes to specialize. """ global _pass_context prev_context = _pass_context - _pass_context = PassContext(runtime_shape) + _pass_context = PassContext(compile_range) try: yield finally: @@ -96,7 +96,8 @@ def hash_dict(dict_: dict[Any, Any]): encoded = json.dumps(dict_, sort_keys=True).encode("utf-8") return hashlib.sha256(encoded).hexdigest() - def is_applicable(self, shape: int | None): + def is_applicable_for_range(self, compile_range: tuple[int, + int] | None): return True diff --git a/vllm/compilation/pass_manager.py b/vllm/compilation/pass_manager.py index 3bc35a8f7198..82bca8f1fe1b 100644 --- a/vllm/compilation/pass_manager.py +++ b/vllm/compilation/pass_manager.py @@ -69,9 +69,9 @@ def __init__(self): def __call__(self, graph: fx.Graph): VllmInductorPass.dump_prefix = 0 # reset dump index - shape = get_pass_context().runtime_shape + compile_range = get_pass_context().compile_range for pass_ in self.passes: - if pass_.is_applicable(shape): + if pass_.is_applicable_for_range(compile_range): pass_(graph) VllmInductorPass.dump_prefix += 1 else: diff --git a/vllm/compilation/piecewise_backend.py b/vllm/compilation/piecewise_backend.py index 2931580afbbb..87b0121f43cb 100644 --- a/vllm/compilation/piecewise_backend.py +++ b/vllm/compilation/piecewise_backend.py @@ -7,7 +7,6 @@ import torch.fx as fx -import vllm.envs as envs from vllm.compilation.backends import VllmBackend from vllm.compilation.monitor import end_monitoring_torch_compile from vllm.config import VllmConfig @@ -17,8 +16,8 @@ @dataclasses.dataclass -class ConcreteSizeEntry: - runtime_shape: int +class RangeEntry: + compile_range: tuple[int, int] compiled: bool = False runnable: Callable = None # type: ignore @@ -55,7 +54,12 @@ def __init__( self.is_full_graph = total_piecewise_compiles == 1 - self.compile_sizes: set[int] = set(self.compilation_config.compile_sizes) + self.compile_ranges = self.compilation_config.get_compile_ranges() + log_string = f"PiecewiseBackend: compile_ranges: {self.compile_ranges}" + logger.debug_once(log_string) + + self.is_in_range = lambda x, range: range[0] <= x < range[1] if range[ + 0] < range[1] else x == range[0] self.first_run_finished = False @@ -63,24 +67,27 @@ def __init__( self.sym_shape_indices = sym_shape_indices - self.is_debugging_mode = envs.VLLM_LOGGING_LEVEL == "DEBUG" - # the entries for different shapes that we need to compile - self.concrete_size_entries: dict[int, ConcreteSizeEntry] = {} + # self.concrete_size_entries: dict[int, RangeEntry] = {} + + # the entries for ranges that we need to either + # TODO: we should merge with concrete_size_entries + self.range_entries: dict[tuple[int, int], RangeEntry] = {} - # to_be_compiled_sizes tracks the remaining sizes to compile, + # to_be_compiled_ranges tracks the remaining ranges to compile, # and updates during the compilation process, so we need to copy it - self.to_be_compiled_sizes: set[int] = self.compile_sizes.copy() + self.to_be_compiled_ranges: set[tuple[int, + int]] = set(self.compile_ranges) # We only keep compilation management inside this class directly. - for shape in self.compile_sizes: - self.concrete_size_entries[shape] = ConcreteSizeEntry( - runtime_shape=shape, + for range in self.compile_ranges: + self.range_entries[range] = RangeEntry( + compile_range=range, runnable=self.compiled_graph_for_general_shape, ) def check_for_ending_compilation(self): - if self.is_last_graph and not self.to_be_compiled_sizes: + if (self.is_last_graph and not self.to_be_compiled_ranges): # no specific sizes to compile # save the hash of the inductor graph for the next run self.vllm_backend.compiler_manager.save_to_file() @@ -94,28 +101,32 @@ def __call__(self, *args) -> Any: runtime_shape = args[self.sym_shape_indices[0]] - if runtime_shape not in self.concrete_size_entries: + range_entry = None + for range in self.compile_ranges: + if self.is_in_range(runtime_shape, range): + range_entry = self.range_entries[range] + break + + if (range_entry is None): # we don't need to do anything for this shape return self.compiled_graph_for_general_shape(*args) - entry = self.concrete_size_entries[runtime_shape] + if not range_entry.compiled: + range_entry.compiled = True + self.to_be_compiled_ranges.remove(range_entry.compile_range) - if not entry.compiled: - entry.compiled = True - self.to_be_compiled_sizes.remove(runtime_shape) # args are real arguments - entry.runnable = self.vllm_backend.compiler_manager.compile( + range_entry.runnable = self.vllm_backend.compiler_manager.compile( self.graph, args, self.compilation_config.inductor_compile_config, self.compilation_config, graph_index=self.piecewise_compile_index, num_graphs=self.total_piecewise_compiles, - runtime_shape=runtime_shape, - ) + compile_range=range_entry.compile_range) # finished compilations for all required shapes - if self.is_last_graph and not self.to_be_compiled_sizes: + if (self.is_last_graph and not self.to_be_compiled_ranges): self.check_for_ending_compilation() - return entry.runnable(*args) + return range_entry.runnable(*args) diff --git a/vllm/compilation/sequence_parallelism.py b/vllm/compilation/sequence_parallelism.py index 31624a8fdcc0..78fd8386f56e 100644 --- a/vllm/compilation/sequence_parallelism.py +++ b/vllm/compilation/sequence_parallelism.py @@ -482,7 +482,7 @@ def __init__(self, config: VllmConfig): ).register(self.patterns) self.dump_patterns(config, self.patterns) - def is_applicable(self, shape: int | None) -> bool: + def is_applicable_for_range(self, compile_range: tuple[int, int] | None) -> bool: # When sequence parallelism is enabled, the residual tensor from RMSNorm # needs to be split along the sequence dimension. However, this dimension # is symbolic during piecewise compilation, and splitting symbolic shapes @@ -502,7 +502,9 @@ def is_applicable(self, shape: int | None) -> bool: ): return True tp_size = get_tensor_model_parallel_world_size() - return shape is not None and shape % tp_size == 0 + return compile_range is not None and ( + compile_range[0] + == compile_range[1]) and (compile_range[1] % tp_size == 0) @VllmInductorPass.time_and_log def __call__(self, graph: fx.Graph): diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index 72418762773c..374e1c99fea0 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -214,6 +214,8 @@ class CompilationConfig: - Inductor compilation: - [`use_inductor`][vllm.config.CompilationConfig.use_inductor] - [`compile_sizes`][vllm.config.CompilationConfig.compile_sizes] + - [`compile_ranges_split_points`] + [vllm.config.CompilationConfig.compile_ranges_split_points] - [`inductor_compile_config`] [vllm.config.CompilationConfig.inductor_compile_config] - [`inductor_passes`][vllm.config.CompilationConfig.inductor_passes] @@ -331,6 +333,16 @@ class CompilationConfig: """Sizes to compile for inductor. In addition to integers, it also supports "cudagraph_capture_sizes" to specify the sizes for cudagraph capture.""" + compile_ranges_split_points: Optional[list[int]] = None + """Split points that represent compile ranges for inductor. + The compile ranges are + [1, split_points[0]), + [split_points[0], split_points[1]), ..., + [split_points[-1], max_num_batched_tokens + 1). + Compile sizes are also used single element ranges: + [compile_sizes[i], compile_sizes[i] + 1). + """ + inductor_compile_config: dict = field(default_factory=dict) """Additional configurations for inductor. - None: use default configurations.""" @@ -914,3 +926,24 @@ def custom_op_log_check(self): enable_str, op, ) + + def get_compile_ranges(self) -> list[tuple[int, int]]: + """Get the compile ranges for the compilation config.""" + compile_ranges_split_points = self.compile_ranges_split_points + compile_ranges = [] + # max_num_batched_tokens + 1 + max_split_point = max(compile_ranges_split_points) + compile_sizes = set(self.compile_sizes) + split_points = sorted( + compile_sizes.union(set(self.compile_ranges_split_points))) + # filter out split points that are greater + # than max_num_batched_tokens + 1 + split_points = [x for x in split_points if x <= max_split_point] + for i, s in enumerate(split_points): + if i == 0: + compile_ranges.append((1, s)) + else: + compile_ranges.append((split_points[i - 1], s)) + if s in compile_sizes and s != 1: + compile_ranges.append((s, s)) + return sorted(compile_ranges) From 0cbb0656ac01d60fb3286e63550d215e95caed81 Mon Sep 17 00:00:00 2001 From: ilmarkov Date: Thu, 4 Sep 2025 10:00:52 -0700 Subject: [PATCH 002/183] Remove general shape graph Signed-off-by: ilmarkov --- vllm/compilation/backends.py | 14 +------ vllm/compilation/piecewise_backend.py | 53 +++++++++++++-------------- vllm/config/compilation.py | 2 + 3 files changed, 30 insertions(+), 39 deletions(-) diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 686c415f7ac3..45a1a8c2f267 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -391,17 +391,7 @@ def call_module( sym_shape_indices = [ i for i, x in enumerate(args) if isinstance(x, torch.SymInt) ] - global compilation_start_time - compiled_graph_for_dynamic_shape = ( - self.vllm_backend.compiler_manager.compile( - submod, - args, - self.compilation_config.inductor_compile_config, - self.compilation_config, - graph_index=index, - num_graphs=len(self.compile_submod_names), - compile_range=None, - )) + # Lazy import here to avoid circular import from .piecewise_backend import PiecewiseBackend @@ -411,7 +401,7 @@ def call_module( index, len(self.compile_submod_names), sym_shape_indices, - compiled_graph_for_dynamic_shape, + # compiled_graph_for_dynamic_shape, self.vllm_backend, ) diff --git a/vllm/compilation/piecewise_backend.py b/vllm/compilation/piecewise_backend.py index 87b0121f43cb..d280b85fc82a 100644 --- a/vllm/compilation/piecewise_backend.py +++ b/vllm/compilation/piecewise_backend.py @@ -63,15 +63,12 @@ def __init__( self.first_run_finished = False - self.compiled_graph_for_general_shape = compiled_graph_for_general_shape # noqa - self.sym_shape_indices = sym_shape_indices # the entries for different shapes that we need to compile # self.concrete_size_entries: dict[int, RangeEntry] = {} # the entries for ranges that we need to either - # TODO: we should merge with concrete_size_entries self.range_entries: dict[tuple[int, int], RangeEntry] = {} # to_be_compiled_ranges tracks the remaining ranges to compile, @@ -81,10 +78,7 @@ def __init__( # We only keep compilation management inside this class directly. for range in self.compile_ranges: - self.range_entries[range] = RangeEntry( - compile_range=range, - runnable=self.compiled_graph_for_general_shape, - ) + self.range_entries[range] = RangeEntry(compile_range=range, ) def check_for_ending_compilation(self): if (self.is_last_graph and not self.to_be_compiled_ranges): @@ -93,24 +87,8 @@ def check_for_ending_compilation(self): self.vllm_backend.compiler_manager.save_to_file() end_monitoring_torch_compile(self.vllm_config) - def __call__(self, *args) -> Any: - if not self.first_run_finished: - self.first_run_finished = True - self.check_for_ending_compilation() - return self.compiled_graph_for_general_shape(*args) - - runtime_shape = args[self.sym_shape_indices[0]] - - range_entry = None - for range in self.compile_ranges: - if self.is_in_range(runtime_shape, range): - range_entry = self.range_entries[range] - break - - if (range_entry is None): - # we don't need to do anything for this shape - return self.compiled_graph_for_general_shape(*args) - + def _maybe_compile_for_range_entry(self, range_entry: RangeEntry, + args) -> Any: if not range_entry.compiled: range_entry.compiled = True self.to_be_compiled_ranges.remove(range_entry.compile_range) @@ -126,7 +104,28 @@ def __call__(self, *args) -> Any: compile_range=range_entry.compile_range) # finished compilations for all required shapes - if (self.is_last_graph and not self.to_be_compiled_ranges): - self.check_for_ending_compilation() + self.check_for_ending_compilation() + + def __call__(self, *args) -> Any: + if not self.first_run_finished: + self.first_run_finished = True + + # Role of the general is taken by the last range + range_entry = self.range_entries[self.compile_ranges[-1]] + self._maybe_compile_for_range_entry(range_entry, args) + return range_entry.runnable(*args) + + runtime_shape = args[self.sym_shape_indices[0]] + + range_entry = None + for range in self.compile_ranges: + if self.is_in_range(runtime_shape, range): + range_entry = self.range_entries[range] + break + assert range_entry is not None, \ + f"Shape out of considered range: {runtime_shape} " \ + "[1, max_num_batched_tokens]" + + self._maybe_compile_for_range_entry(range_entry, args) return range_entry.runnable(*args) diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index 374e1c99fea0..2aab5cb5f295 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -946,4 +946,6 @@ def get_compile_ranges(self) -> list[tuple[int, int]]: compile_ranges.append((split_points[i - 1], s)) if s in compile_sizes and s != 1: compile_ranges.append((s, s)) + assert compile_ranges[-1][1] == max_split_point, \ + "Last compile range end should be max_split_point" return sorted(compile_ranges) From d5392f54cb6e8f15926f1d89642ad08cda44a99c Mon Sep 17 00:00:00 2001 From: ilmarkov Date: Fri, 5 Sep 2025 06:00:15 -0700 Subject: [PATCH 003/183] Add test to test pipeline Signed-off-by: ilmarkov --- .buildkite/test-pipeline.yaml | 1 + 1 file changed, 1 insertion(+) diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 6cbc25b4b3bf..105eca371ff3 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -412,6 +412,7 @@ steps: - pytest -v -s compile/test_decorator.py - pytest -v -s compile/test_noop_elimination.py - pytest -v -s compile/test_aot_compile.py + - pytest -v -s compile/test_compile_ranges.py - label: PyTorch Fullgraph Smoke Test # 15min timeout_in_minutes: 30 From 027c9eb348808e1a37c9dbc86fbfcd020e2166a8 Mon Sep 17 00:00:00 2001 From: ilmarkov Date: Tue, 9 Sep 2025 05:32:05 -0700 Subject: [PATCH 004/183] Fix pre-commit Signed-off-by: ilmarkov --- vllm/compilation/piecewise_backend.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/vllm/compilation/piecewise_backend.py b/vllm/compilation/piecewise_backend.py index d280b85fc82a..cec8aca63d80 100644 --- a/vllm/compilation/piecewise_backend.py +++ b/vllm/compilation/piecewise_backend.py @@ -117,12 +117,13 @@ def __call__(self, *args) -> Any: runtime_shape = args[self.sym_shape_indices[0]] - range_entry = None + range_found = False for range in self.compile_ranges: if self.is_in_range(runtime_shape, range): range_entry = self.range_entries[range] + range_found = True break - assert range_entry is not None, \ + assert range_found, \ f"Shape out of considered range: {runtime_shape} " \ "[1, max_num_batched_tokens]" From b2992d3b9afa19156df1453fa504df87ecbc30d9 Mon Sep 17 00:00:00 2001 From: ilmarkov Date: Thu, 16 Oct 2025 20:12:17 +0000 Subject: [PATCH 005/183] Upd Signed-off-by: ilmarkov --- tests/compile/test_compile_ranges.py | 48 ++++++++-------- vllm/compilation/backends.py | 73 ++++++++++++++---------- vllm/compilation/collective_fusion.py | 19 +++--- vllm/compilation/compiler_interface.py | 16 +++--- vllm/compilation/inductor_pass.py | 3 +- vllm/compilation/pass_manager.py | 2 +- vllm/compilation/piecewise_backend.py | 30 +++++----- vllm/compilation/sequence_parallelism.py | 8 ++- vllm/config/compilation.py | 8 ++- 9 files changed, 114 insertions(+), 93 deletions(-) diff --git a/tests/compile/test_compile_ranges.py b/tests/compile/test_compile_ranges.py index 6759da199f4b..68389ccfbe14 100644 --- a/tests/compile/test_compile_ranges.py +++ b/tests/compile/test_compile_ranges.py @@ -6,8 +6,12 @@ from vllm.compilation.counter import compilation_counter from vllm.compilation.decorators import support_torch_compile -from vllm.config import (CompilationConfig, CompilationLevel, VllmConfig, - set_current_vllm_config) +from vllm.config import ( + CompilationConfig, + CompilationLevel, + VllmConfig, + set_current_vllm_config, +) from vllm.forward_context import set_forward_context from vllm.utils import direct_register_custom_op @@ -18,15 +22,17 @@ MLP_SIZE = 128 -def silly_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - out: torch.Tensor) -> None: +def silly_attention( + q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, out: torch.Tensor +) -> None: out.copy_(q) out += k out += v -def silly_attention_fake(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - out: torch.Tensor) -> None: +def silly_attention_fake( + q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, out: torch.Tensor +) -> None: return @@ -41,12 +47,7 @@ def silly_attention_fake(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, @support_torch_compile class TestModel(nn.Module): - - def __init__(self, - *, - vllm_config: VllmConfig, - prefix: str = '', - **kwargs) -> None: + def __init__(self, *, vllm_config: VllmConfig, prefix: str = "", **kwargs) -> None: super().__init__() def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -59,8 +60,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: @torch.inference_mode -def run_model(vllm_config: VllmConfig, model: nn.Module, - batch_sizes: list[int]): +def run_model(vllm_config: VllmConfig, model: nn.Module, batch_sizes: list[int]): with set_forward_context({}, vllm_config=vllm_config): model(torch.randn(BATCH_SIZE, MLP_SIZE).cuda()) for batch_size in batch_sizes: @@ -68,19 +68,21 @@ def run_model(vllm_config: VllmConfig, model: nn.Module, def test_compile_ranges(): - vllm_config = VllmConfig(compilation_config=CompilationConfig( - level=CompilationLevel.PIECEWISE, - compile_ranges_split_points=[8, 32], - )) + vllm_config = VllmConfig( + compilation_config=CompilationConfig( + level=CompilationLevel.PIECEWISE, + compile_ranges_split_points=[8, 32], + ) + ) with set_current_vllm_config(vllm_config): - model = TestModel(vllm_config=vllm_config, prefix='').eval().cuda() + model = TestModel(vllm_config=vllm_config, prefix="").eval().cuda() batch_sizes = [1, 16, 48] # A has support_torch_compile with compilation_counter.expect( - num_graphs_seen=1, - num_piecewise_graphs_seen=1, - num_backend_compilations=4, - # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen + num_graphs_seen=1, + num_piecewise_graphs_seen=1, + num_backend_compilations=4, + # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen ): run_model(vllm_config, model, batch_sizes) diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 45a1a8c2f267..beda9b36f686 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -80,8 +80,7 @@ class CompilerManager: """ def __init__(self, compilation_config: CompilationConfig): - self.cache: dict[tuple[tuple[int, int] | None, int, str], - Any] = (dict()) + self.cache: dict[tuple[tuple[int, int] | None, int, str], Any] = dict() self.is_cache_updated = False self.compilation_config = compilation_config self.compiler = make_compiler(compilation_config) @@ -156,20 +155,26 @@ def load( if (compile_range, graph_index, self.compiler.name) not in self.cache: return None handle = self.cache[(compile_range, graph_index, self.compiler.name)] - compiled_graph = self.compiler.load(handle, graph, example_inputs, - graph_index, compile_range) + compiled_graph = self.compiler.load( + handle, graph, example_inputs, graph_index, compile_range + ) if compile_range is None: logger.debug( - "Directly load the %s-th graph for dynamic compile range from %s via handle %s", + "Directly load the %s-th graph for dynamic compile range" + "from %s via handle %s", graph_index, self.compiler.name, handle, ) else: logger.debug( - "Directly load the %s-th graph for compile range %s from %s via " - "handle %s", graph_index, str(compile_range), - self.compiler.name, handle) + "Directly load the %s-th graph for compile range %s" + "from %s via handle %s", + graph_index, + str(compile_range), + self.compiler.name, + handle, + ) return compiled_graph def compile( @@ -192,8 +197,7 @@ def compile( compiled_graph = None # try to load from the cache - compiled_graph = self.load(graph, example_inputs, graph_index, - compile_range) + compiled_graph = self.load(graph, example_inputs, graph_index, compile_range) if compiled_graph is not None: if graph_index == num_graphs - 1: # after loading the last graph for this shape, record the time. @@ -209,8 +213,10 @@ def compile( else: logger.info( "Directly load the compiled graph(s) for compile range %s " - "from the cache, took %.3f s", str(compile_range), - elapsed) + "from the cache, took %.3f s", + str(compile_range), + elapsed, + ) return compiled_graph # no compiler cached the graph, or the cache is disabled, @@ -219,38 +225,43 @@ def compile( # Let compile_fx generate a key for us maybe_key = None else: - maybe_key = \ - f"artifact_compile_range_{compile_range}_subgraph_{graph_index}" - compiled_graph, handle = self.compiler.compile( - graph, example_inputs, additional_inductor_config, compile_range, - maybe_key) + maybe_key = f"artifact_compile_range_{compile_range}_subgraph_{graph_index}" + with self.compile_context(compile_range): + compiled_graph, handle = self.compiler.compile( + graph, + example_inputs, + additional_inductor_config, + compile_range, + maybe_key, + ) assert compiled_graph is not None, "Failed to compile the graph" # store the artifact in the cache if not envs.VLLM_DISABLE_COMPILE_CACHE and handle is not None: - self.cache[(compile_range, graph_index, - self.compiler.name)] = handle + self.cache[(compile_range, graph_index, self.compiler.name)] = handle compilation_counter.num_cache_entries_updated += 1 self.is_cache_updated = True if graph_index == 0: # adds some info logging for the first graph if compile_range is None: - logger.info_once( - "Cache the graph for dynamic shape for later use") + logger.info_once("Cache the graph for dynamic shape for later use", scope="local") else: - logger.info_once("Cache the graph of compile range %s for later use", - str(compile_range)) + logger.info_once( + "Cache the graph of compile range %s for later use", + str(compile_range), + ) if compile_range is None: logger.debug( - "Store the %s-th graph for dynamic compile range from %s via handle %s", + "Store the %s-th graph for dynamic compile range" + "from %s via handle %s", graph_index, self.compiler.name, handle, ) else: logger.debug( - "Store the %s-th graph for compile range %s from %s via handle %s", + "Store the %s-th graph for compile range%s from %s via handle %s", graph_index, str(compile_range), self.compiler.name, @@ -264,14 +275,17 @@ def compile( compilation_config.compilation_time += elapsed if compile_range is None: logger.info_once( - "Compiling a graph for dynamic compile range takes %.2f s", - + "Compiling a graph for dynamic compile range takes %.2f s", elapsed, scope="local", ) else: - logger.info_once("Compiling a graph for compile range %s takes %.2f s", - str(compile_range), elapsed, scope="local") + logger.info_once( + "Compiling a graph for compile range %s takes %.2f s", + str(compile_range), + elapsed, + scope="local", + ) return compiled_graph @@ -401,7 +415,6 @@ def call_module( index, len(self.compile_submod_names), sym_shape_indices, - # compiled_graph_for_dynamic_shape, self.vllm_backend, ) diff --git a/vllm/compilation/collective_fusion.py b/vllm/compilation/collective_fusion.py index a4758c971611..3d970ac2964b 100644 --- a/vllm/compilation/collective_fusion.py +++ b/vllm/compilation/collective_fusion.py @@ -505,12 +505,12 @@ def call_trtllm_fused_allreduce_norm( element_size = allreduce_in.element_size() current_tensor_size = num_tokens * hidden_size * element_size max_tensor_size = max_token_num * hidden_size * element_size - assert current_tensor_size <= max_tensor_size, \ - f"Current tensor size {current_tensor_size} is larger than " \ - f"max token num {max_token_num} * hidden size {hidden_size} * " \ + assert current_tensor_size <= max_tensor_size, ( + f"Current tensor size {current_tensor_size} is larger than " + f"max token num {max_token_num} * hidden size {hidden_size} * " f"element size {element_size}" - device_capability = current_platform.get_device_capability( - ).as_version_str() + ) + device_capability = current_platform.get_device_capability().as_version_str() # Get one shot input size limit for the current world size # for the current device capability max_one_shot_size = _FI_ALLREDUCE_ONE_SHOT_MAX_SIZES_MB. \ @@ -520,9 +520,9 @@ def call_trtllm_fused_allreduce_norm( use_oneshot = max_one_shot_size is None or \ current_tensor_size <= max_one_shot_size * MiB - assert ( - _FI_WORKSPACE_TENSOR - is not None), "Flashinfer must be enabled when using flashinfer" + assert _FI_WORKSPACE_TENSOR is not None, ( + "Flashinfer must be enabled when using flashinfer" + ) if norm_out is None: norm_out = allreduce_in residual_out = residual @@ -1178,8 +1178,7 @@ def register_patterns(self): self.disabled = False @VllmInductorPass.time_and_log - def is_applicable_for_range( - self, compile_range: tuple[int, int] | None) -> bool: + def is_applicable_for_range(self, compile_range: tuple[int, int] | None) -> bool: if compile_range is None: return False return compile_range[1] - 1 <= self.max_token_num diff --git a/vllm/compilation/compiler_interface.py b/vllm/compilation/compiler_interface.py index 3861bfed11d5..4e5aa077ddae 100644 --- a/vllm/compilation/compiler_interface.py +++ b/vllm/compilation/compiler_interface.py @@ -63,14 +63,14 @@ def compile( graph: fx.GraphModule, example_inputs: list[Any], compiler_config: dict[str, Any], - compile_range: tuple[int, int | None] = None, + compile_range: tuple[int, int] | None = None, key: str | None = None, ) -> tuple[Callable | None, Any | None]: """ Compile the graph with the given example inputs and compiler config, with a range. If the `compile_range` is None, it means the `example_inputs` have a dynamic shape. Otherwise, the - `compile_range` specifies the range of the inputs, + `compile_range` specifies the range of the inputs, it could be concrete size, e.g. (4, 4). Right now we only support one variable range of shapes for all inputs, which is the batchsize (number of tokens) during inference. @@ -99,7 +99,7 @@ def load( graph: fx.GraphModule, example_inputs: list[Any], graph_index: int, - compile_range: tuple[int, int | None] = None, + compile_range: tuple[int, int] | None = None, ) -> Callable: """ Load the compiled function from the handle. @@ -193,7 +193,7 @@ def compile( graph: fx.GraphModule, example_inputs: list[Any], compiler_config: dict[str, Any], - compile_range: tuple[int, int | None] = None, + compile_range: tuple[int, int] | None = None, key: str | None = None, ) -> tuple[Callable | None, Any | None]: compilation_counter.num_inductor_compiles += 1 @@ -234,7 +234,7 @@ def load( graph: fx.GraphModule, example_inputs: list[Any], graph_index: int, - compile_range: tuple[int, int | None] = None, + compile_range: tuple[int, int] | None = None, ) -> Callable: assert isinstance(handle, tuple) assert isinstance(handle[0], str) @@ -298,7 +298,7 @@ def compile( graph: fx.GraphModule, example_inputs: list[Any], compiler_config: dict[str, Any], - compile_range: tuple[int, int | None] = None, + compile_range: tuple[int, int] | None = None, key: str | None = None, ) -> tuple[Callable | None, Any | None]: compilation_counter.num_inductor_compiles += 1 @@ -497,7 +497,7 @@ def load( graph: fx.GraphModule, example_inputs: list[Any], graph_index: int, - compile_range: tuple[int, int | None] = None, + compile_range: tuple[int, int] | None = None, ) -> Callable: assert isinstance(handle, tuple) assert isinstance(handle[0], str) @@ -615,7 +615,7 @@ def compile( graph: fx.GraphModule, example_inputs: list[Any], compiler_config: dict[str, Any], - compile_range: tuple[int, int | None] = None, + compile_range: tuple[int, int] | None = None, key: str | None = None, ) -> tuple[Callable | None, Any | None]: compilation_counter.num_eager_compiles += 1 diff --git a/vllm/compilation/inductor_pass.py b/vllm/compilation/inductor_pass.py index 1b4430c82b2d..599fa776b6c0 100644 --- a/vllm/compilation/inductor_pass.py +++ b/vllm/compilation/inductor_pass.py @@ -96,8 +96,7 @@ def hash_dict(dict_: dict[Any, Any]): encoded = json.dumps(dict_, sort_keys=True).encode("utf-8") return hashlib.sha256(encoded).hexdigest() - def is_applicable_for_range(self, compile_range: tuple[int, - int] | None): + def is_applicable_for_range(self, compile_range: tuple[int, int] | None): return True diff --git a/vllm/compilation/pass_manager.py b/vllm/compilation/pass_manager.py index 82bca8f1fe1b..08002dc862f6 100644 --- a/vllm/compilation/pass_manager.py +++ b/vllm/compilation/pass_manager.py @@ -75,7 +75,7 @@ def __call__(self, graph: fx.Graph): pass_(graph) VllmInductorPass.dump_prefix += 1 else: - logger.debug("Skipping %s with shape %s", pass_, shape) + logger.debug("Skipping %s with compile range %s", pass_, compile_range) # post-cleanup goes before fix_functionalization # because it requires a functional graph diff --git a/vllm/compilation/piecewise_backend.py b/vllm/compilation/piecewise_backend.py index cec8aca63d80..607d6a80f5cf 100644 --- a/vllm/compilation/piecewise_backend.py +++ b/vllm/compilation/piecewise_backend.py @@ -30,7 +30,6 @@ def __init__( piecewise_compile_index: int, total_piecewise_compiles: int, sym_shape_indices: list[int], - compiled_graph_for_general_shape: Callable, vllm_backend: VllmBackend, ): """ @@ -58,8 +57,11 @@ def __init__( log_string = f"PiecewiseBackend: compile_ranges: {self.compile_ranges}" logger.debug_once(log_string) - self.is_in_range = lambda x, range: range[0] <= x < range[1] if range[ - 0] < range[1] else x == range[0] + self.is_in_range = ( + lambda x, range: range[0] <= x < range[1] + if range[0] < range[1] + else x == range[0] + ) self.first_run_finished = False @@ -73,22 +75,22 @@ def __init__( # to_be_compiled_ranges tracks the remaining ranges to compile, # and updates during the compilation process, so we need to copy it - self.to_be_compiled_ranges: set[tuple[int, - int]] = set(self.compile_ranges) + self.to_be_compiled_ranges: set[tuple[int, int]] = set(self.compile_ranges) # We only keep compilation management inside this class directly. for range in self.compile_ranges: - self.range_entries[range] = RangeEntry(compile_range=range, ) + self.range_entries[range] = RangeEntry( + compile_range=range, + ) def check_for_ending_compilation(self): - if (self.is_last_graph and not self.to_be_compiled_ranges): + if self.is_last_graph and not self.to_be_compiled_ranges: # no specific sizes to compile # save the hash of the inductor graph for the next run self.vllm_backend.compiler_manager.save_to_file() end_monitoring_torch_compile(self.vllm_config) - def _maybe_compile_for_range_entry(self, range_entry: RangeEntry, - args) -> Any: + def _maybe_compile_for_range_entry(self, range_entry: RangeEntry, args) -> Any: if not range_entry.compiled: range_entry.compiled = True self.to_be_compiled_ranges.remove(range_entry.compile_range) @@ -101,7 +103,8 @@ def _maybe_compile_for_range_entry(self, range_entry: RangeEntry, self.compilation_config, graph_index=self.piecewise_compile_index, num_graphs=self.total_piecewise_compiles, - compile_range=range_entry.compile_range) + compile_range=range_entry.compile_range, + ) # finished compilations for all required shapes self.check_for_ending_compilation() @@ -123,9 +126,10 @@ def __call__(self, *args) -> Any: range_entry = self.range_entries[range] range_found = True break - assert range_found, \ - f"Shape out of considered range: {runtime_shape} " \ - "[1, max_num_batched_tokens]" + assert range_found, ( + f"Shape out of considered range: {runtime_shape} " + "[1, max_num_batched_tokens]" + ) self._maybe_compile_for_range_entry(range_entry, args) diff --git a/vllm/compilation/sequence_parallelism.py b/vllm/compilation/sequence_parallelism.py index 78fd8386f56e..cf47adb4670a 100644 --- a/vllm/compilation/sequence_parallelism.py +++ b/vllm/compilation/sequence_parallelism.py @@ -502,9 +502,11 @@ def is_applicable_for_range(self, compile_range: tuple[int, int] | None) -> bool ): return True tp_size = get_tensor_model_parallel_world_size() - return compile_range is not None and ( - compile_range[0] - == compile_range[1]) and (compile_range[1] % tp_size == 0) + return ( + compile_range is not None + and (compile_range[0] == compile_range[1]) + and (compile_range[1] % tp_size == 0) + ) @VllmInductorPass.time_and_log def __call__(self, graph: fx.Graph): diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index 2aab5cb5f295..278fe5801323 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -333,7 +333,7 @@ class CompilationConfig: """Sizes to compile for inductor. In addition to integers, it also supports "cudagraph_capture_sizes" to specify the sizes for cudagraph capture.""" - compile_ranges_split_points: Optional[list[int]] = None + compile_ranges_split_points: list[int] | None = None """Split points that represent compile ranges for inductor. The compile ranges are [1, split_points[0]), @@ -935,7 +935,8 @@ def get_compile_ranges(self) -> list[tuple[int, int]]: max_split_point = max(compile_ranges_split_points) compile_sizes = set(self.compile_sizes) split_points = sorted( - compile_sizes.union(set(self.compile_ranges_split_points))) + compile_sizes.union(set(self.compile_ranges_split_points)) + ) # filter out split points that are greater # than max_num_batched_tokens + 1 split_points = [x for x in split_points if x <= max_split_point] @@ -946,6 +947,7 @@ def get_compile_ranges(self) -> list[tuple[int, int]]: compile_ranges.append((split_points[i - 1], s)) if s in compile_sizes and s != 1: compile_ranges.append((s, s)) - assert compile_ranges[-1][1] == max_split_point, \ + assert compile_ranges[-1][1] == max_split_point, ( "Last compile range end should be max_split_point" + ) return sorted(compile_ranges) From 3499384c1e183cd851c93d12ea7d77c08de03ed2 Mon Sep 17 00:00:00 2001 From: ilmarkov Date: Thu, 16 Oct 2025 20:32:36 +0000 Subject: [PATCH 006/183] Upd config Signed-off-by: ilmarkov --- vllm/config/vllm.py | 45 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 45 insertions(+) diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index 916f258d6586..fd38992e374b 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -426,6 +426,8 @@ def __post_init__(self): "correctness and to realize prefill savings. " ) + self._set_compile_ranges() + disable_chunked_prefill_reasons: list[str] = [] if self.model_config: @@ -796,6 +798,49 @@ def _set_cudagraph_sizes(self): # complete the remaining process. self.compilation_config.post_init_cudagraph_sizes() + def _set_compile_ranges(self): + """ + Set the compile ranges for the compilation config. + """ + compilation_config = self.compilation_config + computed_compile_ranges_split_points = [] + + # The upper bound of the compile ranges is the max_num_batched_tokens + max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens + if max_num_batched_tokens is not None: + # We add 1 because the bounds checks in the compiler are exclusive + # and we want to include the max_num_batched_tokens + # in the compile range + computed_compile_ranges_split_points.append(max_num_batched_tokens + 1) + + # Add the compile ranges for flashinfer + if compilation_config.pass_config.enable_fi_allreduce_fusion: + tp_size = self.parallel_config.tensor_parallel_size + max_size = compilation_config.pass_config.flashinfer_max_size(tp_size) + if max_size is not None: + max_token_num = max_size // ( + self.model_config.get_hidden_size() + * self.model_config.dtype.itemsize + ) + # We add 1 because the bounds checks in the compiler are + # exclusive and we want to include the max_token_num in the + # compile range + computed_compile_ranges_split_points.append(max_token_num + 1) + + if compilation_config.compile_ranges_split_points is not None: + for x in compilation_config.compile_ranges_split_points: + assert isinstance(x, int) + assert x > 0, f"Invalid compile range split point: {x}" + if ( + max_num_batched_tokens is not None + and x < max_num_batched_tokens + and x > 1 + ): + computed_compile_ranges_split_points.append(x) + compilation_config.compile_ranges_split_points = sorted( + computed_compile_ranges_split_points + ) # type: ignore + def recalculate_max_model_len(self, max_model_len: int): # Can only be called in try_verify_and_update_config model_config = self.model_config From 5336ee6ffe1d5b03b69b23f4b346ba10a549c6cd Mon Sep 17 00:00:00 2001 From: ilmarkov Date: Thu, 16 Oct 2025 20:51:01 +0000 Subject: [PATCH 007/183] Fix Signed-off-by: ilmarkov --- vllm/compilation/collective_fusion.py | 18 ++++++++++-------- vllm/v1/worker/utils.py | 2 +- 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/vllm/compilation/collective_fusion.py b/vllm/compilation/collective_fusion.py index 3d970ac2964b..7c0a1208d870 100644 --- a/vllm/compilation/collective_fusion.py +++ b/vllm/compilation/collective_fusion.py @@ -431,7 +431,7 @@ def __init__(self, config: VllmConfig): self.dump_patterns(config, self.patterns) - def is_applicable(self, shape: int | None) -> bool: + def is_applicable_for_range(self, compile_range: tuple[int, int] | None) -> bool: # This pass is applied on top of the sequence parallelism pass. # It inherits the same applicability condition as `SequenceParallelismPass`. # See `SequenceParallelismPass.is_applicable` for more details. @@ -441,7 +441,9 @@ def is_applicable(self, shape: int | None) -> bool: ): return True tp_size = get_tensor_model_parallel_world_size() - return shape is not None and shape % tp_size == 0 + return compile_range is not None and ( + compile_range[0] == compile_range[1] and compile_range[1] % tp_size == 0 + ) @VllmInductorPass.time_and_log def __call__(self, graph: fx.Graph): @@ -1100,18 +1102,18 @@ def __init__(self, config: VllmConfig): ) return element_size = 4 if use_fp32_lamport else 2 - max_token_num = max_size // (self.hidden_dim * element_size) + self.max_token_num = max_size // (self.hidden_dim * element_size) # take the min to save workspace size and we'll never use more # than max_num_batched_tokens anyways - max_token_num = min( - max_token_num, config.scheduler_config.max_num_batched_tokens + self.max_token_num = min( + self.max_token_num, config.scheduler_config.max_num_batched_tokens ) self.ipc_handles, workspace_tensor = ( flashinfer_comm.trtllm_create_ipc_workspace_for_all_reduce_fusion( tp_rank=rank, tp_size=self.tp_size, - max_token_num=max_token_num, + max_token_num=self.max_token_num, hidden_dim=self.hidden_dim, group=self.group, use_fp32_lamport=use_fp32_lamport, @@ -1124,7 +1126,7 @@ def __init__(self, config: VllmConfig): rank=rank, world_size=self.tp_size, use_fp32_lamport=use_fp32_lamport, - max_token_num=max_token_num, + max_token_num=self.max_token_num, ) self.register_patterns() @@ -1177,12 +1179,12 @@ def register_patterns(self): self.disabled = False - @VllmInductorPass.time_and_log def is_applicable_for_range(self, compile_range: tuple[int, int] | None) -> bool: if compile_range is None: return False return compile_range[1] - 1 <= self.max_token_num + @VllmInductorPass.time_and_log def __call__(self, graph: fx.Graph): if self.disabled: logger.debug("AllReduceFusionPass disabled") diff --git a/vllm/v1/worker/utils.py b/vllm/v1/worker/utils.py index 92baf0cb7136..ef953dd2051e 100644 --- a/vllm/v1/worker/utils.py +++ b/vllm/v1/worker/utils.py @@ -330,7 +330,7 @@ def is_residual_scattered_for_sp( The residual tensor is scattered across tensor parallel ranks when sequence parallelism and tensor parallelism is enabled. - This follows the same logic as SequenceParallelismPass.is_applicable(): + This follows the same logic as SequenceParallelismPass.is_applicable_for_range(): - In full-graph compilation mode (no splitting ops or using inductor graph partition), SP is always applied - Otherwise, SP is only applied for specific shapes in compile_sizes From 4958474f77a930f532730a9ec7a395339ea32138 Mon Sep 17 00:00:00 2001 From: ilmarkov Date: Fri, 17 Oct 2025 11:30:21 +0000 Subject: [PATCH 008/183] Priotitize compile_sizes Signed-off-by: ilmarkov --- vllm/compilation/piecewise_backend.py | 28 ++++++++++++++++++++------- vllm/config/compilation.py | 18 ++--------------- 2 files changed, 23 insertions(+), 23 deletions(-) diff --git a/vllm/compilation/piecewise_backend.py b/vllm/compilation/piecewise_backend.py index 607d6a80f5cf..7a10fed1d237 100644 --- a/vllm/compilation/piecewise_backend.py +++ b/vllm/compilation/piecewise_backend.py @@ -57,6 +57,10 @@ def __init__( log_string = f"PiecewiseBackend: compile_ranges: {self.compile_ranges}" logger.debug_once(log_string) + self.compile_sizes = self.compilation_config.compile_sizes + log_string = f"PiecewiseBackend: compile_sizes: {self.compile_sizes}" + logger.debug_once(log_string) + self.is_in_range = ( lambda x, range: range[0] <= x < range[1] if range[0] < range[1] @@ -78,6 +82,12 @@ def __init__( self.to_be_compiled_ranges: set[tuple[int, int]] = set(self.compile_ranges) # We only keep compilation management inside this class directly. + for size in self.compile_sizes: + range = (size, size) + self.range_entries[range] = RangeEntry( + compile_range=range, + ) + for range in self.compile_ranges: self.range_entries[range] = RangeEntry( compile_range=range, @@ -112,20 +122,24 @@ def _maybe_compile_for_range_entry(self, range_entry: RangeEntry, args) -> Any: def __call__(self, *args) -> Any: if not self.first_run_finished: self.first_run_finished = True + self.check_for_ending_compilation() - # Role of the general is taken by the last range + # Role of the general graph is taken by the last range graph range_entry = self.range_entries[self.compile_ranges[-1]] self._maybe_compile_for_range_entry(range_entry, args) return range_entry.runnable(*args) - runtime_shape = args[self.sym_shape_indices[0]] range_found = False - for range in self.compile_ranges: - if self.is_in_range(runtime_shape, range): - range_entry = self.range_entries[range] - range_found = True - break + if runtime_shape in self.compile_sizes: + range_entry = self.range_entries[(runtime_shape, runtime_shape)] + range_found = True + else: + for range in self.compile_ranges: + if self.is_in_range(runtime_shape, range): + range_entry = self.range_entries[range] + range_found = True + break assert range_found, ( f"Shape out of considered range: {runtime_shape} " "[1, max_num_batched_tokens]" diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index 278fe5801323..c2a6d6d783b9 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -929,25 +929,11 @@ def custom_op_log_check(self): def get_compile_ranges(self) -> list[tuple[int, int]]: """Get the compile ranges for the compilation config.""" - compile_ranges_split_points = self.compile_ranges_split_points + split_points = self.compile_ranges_split_points compile_ranges = [] - # max_num_batched_tokens + 1 - max_split_point = max(compile_ranges_split_points) - compile_sizes = set(self.compile_sizes) - split_points = sorted( - compile_sizes.union(set(self.compile_ranges_split_points)) - ) - # filter out split points that are greater - # than max_num_batched_tokens + 1 - split_points = [x for x in split_points if x <= max_split_point] for i, s in enumerate(split_points): if i == 0: compile_ranges.append((1, s)) else: compile_ranges.append((split_points[i - 1], s)) - if s in compile_sizes and s != 1: - compile_ranges.append((s, s)) - assert compile_ranges[-1][1] == max_split_point, ( - "Last compile range end should be max_split_point" - ) - return sorted(compile_ranges) + return compile_ranges From 04306ed0dacf3fc11bcfb5ae993095d8d5a506bb Mon Sep 17 00:00:00 2001 From: ilmarkov Date: Tue, 28 Oct 2025 13:26:59 +0000 Subject: [PATCH 009/183] Fix inductor config Signed-off-by: ilmarkov --- vllm/compilation/backends.py | 7 ++++++- vllm/compilation/compiler_interface.py | 4 ++-- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index beda9b36f686..30ab91e4ab82 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -225,7 +225,12 @@ def compile( # Let compile_fx generate a key for us maybe_key = None else: - maybe_key = f"artifact_compile_range_{compile_range}_subgraph_{graph_index}" + maybe_key = "artifact_compile_range_" + if compile_range is None: + maybe_key += "dynamic_shape" + else: + maybe_key += f"{compile_range[0]}_{compile_range[1]}" + maybe_key += f"_subgraph_{graph_index}" with self.compile_context(compile_range): compiled_graph, handle = self.compiler.compile( graph, diff --git a/vllm/compilation/compiler_interface.py b/vllm/compilation/compiler_interface.py index 4e5aa077ddae..d069769fe76f 100644 --- a/vllm/compilation/compiler_interface.py +++ b/vllm/compilation/compiler_interface.py @@ -594,8 +594,8 @@ def metrics_context(self) -> contextlib.AbstractContextManager: def set_inductor_config(config, compile_range): - if isinstance(compile_range, tuple): - # for a specific range of batchsizes, tuning triton kernel parameters + if isinstance(compile_range, tuple) and compile_range[0] == compile_range[1]: + # for a specific batch size, tuning triton kernel parameters # can be beneficial config["max_autotune"] = envs.VLLM_ENABLE_INDUCTOR_MAX_AUTOTUNE config["coordinate_descent_tuning"] = ( From 9dc4eea25b0ec2520d920616002a6f148a1c3801 Mon Sep 17 00:00:00 2001 From: ilmarkov Date: Mon, 3 Nov 2025 10:53:49 +0000 Subject: [PATCH 010/183] Laith's fix Signed-off-by: ilmarkov --- vllm/compilation/compiler_interface.py | 38 +++++++++++++++++++++----- 1 file changed, 31 insertions(+), 7 deletions(-) diff --git a/vllm/compilation/compiler_interface.py b/vllm/compilation/compiler_interface.py index d069769fe76f..3453b8f676e8 100644 --- a/vllm/compilation/compiler_interface.py +++ b/vllm/compilation/compiler_interface.py @@ -213,13 +213,37 @@ def compile( from torch._inductor import standalone_compile - compiled_graph = standalone_compile( - graph, - example_inputs, - dynamic_shapes=dynamic_shapes, - options={"config_patches": current_config}, - ) - + if dynamic_shapes == "from_graph": + # We need to pass fake example_inputs, otherwise torch.compile + # will fakify the example_inputs potentially causing some non dynamic + # dimension to be be duck shaped to other existing shapes that have hints + # matching their values. + # This is problem because it can lead to unintended specializations! + # if the new wrongly dynamic dim is specialized + # it will force specializing the whole shape + # standalone_compile probably should not accept + # non fake tensors as example inputs! + fake_example_inputs = [] + for node in graph.graph.nodes: + # All place holders come first + if node.op == "placeholder": + fake_example_inputs.append(node.meta["example_value"]) + else: + break + assert len(fake_example_inputs) == len(example_inputs) + compiled_graph = standalone_compile( + graph, + fake_example_inputs, + dynamic_shapes=dynamic_shapes, + options={"config_patches": current_config}, + ) + else: + compiled_graph = standalone_compile( + graph, + example_inputs, + dynamic_shapes=dynamic_shapes, + options={"config_patches": current_config}, + ) # Save the compiled artifact to disk in the specified path assert key is not None path = os.path.join(self.cache_dir, key) From 2c63f0b05c02ce4d93e23093b3838af775d92614 Mon Sep 17 00:00:00 2001 From: ilmarkov Date: Tue, 4 Nov 2025 10:22:17 +0000 Subject: [PATCH 011/183] Upd Signed-off-by: ilmarkov --- vllm/compilation/backends.py | 6 ++++-- vllm/compilation/collective_fusion.py | 11 ++++++----- vllm/config/compilation.py | 3 +++ 3 files changed, 13 insertions(+), 7 deletions(-) diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 30ab91e4ab82..7cda5d0dee96 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -250,7 +250,9 @@ def compile( if graph_index == 0: # adds some info logging for the first graph if compile_range is None: - logger.info_once("Cache the graph for dynamic shape for later use", scope="local") + logger.info_once( + "Cache the graph for dynamic shape for later use", scope="local" + ) else: logger.info_once( "Cache the graph of compile range %s for later use", @@ -280,7 +282,7 @@ def compile( compilation_config.compilation_time += elapsed if compile_range is None: logger.info_once( - "Compiling a graph for dynamic compile range takes %.2f s", + "Compiling a graph for dynamic compile range takes %.2f s", elapsed, scope="local", ) diff --git a/vllm/compilation/collective_fusion.py b/vllm/compilation/collective_fusion.py index 7c0a1208d870..9c20db07c267 100644 --- a/vllm/compilation/collective_fusion.py +++ b/vllm/compilation/collective_fusion.py @@ -515,12 +515,13 @@ def call_trtllm_fused_allreduce_norm( device_capability = current_platform.get_device_capability().as_version_str() # Get one shot input size limit for the current world size # for the current device capability - max_one_shot_size = _FI_ALLREDUCE_ONE_SHOT_MAX_SIZES_MB. \ - get(device_capability, {}). \ - get(world_size, None) + max_one_shot_size = _FI_ALLREDUCE_ONE_SHOT_MAX_SIZES_MB.get( + device_capability, {} + ).get(world_size, None) # Use one shot if no max size is specified - use_oneshot = max_one_shot_size is None or \ - current_tensor_size <= max_one_shot_size * MiB + use_oneshot = ( + max_one_shot_size is None or current_tensor_size <= max_one_shot_size * MiB + ) assert _FI_WORKSPACE_TENSOR is not None, ( "Flashinfer must be enabled when using flashinfer" diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index c2a6d6d783b9..e469c8e25a43 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -142,6 +142,9 @@ def flashinfer_max_size(self, world_size: int) -> int | None: max_sizes = { k: int(v * MiB) for k, v in self.fi_allreduce_fusion_max_size_mb.items() } + logger.debug_once( + f"flashinfer_max_size: {max_sizes.get(world_size)}", scope="global" + ) # return None if world size is not supported by flashinfer return max_sizes.get(world_size) From fcebc21fb1708abbfc2622cfeee517aef801c622 Mon Sep 17 00:00:00 2001 From: ilmarkov Date: Tue, 4 Nov 2025 14:30:18 +0000 Subject: [PATCH 012/183] Add caching Signed-off-by: ilmarkov --- vllm/compilation/compiler_interface.py | 37 +++++--------------------- vllm/compilation/pass_manager.py | 1 + vllm/compilation/piecewise_backend.py | 23 +++++++++++++++- vllm/config/compilation.py | 8 +++--- 4 files changed, 32 insertions(+), 37 deletions(-) diff --git a/vllm/compilation/compiler_interface.py b/vllm/compilation/compiler_interface.py index 3453b8f676e8..6a57cd4bc578 100644 --- a/vllm/compilation/compiler_interface.py +++ b/vllm/compilation/compiler_interface.py @@ -213,37 +213,12 @@ def compile( from torch._inductor import standalone_compile - if dynamic_shapes == "from_graph": - # We need to pass fake example_inputs, otherwise torch.compile - # will fakify the example_inputs potentially causing some non dynamic - # dimension to be be duck shaped to other existing shapes that have hints - # matching their values. - # This is problem because it can lead to unintended specializations! - # if the new wrongly dynamic dim is specialized - # it will force specializing the whole shape - # standalone_compile probably should not accept - # non fake tensors as example inputs! - fake_example_inputs = [] - for node in graph.graph.nodes: - # All place holders come first - if node.op == "placeholder": - fake_example_inputs.append(node.meta["example_value"]) - else: - break - assert len(fake_example_inputs) == len(example_inputs) - compiled_graph = standalone_compile( - graph, - fake_example_inputs, - dynamic_shapes=dynamic_shapes, - options={"config_patches": current_config}, - ) - else: - compiled_graph = standalone_compile( - graph, - example_inputs, - dynamic_shapes=dynamic_shapes, - options={"config_patches": current_config}, - ) + compiled_graph = standalone_compile( + graph, + example_inputs, + dynamic_shapes=dynamic_shapes, + options={"config_patches": current_config}, + ) # Save the compiled artifact to disk in the specified path assert key is not None path = os.path.join(self.cache_dir, key) diff --git a/vllm/compilation/pass_manager.py b/vllm/compilation/pass_manager.py index 08002dc862f6..3e0c9bc99a24 100644 --- a/vllm/compilation/pass_manager.py +++ b/vllm/compilation/pass_manager.py @@ -155,5 +155,6 @@ def uuid(self): # See [HACK: Bug with Inductor graph partition and torch.compile cache] state["inductor_splitting_ops"].extend(self.inductor_splitting_ops) + state["compile_range"] = get_pass_context().compile_range return InductorPass.hash_dict(state) diff --git a/vllm/compilation/piecewise_backend.py b/vllm/compilation/piecewise_backend.py index 7a10fed1d237..ad5b49f28550 100644 --- a/vllm/compilation/piecewise_backend.py +++ b/vllm/compilation/piecewise_backend.py @@ -87,6 +87,7 @@ def __init__( self.range_entries[range] = RangeEntry( compile_range=range, ) + self.to_be_compiled_ranges.add(range) for range in self.compile_ranges: self.range_entries[range] = RangeEntry( @@ -100,6 +101,26 @@ def check_for_ending_compilation(self): self.vllm_backend.compiler_manager.save_to_file() end_monitoring_torch_compile(self.vllm_config) + def fakify_args(self, args: list[Any]) -> list[Any]: + # We need to pass fake example_inputs, otherwise torch.compile + # will fakify the example_inputs potentially causing some non dynamic + # dimension to be be duck shaped to other existing shapes that have hints + # matching their values. + # This is problem because it can lead to unintended specializations! + # if the new wrongly dynamic dim is specialized + # it will force specializing the whole shape + # torch.compile probably should not accept + # non fake tensors as example inputs! + fake_example_inputs = [] + for node in self.graph.graph.nodes: + # All place holders come first + if node.op == "placeholder": + fake_example_inputs.append(node.meta["example_value"]) + else: + break + assert len(fake_example_inputs) == len(args) + return fake_example_inputs + def _maybe_compile_for_range_entry(self, range_entry: RangeEntry, args) -> Any: if not range_entry.compiled: range_entry.compiled = True @@ -108,7 +129,7 @@ def _maybe_compile_for_range_entry(self, range_entry: RangeEntry, args) -> Any: # args are real arguments range_entry.runnable = self.vllm_backend.compiler_manager.compile( self.graph, - args, + self.fakify_args(args), self.compilation_config.inductor_compile_config, self.compilation_config, graph_index=self.piecewise_compile_index, diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index 475f4c15afef..fa728c23d145 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -142,11 +142,9 @@ def flashinfer_max_size(self, world_size: int) -> int | None: max_size_mb = self.fi_allreduce_fusion_max_size_mb if max_size_mb is None: max_size_mb = self.default_fi_allreduce_fusion_max_size_mb().get(world_size) - logger.debug_once( - f"flashinfer_max_size: {int(max_size_mb * MiB)}", scope="global" - ) - return int(max_size_mb * MiB) - return None + max_size_bytes = int(max_size_mb * MiB) if max_size_mb is not None else None + logger.debug_once(f"flashinfer_max_size: {max_size_bytes}", scope="global") + return max_size_bytes @staticmethod def default_fi_allreduce_fusion_max_size_mb() -> dict[int, float]: From 65151bcecf8429890f4fa191e7988aedfb2c9aa5 Mon Sep 17 00:00:00 2001 From: ilmarkov Date: Wed, 5 Nov 2025 12:58:20 +0000 Subject: [PATCH 013/183] Address comments Signed-off-by: ilmarkov --- tests/compile/test_compile_ranges.py | 65 +++++++++++++++------------ vllm/compilation/collective_fusion.py | 5 +++ vllm/config/compilation.py | 1 - 3 files changed, 41 insertions(+), 30 deletions(-) diff --git a/tests/compile/test_compile_ranges.py b/tests/compile/test_compile_ranges.py index 68389ccfbe14..03f31df1ece7 100644 --- a/tests/compile/test_compile_ranges.py +++ b/tests/compile/test_compile_ranges.py @@ -1,19 +1,24 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import torch +from torch import fx as fx from torch import nn from torch.library import Library from vllm.compilation.counter import compilation_counter from vllm.compilation.decorators import support_torch_compile +from vllm.compilation.inductor_pass import ( + CustomGraphPass, + InductorPass, + get_pass_context, +) from vllm.config import ( - CompilationConfig, - CompilationLevel, VllmConfig, set_current_vllm_config, ) +from vllm.config.compilation import CompilationConfig, CompilationMode +from vllm.config.scheduler import SchedulerConfig from vllm.forward_context import set_forward_context -from vllm.utils import direct_register_custom_op # create a library to hold the custom op silly_lib = Library("silly", "FRAGMENT") # noqa @@ -22,29 +27,6 @@ MLP_SIZE = 128 -def silly_attention( - q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, out: torch.Tensor -) -> None: - out.copy_(q) - out += k - out += v - - -def silly_attention_fake( - q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, out: torch.Tensor -) -> None: - return - - -direct_register_custom_op( - op_name="attention", - op_func=silly_attention, - mutates_args=["out"], - fake_impl=silly_attention_fake, - target_lib=silly_lib, -) - - @support_torch_compile class TestModel(nn.Module): def __init__(self, *, vllm_config: VllmConfig, prefix: str = "", **kwargs) -> None: @@ -67,12 +49,37 @@ def run_model(vllm_config: VllmConfig, model: nn.Module, batch_sizes: list[int]) model(torch.randn(batch_size, MLP_SIZE).cuda()) +class PostGradPassManagerCheckRanges(CustomGraphPass): + def __init__(self, ranges: list[tuple[int, int]]): + self.ranges = ranges + + def __call__(self, graph: fx.Graph): + compile_range = get_pass_context().compile_range + assert compile_range in self.ranges, ( + f"Compile range {compile_range} not in {self.ranges}" + ) + + def uuid(self) -> str: + state = { + "ranges": self.ranges, + } + return InductorPass.hash_dict(state) + + def test_compile_ranges(): vllm_config = VllmConfig( + scheduler_config=SchedulerConfig( + max_num_batched_tokens=8192, + ), compilation_config=CompilationConfig( - level=CompilationLevel.PIECEWISE, + mode=CompilationMode.VLLM_COMPILE, compile_ranges_split_points=[8, 32], - ) + ), + inductor_compile_config={ + "post_grad_custom_post_pass": PostGradPassManagerCheckRanges( + [(1, 8), (8, 32), (32, 2049)] + ) + }, ) with set_current_vllm_config(vllm_config): @@ -82,7 +89,7 @@ def test_compile_ranges(): with compilation_counter.expect( num_graphs_seen=1, num_piecewise_graphs_seen=1, - num_backend_compilations=4, + num_backend_compilations=3, # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen ): run_model(vllm_config, model, batch_sizes) diff --git a/vllm/compilation/collective_fusion.py b/vllm/compilation/collective_fusion.py index 9c20db07c267..aaf53c6e5768 100644 --- a/vllm/compilation/collective_fusion.py +++ b/vllm/compilation/collective_fusion.py @@ -1109,6 +1109,11 @@ def __init__(self, config: VllmConfig): self.max_token_num = min( self.max_token_num, config.scheduler_config.max_num_batched_tokens ) + logger.debug_once( + f"Flashinfer max size: {max_size // (1024 * 1024)} MB" + f", Maximal number of tokens: {self.max_token_num}", + scope="global", + ) self.ipc_handles, workspace_tensor = ( flashinfer_comm.trtllm_create_ipc_workspace_for_all_reduce_fusion( diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index fa728c23d145..6e50493a770c 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -143,7 +143,6 @@ def flashinfer_max_size(self, world_size: int) -> int | None: if max_size_mb is None: max_size_mb = self.default_fi_allreduce_fusion_max_size_mb().get(world_size) max_size_bytes = int(max_size_mb * MiB) if max_size_mb is not None else None - logger.debug_once(f"flashinfer_max_size: {max_size_bytes}", scope="global") return max_size_bytes @staticmethod From df22202272995c4a9c99f1ae7c562416d9620e53 Mon Sep 17 00:00:00 2001 From: ilmarkov Date: Wed, 5 Nov 2025 11:25:17 -0500 Subject: [PATCH 014/183] Update benchmark Signed-off-by: ilmarkov --- benchmarks/kernels/benchmark_fused_collective.py | 16 ++++++++++++---- vllm/config/compilation.py | 2 +- 2 files changed, 13 insertions(+), 5 deletions(-) diff --git a/benchmarks/kernels/benchmark_fused_collective.py b/benchmarks/kernels/benchmark_fused_collective.py index cec134ff9138..d7fa0580a3e7 100644 --- a/benchmarks/kernels/benchmark_fused_collective.py +++ b/benchmarks/kernels/benchmark_fused_collective.py @@ -410,6 +410,7 @@ def run_benchmarks( use_residual: bool, allreduce_params: FlashInferFusedAllReduceParams | None, quant_modes: set[str], + no_oneshot: bool, ): """Run all benchmarks for given configuration. @@ -431,6 +432,7 @@ def run_benchmarks( rms_eps = 1e-6 results = {} vllm_fused_allreduce = VllmFusedAllreduce(hidden_dim, dtype) + use_oneshot_options = [False] if no_oneshot else [True, False] # Create RMSNorm and QuantFP8 layers once for native benchmarks @@ -476,7 +478,7 @@ def run_benchmarks( # FlashInfer Fused AllReduce + RMSNorm Oneshot/Twoshot if flashinfer_comm is not None and allreduce_params is not None: - for use_oneshot in [True, False]: + for use_oneshot in use_oneshot_options: suffix = "_oneshot" if use_oneshot else "_twoshot" try: time_ms = benchmark_operation( @@ -560,7 +562,7 @@ def run_benchmarks( # FlashInfer Fused AllReduce + RMSNorm + FP8 Quant Oneshot if flashinfer_comm is not None and allreduce_params is not None: - for use_oneshot in [True, False]: + for use_oneshot in use_oneshot_options: suffix = "_oneshot" if use_oneshot else "_twoshot" try: time_ms = benchmark_operation( @@ -645,7 +647,7 @@ def run_benchmarks( # FlashInfer Fused AllReduce + RMSNorm + FP4 Quant Oneshot if flashinfer_comm is not None and allreduce_params is not None: - for use_oneshot in [True, False]: + for use_oneshot in use_oneshot_options: suffix = "_oneshot" if use_oneshot else "_twoshot" try: time_ms = benchmark_operation( @@ -901,7 +903,7 @@ def save_results_to_file( try: markdown_content = format_results_markdown(all_results, world_size, args) - with open(output_path, "w") as f: + with open(output_path, "a") as f: f.write(markdown_content) except Exception as e: @@ -960,6 +962,12 @@ def main(): """, ) + parser.add_argument( + "--no-oneshot", + action="store_true", + help="Skip oneshot benchmarks", + ) + args = parser.parse_args() # Check if running with torchrun (required for collective operations) diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index 6e50493a770c..6f35673856df 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -923,7 +923,7 @@ def custom_op_log_check(self): def get_compile_ranges(self) -> list[tuple[int, int]]: """Get the compile ranges for the compilation config.""" - split_points = self.compile_ranges_split_points + split_points = set(self.compile_ranges_split_points) compile_ranges = [] for i, s in enumerate(split_points): if i == 0: From a21de2baef2202f2610788027c904f9b377752e9 Mon Sep 17 00:00:00 2001 From: ilmarkov Date: Wed, 5 Nov 2025 16:32:59 +0000 Subject: [PATCH 015/183] Fix Signed-off-by: ilmarkov --- benchmarks/kernels/benchmark_fused_collective.py | 1 + 1 file changed, 1 insertion(+) diff --git a/benchmarks/kernels/benchmark_fused_collective.py b/benchmarks/kernels/benchmark_fused_collective.py index d7fa0580a3e7..99213d0c7cc2 100644 --- a/benchmarks/kernels/benchmark_fused_collective.py +++ b/benchmarks/kernels/benchmark_fused_collective.py @@ -1076,6 +1076,7 @@ def main(): use_residual, allreduce_params, quant_modes=quant_modes, + no_oneshot=args.no_oneshot, ) # Store results for markdown export From 6766e4f7da7914d7b1a24e6d760f56e181d5fbaa Mon Sep 17 00:00:00 2001 From: ilmarkov Date: Wed, 5 Nov 2025 17:15:46 -0500 Subject: [PATCH 016/183] Update fakify for compile sizes Signed-off-by: ilmarkov --- vllm/compilation/piecewise_backend.py | 9 ++++++++- vllm/config/compilation.py | 2 +- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/vllm/compilation/piecewise_backend.py b/vllm/compilation/piecewise_backend.py index ad5b49f28550..fe35aaa9e4ae 100644 --- a/vllm/compilation/piecewise_backend.py +++ b/vllm/compilation/piecewise_backend.py @@ -122,14 +122,21 @@ def fakify_args(self, args: list[Any]) -> list[Any]: return fake_example_inputs def _maybe_compile_for_range_entry(self, range_entry: RangeEntry, args) -> Any: + is_compile_size = lambda range: range[0] == range[1] if not range_entry.compiled: range_entry.compiled = True self.to_be_compiled_ranges.remove(range_entry.compile_range) # args are real arguments + # fakify for range, real args for concrete size + args = ( + self.fakify_args(args) + if not is_compile_size(range_entry.compile_range) + else args + ) range_entry.runnable = self.vllm_backend.compiler_manager.compile( self.graph, - self.fakify_args(args), + args, self.compilation_config.inductor_compile_config, self.compilation_config, graph_index=self.piecewise_compile_index, diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index 740b970669ed..67cd974a13e7 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -947,7 +947,7 @@ def custom_op_log_check(self): def get_compile_ranges(self) -> list[tuple[int, int]]: """Get the compile ranges for the compilation config.""" - split_points = set(self.compile_ranges_split_points) + split_points = sorted(set(self.compile_ranges_split_points)) compile_ranges = [] for i, s in enumerate(split_points): if i == 0: From af87d7a7996dc857933ce38b8be3badbed95a935 Mon Sep 17 00:00:00 2001 From: ilmarkov Date: Thu, 6 Nov 2025 09:59:37 -0500 Subject: [PATCH 017/183] Linter fix Signed-off-by: ilmarkov --- vllm/config/compilation.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index 67cd974a13e7..3a3fdd7f295d 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -947,6 +947,8 @@ def custom_op_log_check(self): def get_compile_ranges(self) -> list[tuple[int, int]]: """Get the compile ranges for the compilation config.""" + if self.compile_ranges_split_points is None: + return [] split_points = sorted(set(self.compile_ranges_split_points)) compile_ranges = [] for i, s in enumerate(split_points): From b4c1b1d66d6ce3288c65c57251d0492f2e9f475b Mon Sep 17 00:00:00 2001 From: ilmarkov Date: Mon, 10 Nov 2025 12:31:48 +0000 Subject: [PATCH 018/183] Address the review Signed-off-by: ilmarkov --- tests/compile/test_compile_ranges.py | 50 +++++++++++++----------- vllm/compilation/backends.py | 12 +++--- vllm/compilation/collective_fusion.py | 9 +++-- vllm/compilation/compiler_interface.py | 21 +++++----- vllm/compilation/inductor_pass.py | 9 +++-- vllm/compilation/pass_manager.py | 4 +- vllm/compilation/piecewise_backend.py | 27 ++++++------- vllm/compilation/sequence_parallelism.py | 7 ++-- vllm/config/compilation.py | 8 ++-- vllm/config/utils.py | 36 ++++++++++++++++- vllm/config/vllm.py | 6 ++- vllm/v1/worker/gpu_worker.py | 19 ++++++++- 12 files changed, 137 insertions(+), 71 deletions(-) diff --git a/tests/compile/test_compile_ranges.py b/tests/compile/test_compile_ranges.py index 03f31df1ece7..564690f18192 100644 --- a/tests/compile/test_compile_ranges.py +++ b/tests/compile/test_compile_ranges.py @@ -3,12 +3,11 @@ import torch from torch import fx as fx from torch import nn -from torch.library import Library +import tests.compile.silly_attention # noqa from vllm.compilation.counter import compilation_counter from vllm.compilation.decorators import support_torch_compile from vllm.compilation.inductor_pass import ( - CustomGraphPass, InductorPass, get_pass_context, ) @@ -18,11 +17,9 @@ ) from vllm.config.compilation import CompilationConfig, CompilationMode from vllm.config.scheduler import SchedulerConfig +from vllm.config.utils import Range from vllm.forward_context import set_forward_context -# create a library to hold the custom op -silly_lib = Library("silly", "FRAGMENT") # noqa - BATCH_SIZE = 64 MLP_SIZE = 128 @@ -49,24 +46,34 @@ def run_model(vllm_config: VllmConfig, model: nn.Module, batch_sizes: list[int]) model(torch.randn(batch_size, MLP_SIZE).cuda()) -class PostGradPassManagerCheckRanges(CustomGraphPass): - def __init__(self, ranges: list[tuple[int, int]]): +class PostGradPassManagerCheckRanges(InductorPass): + def __init__(self, ranges: list[Range]): self.ranges = ranges + self.num_calls = 0 def __call__(self, graph: fx.Graph): compile_range = get_pass_context().compile_range assert compile_range in self.ranges, ( f"Compile range {compile_range} not in {self.ranges}" ) + self.num_calls += 1 def uuid(self) -> str: state = { - "ranges": self.ranges, + "ranges": [str(range) for range in self.ranges], + "current_compile_range": str(get_pass_context().compile_range), } return InductorPass.hash_dict(state) def test_compile_ranges(): + post_grad_pass_manager = PostGradPassManagerCheckRanges( + [ + Range(start=1, end=8), + Range(start=8, end=32), + Range(start=32, end=8193), + ] + ) vllm_config = VllmConfig( scheduler_config=SchedulerConfig( max_num_batched_tokens=8192, @@ -74,22 +81,21 @@ def test_compile_ranges(): compilation_config=CompilationConfig( mode=CompilationMode.VLLM_COMPILE, compile_ranges_split_points=[8, 32], + inductor_compile_config={ + "post_grad_custom_post_pass": post_grad_pass_manager + }, ), - inductor_compile_config={ - "post_grad_custom_post_pass": PostGradPassManagerCheckRanges( - [(1, 8), (8, 32), (32, 2049)] - ) - }, ) with set_current_vllm_config(vllm_config): model = TestModel(vllm_config=vllm_config, prefix="").eval().cuda() - batch_sizes = [1, 16, 48] - # A has support_torch_compile - with compilation_counter.expect( - num_graphs_seen=1, - num_piecewise_graphs_seen=1, - num_backend_compilations=3, - # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen - ): - run_model(vllm_config, model, batch_sizes) + batch_sizes = [1, 16, 48] + # A has support_torch_compile + with compilation_counter.expect( + num_graphs_seen=1, + num_piecewise_graphs_seen=1, + num_backend_compilations=3, + # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen + ): + run_model(vllm_config, model, batch_sizes) + assert post_grad_pass_manager.num_calls == 3 diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 7a1d851ebe42..0d7ef88c8e6a 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -22,6 +22,7 @@ resolve_defined_ops, ) from vllm.config import CompilationConfig, CUDAGraphMode, VllmConfig +from vllm.config.utils import Range from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.utils.import_utils import resolve_obj_by_qualname @@ -83,7 +84,7 @@ class CompilerManager: """ def __init__(self, compilation_config: CompilationConfig): - self.cache: dict[tuple[tuple[int, int] | None, int, str], Any] = dict() + self.cache: dict[tuple[Range | None, int, str], Any] = dict() self.is_cache_updated = False self.compilation_config = compilation_config self.compiler = make_compiler(compilation_config) @@ -92,7 +93,7 @@ def compute_hash(self, vllm_config: VllmConfig) -> str: return self.compiler.compute_hash(vllm_config) @contextmanager - def compile_context(self, compile_range: tuple[int, int] | None = None): + def compile_context(self, compile_range: Range | None = None): """Provide compilation context for the duration of compilation to set any torch global properties we want to scope to a single Inductor compilation (e.g. partition rules, pass context).""" @@ -152,7 +153,7 @@ def load( graph: fx.GraphModule, example_inputs: list[Any], graph_index: int, - compile_range: tuple[int, int] | None = None, + compile_range: Range | None = None, ) -> Callable | None: if (compile_range, graph_index, self.compiler.name) not in self.cache: return None @@ -187,7 +188,7 @@ def compile( compilation_config: CompilationConfig, graph_index: int = 0, num_graphs: int = 1, - compile_range: tuple[int, int] | None = None, + compile_range: Range | None = None, ) -> Any: if graph_index == 0: # before compiling the first graph, record the start time @@ -206,6 +207,7 @@ def compile( # there can be multiple graphs due to piecewise compilation. now = time.time() elapsed = now - compilation_start_time + compilation_config.compilation_time += elapsed if compile_range is None: logger.info( "Directly load the compiled graph(s) for dynamic shape " @@ -231,7 +233,7 @@ def compile( if compile_range is None: maybe_key += "dynamic_shape" else: - maybe_key += f"{compile_range[0]}_{compile_range[1]}" + maybe_key += f"{compile_range.start}_{compile_range.end}" maybe_key += f"_subgraph_{graph_index}" with self.compile_context(compile_range): compiled_graph, handle = self.compiler.compile( diff --git a/vllm/compilation/collective_fusion.py b/vllm/compilation/collective_fusion.py index dbe17f984808..81e881373e45 100644 --- a/vllm/compilation/collective_fusion.py +++ b/vllm/compilation/collective_fusion.py @@ -10,6 +10,7 @@ from torch.distributed._symmetric_memory import enable_symm_mem_for_group from vllm.config import VllmConfig +from vllm.config.utils import Range from vllm.distributed import get_tp_group, tensor_model_parallel_all_reduce from vllm.distributed.parallel_state import ( get_tensor_model_parallel_rank, @@ -431,7 +432,7 @@ def __init__(self, config: VllmConfig): self.dump_patterns(config, self.patterns) - def is_applicable_for_range(self, compile_range: tuple[int, int] | None) -> bool: + def is_applicable_for_range(self, compile_range: Range | None) -> bool: # This pass is applied on top of the sequence parallelism pass. # It inherits the same applicability condition as `SequenceParallelismPass`. # See `SequenceParallelismPass.is_applicable` for more details. @@ -442,7 +443,7 @@ def is_applicable_for_range(self, compile_range: tuple[int, int] | None) -> bool return True tp_size = get_tensor_model_parallel_world_size() return compile_range is not None and ( - compile_range[0] == compile_range[1] and compile_range[1] % tp_size == 0 + compile_range.is_single_size() and compile_range.end % tp_size == 0 ) @VllmInductorPass.time_and_log @@ -1188,10 +1189,10 @@ def register_patterns(self): self.disabled = False - def is_applicable_for_range(self, compile_range: tuple[int, int] | None) -> bool: + def is_applicable_for_range(self, compile_range: Range | None) -> bool: if compile_range is None: return False - return compile_range[1] - 1 <= self.max_token_num + return compile_range.end - 1 <= self.max_token_num @VllmInductorPass.time_and_log def __call__(self, graph: fx.Graph): diff --git a/vllm/compilation/compiler_interface.py b/vllm/compilation/compiler_interface.py index 6124a5428f6c..b95067aba191 100644 --- a/vllm/compilation/compiler_interface.py +++ b/vllm/compilation/compiler_interface.py @@ -16,6 +16,7 @@ import vllm.envs as envs from vllm.compilation.counter import compilation_counter from vllm.config import VllmConfig +from vllm.config.utils import Range from vllm.utils.torch_utils import is_torch_equal_or_newer @@ -63,7 +64,7 @@ def compile( graph: fx.GraphModule, example_inputs: list[Any], compiler_config: dict[str, Any], - compile_range: tuple[int, int] | None = None, + compile_range: Range | None = None, key: str | None = None, ) -> tuple[Callable | None, Any | None]: """ @@ -99,7 +100,7 @@ def load( graph: fx.GraphModule, example_inputs: list[Any], graph_index: int, - compile_range: tuple[int, int] | None = None, + compile_range: Range | None = None, ) -> Callable: """ Load the compiled function from the handle. @@ -213,7 +214,7 @@ def compile( graph: fx.GraphModule, example_inputs: list[Any], compiler_config: dict[str, Any], - compile_range: tuple[int, int] | None = None, + compile_range: Range | None = None, key: str | None = None, ) -> tuple[Callable | None, Any | None]: compilation_counter.num_inductor_compiles += 1 @@ -223,8 +224,8 @@ def compile( set_inductor_config(current_config, compile_range) set_functorch_config() - if isinstance(compile_range, tuple): - if compile_range[0] == compile_range[1]: + if compile_range is not None: + if compile_range.is_single_size(): dynamic_shapes = "from_example_inputs" else: dynamic_shapes = "from_graph" @@ -254,7 +255,7 @@ def load( graph: fx.GraphModule, example_inputs: list[Any], graph_index: int, - compile_range: tuple[int, int] | None = None, + compile_range: Range | None = None, ) -> Callable: assert isinstance(handle, tuple) assert isinstance(handle[0], str) @@ -318,7 +319,7 @@ def compile( graph: fx.GraphModule, example_inputs: list[Any], compiler_config: dict[str, Any], - compile_range: tuple[int, int] | None = None, + compile_range: Range | None = None, key: str | None = None, ) -> tuple[Callable | None, Any | None]: compilation_counter.num_inductor_compiles += 1 @@ -515,7 +516,7 @@ def load( graph: fx.GraphModule, example_inputs: list[Any], graph_index: int, - compile_range: tuple[int, int] | None = None, + compile_range: Range | None = None, ) -> Callable: assert isinstance(handle, tuple) assert isinstance(handle[0], str) @@ -612,7 +613,7 @@ def metrics_context(self) -> contextlib.AbstractContextManager: def set_inductor_config(config, compile_range): - if isinstance(compile_range, tuple) and compile_range[0] == compile_range[1]: + if compile_range is not None and compile_range.is_single_size(): # for a specific batch size, tuning triton kernel parameters # can be beneficial config["max_autotune"] = envs.VLLM_ENABLE_INDUCTOR_MAX_AUTOTUNE @@ -633,7 +634,7 @@ def compile( graph: fx.GraphModule, example_inputs: list[Any], compiler_config: dict[str, Any], - compile_range: tuple[int, int] | None = None, + compile_range: Range | None = None, key: str | None = None, ) -> tuple[Callable | None, Any | None]: compilation_counter.num_eager_compiles += 1 diff --git a/vllm/compilation/inductor_pass.py b/vllm/compilation/inductor_pass.py index 599fa776b6c0..008eba4629a3 100644 --- a/vllm/compilation/inductor_pass.py +++ b/vllm/compilation/inductor_pass.py @@ -14,6 +14,7 @@ from torch import fx from torch._subclasses.fake_tensor import FakeTensorMode, unset_fake_temporarily +from vllm.config.utils import Range from vllm.utils.torch_utils import is_torch_equal_or_newer if is_torch_equal_or_newer("2.6"): @@ -28,8 +29,8 @@ class PassContext: - def __init__(self, compile_range: tuple[int, int] | None): - self.compile_range = compile_range + def __init__(self, compile_range: Range | None): + self.compile_range: Range | None = compile_range def get_pass_context() -> PassContext: @@ -39,7 +40,7 @@ def get_pass_context() -> PassContext: @contextmanager -def pass_context(compile_range: tuple[int, int] | None): +def pass_context(compile_range: Range | None): """A context manager that stores the current pass context, usually it is a list of sizes to specialize. """ @@ -96,7 +97,7 @@ def hash_dict(dict_: dict[Any, Any]): encoded = json.dumps(dict_, sort_keys=True).encode("utf-8") return hashlib.sha256(encoded).hexdigest() - def is_applicable_for_range(self, compile_range: tuple[int, int] | None): + def is_applicable_for_range(self, compile_range: Range | None): return True diff --git a/vllm/compilation/pass_manager.py b/vllm/compilation/pass_manager.py index 5984f968da35..4664d0d9aefd 100644 --- a/vllm/compilation/pass_manager.py +++ b/vllm/compilation/pass_manager.py @@ -127,6 +127,8 @@ def uuid(self): for pass_ in self.passes: state["passes"].append(pass_.uuid()) state["passes"].append(self.fix_functionalization.uuid()) - state["compile_range"] = get_pass_context().compile_range + compile_range = get_pass_context().compile_range + if compile_range is not None: + state["compile_range"] = str(compile_range) return InductorPass.hash_dict(state) diff --git a/vllm/compilation/piecewise_backend.py b/vllm/compilation/piecewise_backend.py index fe35aaa9e4ae..10844b69c455 100644 --- a/vllm/compilation/piecewise_backend.py +++ b/vllm/compilation/piecewise_backend.py @@ -10,6 +10,7 @@ from vllm.compilation.backends import VllmBackend from vllm.compilation.monitor import end_monitoring_torch_compile from vllm.config import VllmConfig +from vllm.config.compilation import Range from vllm.logger import init_logger logger = init_logger(__name__) @@ -17,7 +18,7 @@ @dataclasses.dataclass class RangeEntry: - compile_range: tuple[int, int] + compile_range: Range compiled: bool = False runnable: Callable = None # type: ignore @@ -61,12 +62,6 @@ def __init__( log_string = f"PiecewiseBackend: compile_sizes: {self.compile_sizes}" logger.debug_once(log_string) - self.is_in_range = ( - lambda x, range: range[0] <= x < range[1] - if range[0] < range[1] - else x == range[0] - ) - self.first_run_finished = False self.sym_shape_indices = sym_shape_indices @@ -75,15 +70,15 @@ def __init__( # self.concrete_size_entries: dict[int, RangeEntry] = {} # the entries for ranges that we need to either - self.range_entries: dict[tuple[int, int], RangeEntry] = {} + self.range_entries: dict[Range, RangeEntry] = {} # to_be_compiled_ranges tracks the remaining ranges to compile, # and updates during the compilation process, so we need to copy it - self.to_be_compiled_ranges: set[tuple[int, int]] = set(self.compile_ranges) + self.to_be_compiled_ranges: set[Range] = set(self.compile_ranges) # We only keep compilation management inside this class directly. for size in self.compile_sizes: - range = (size, size) + range = Range(start=size, end=size) self.range_entries[range] = RangeEntry( compile_range=range, ) @@ -122,7 +117,6 @@ def fakify_args(self, args: list[Any]) -> list[Any]: return fake_example_inputs def _maybe_compile_for_range_entry(self, range_entry: RangeEntry, args) -> Any: - is_compile_size = lambda range: range[0] == range[1] if not range_entry.compiled: range_entry.compiled = True self.to_be_compiled_ranges.remove(range_entry.compile_range) @@ -131,7 +125,7 @@ def _maybe_compile_for_range_entry(self, range_entry: RangeEntry, args) -> Any: # fakify for range, real args for concrete size args = ( self.fakify_args(args) - if not is_compile_size(range_entry.compile_range) + if not range_entry.compile_range.is_single_size() else args ) range_entry.runnable = self.vllm_backend.compiler_manager.compile( @@ -158,13 +152,18 @@ def __call__(self, *args) -> Any: return range_entry.runnable(*args) runtime_shape = args[self.sym_shape_indices[0]] + # First we try to find the range entry for the concrete compile size + # If not found, we search for the range entry + # that contains the runtime shape. range_found = False if runtime_shape in self.compile_sizes: - range_entry = self.range_entries[(runtime_shape, runtime_shape)] + range_entry = self.range_entries[ + Range(start=runtime_shape, end=runtime_shape) + ] range_found = True else: for range in self.compile_ranges: - if self.is_in_range(runtime_shape, range): + if range.contains(runtime_shape): range_entry = self.range_entries[range] range_found = True break diff --git a/vllm/compilation/sequence_parallelism.py b/vllm/compilation/sequence_parallelism.py index cf47adb4670a..6a5ee5a0efb7 100644 --- a/vllm/compilation/sequence_parallelism.py +++ b/vllm/compilation/sequence_parallelism.py @@ -7,6 +7,7 @@ from torch._inductor.pattern_matcher import PatternMatcherPass from vllm.config import VllmConfig +from vllm.config.compilation import Range from vllm.distributed import get_tp_group, tensor_model_parallel_all_reduce from vllm.distributed.parallel_state import get_tensor_model_parallel_world_size from vllm.logger import init_logger @@ -482,7 +483,7 @@ def __init__(self, config: VllmConfig): ).register(self.patterns) self.dump_patterns(config, self.patterns) - def is_applicable_for_range(self, compile_range: tuple[int, int] | None) -> bool: + def is_applicable_for_range(self, compile_range: Range | None) -> bool: # When sequence parallelism is enabled, the residual tensor from RMSNorm # needs to be split along the sequence dimension. However, this dimension # is symbolic during piecewise compilation, and splitting symbolic shapes @@ -504,8 +505,8 @@ def is_applicable_for_range(self, compile_range: tuple[int, int] | None) -> bool tp_size = get_tensor_model_parallel_world_size() return ( compile_range is not None - and (compile_range[0] == compile_range[1]) - and (compile_range[1] % tp_size == 0) + and (compile_range.is_single_size()) + and (compile_range.end % tp_size == 0) ) @VllmInductorPass.time_and_log diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index 2ae93c59ddfb..298fe4242a83 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -14,7 +14,7 @@ import vllm.envs as envs from vllm.compilation.inductor_pass import CallableInductorPass, InductorPass -from vllm.config.utils import config +from vllm.config.utils import Range, config from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.utils.import_utils import resolve_obj_by_qualname @@ -945,7 +945,7 @@ def custom_op_log_check(self): op, ) - def get_compile_ranges(self) -> list[tuple[int, int]]: + def get_compile_ranges(self) -> list[Range]: """Get the compile ranges for the compilation config.""" if self.compile_ranges_split_points is None: return [] @@ -953,7 +953,7 @@ def get_compile_ranges(self) -> list[tuple[int, int]]: compile_ranges = [] for i, s in enumerate(split_points): if i == 0: - compile_ranges.append((1, s)) + compile_ranges.append(Range(start=1, end=s)) else: - compile_ranges.append((split_points[i - 1], s)) + compile_ranges.append(Range(start=split_points[i - 1], end=s)) return compile_ranges diff --git a/vllm/config/utils.py b/vllm/config/utils.py index 7e0878d96bbd..7270caf02740 100644 --- a/vllm/config/utils.py +++ b/vllm/config/utils.py @@ -6,7 +6,7 @@ import inspect import textwrap from collections.abc import Iterable -from dataclasses import MISSING, Field, field, fields, is_dataclass, replace +from dataclasses import MISSING, Field, dataclass, field, fields, is_dataclass, replace from itertools import pairwise from typing import TYPE_CHECKING, Any, Protocol, TypeVar @@ -176,3 +176,37 @@ def update_config(config: ConfigT, overrides: dict[str, Any]) -> ConfigT: ) processed_overrides[field_name] = value return replace(config, **processed_overrides) + + +@dataclass +class Range: + """ + A range of numbers. + Inclusive of start, exclusive of end. + """ + + start: int + end: int + + def is_single_size(self) -> bool: + return self.start == self.end + + def contains(self, size: int) -> bool: + # Inclusive of start, exclusive of end + if self.is_single_size(): + return size == self.start + return self.start <= size < self.end + + def __eq__(self, other: object) -> bool: + if not isinstance(other, Range): + return False + return self.start == other.start and self.end == other.end + + def __hash__(self) -> int: + return hash((self.start, self.end)) + + def __str__(self) -> str: + return f"(start={self.start}, end={self.end})" + + def __repr__(self) -> str: + return self.__str__() diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index 43a3b51b3a0a..a217b3c48f81 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -889,7 +889,11 @@ def _set_compile_ranges(self): # We add 1 because the bounds checks in the compiler are # exclusive and we want to include the max_token_num in the # compile range - computed_compile_ranges_split_points.append(max_token_num + 1) + if ( + max_num_batched_tokens is not None + and max_token_num < max_num_batched_tokens + ): + computed_compile_ranges_split_points.append(max_token_num + 1) if compilation_config.compile_ranges_split_points is not None: for x in compilation_config.compile_ranges_split_points: diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index f13ff4e726bd..42f9bdeab97e 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -14,7 +14,7 @@ import torch.nn as nn import vllm.envs as envs -from vllm.config import VllmConfig +from vllm.config import CUDAGraphMode, VllmConfig from vllm.distributed import ( ensure_model_parallel_initialized, init_distributed_environment, @@ -398,12 +398,27 @@ def compile_or_warm_up_model(self) -> None: # but users still want to compile for better performance, # e.g. for the max-num-batched token size in chunked prefill. warmup_sizes = self.vllm_config.compilation_config.compile_sizes.copy() - if not self.model_config.enforce_eager: + + if ( + not self.model_config.enforce_eager + or self.compilation_config.cudagraph_mode == CUDAGraphMode.NONE + ): warmup_sizes = [ x for x in warmup_sizes if x not in self.vllm_config.compilation_config.cudagraph_capture_sizes ] + compile_ranges = self.vllm_config.compilation_config.get_compile_ranges() + + # For each compile_range, if none of the batch sizes + # in warmup_sizes or cudagraph_capture_sizes are in the range, + # add the start of the range to ensure compilation/warmup. + all_sizes = set(self.vllm_config.compilation_config.cudagraph_capture_sizes) + all_sizes.update(warmup_sizes) + for compile_range in compile_ranges: + if not any(compile_range.contains(x) for x in all_sizes): + warmup_sizes.append(compile_range.start) + # We skip EPLB here since we don't want to record dummy metrics for size in sorted(warmup_sizes, reverse=True): logger.info("Compile and warming up model for size %d", size) From f080a83511511a9c0a222451a752a1623aec095d Mon Sep 17 00:00:00 2001 From: vllmellm Date: Mon, 10 Nov 2025 17:20:53 +0100 Subject: [PATCH 019/183] [RFC][ROCm][AITER] Keep all AITER kernels in `_aiter_ops` class like `_custom_ops` and `_ipex_ops` (#24490) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: vllmellm Co-authored-by: Luka Govedič --- docs/design/moe_kernel_features.md | 2 +- tests/kernels/moe/test_moe.py | 11 +- .../model_executor/test_enabled_custom_ops.py | 41 +- vllm/_aiter_ops.py | 941 ++++++++++++++++++ vllm/attention/ops/rocm_aiter_mla.py | 105 -- vllm/envs.py | 8 +- .../layers/fused_moe/fused_moe.py | 15 +- vllm/model_executor/layers/fused_moe/layer.py | 83 +- .../layers/fused_moe/rocm_aiter_fused_moe.py | 329 +----- vllm/model_executor/layers/layernorm.py | 90 +- .../compressed_tensors_moe.py | 12 +- .../schemes/compressed_tensors_w8a8_fp8.py | 4 +- .../model_executor/layers/quantization/fp8.py | 16 +- .../quantization/kernels/scaled_mm/aiter.py | 48 +- .../layers/quantization/quark/quark_moe.py | 47 +- .../quark/schemes/quark_ocp_mx.py | 7 + .../layers/quantization/utils/fp8_utils.py | 124 +-- .../layers/quantization/utils/w8a8_utils.py | 2 +- .../layers/rotary_embedding/base.py | 13 +- .../rotary_embedding/deepseek_scaling_rope.py | 9 + .../rotary_embedding/rocm_aiter_rope_ops.py | 94 -- vllm/model_executor/models/deepseek_v2.py | 27 +- vllm/platforms/rocm.py | 27 +- vllm/v1/attention/backends/mla/common.py | 55 +- .../attention/backends/mla/rocm_aiter_mla.py | 9 +- 25 files changed, 1194 insertions(+), 925 deletions(-) create mode 100644 vllm/_aiter_ops.py delete mode 100644 vllm/attention/ops/rocm_aiter_mla.py delete mode 100644 vllm/model_executor/layers/rotary_embedding/rocm_aiter_rope_ops.py diff --git a/docs/design/moe_kernel_features.md b/docs/design/moe_kernel_features.md index 633e23eea33e..ee224e6922fb 100644 --- a/docs/design/moe_kernel_features.md +++ b/docs/design/moe_kernel_features.md @@ -97,7 +97,7 @@ To be used with a particular `FusedMoEPrepareAndFinalize` sub-class, MoE kernels | trtllm | standard | mxfp4,
nvfp4 | G(16),G(32) | 5 | N | Y | [`TrtLlmGenExperts`][vllm.model_executor.layers.fused_moe.trtllm_moe.TrtLlmGenExperts] | | pallas | standard | N/A | N/A | silu | N | N | [`fused_moe`][vllm.model_executor.layers.fused_moe.moe_pallas.fused_moe] | | iterative | standard | N/A | N/A | silu | N | N | [`fused_moe`][vllm.model_executor.layers.fused_moe.moe_torch_iterative.fused_moe] | -| rocm aiter moe | standard | fp8 | G(128),A,T | silu, gelu | Y | N | [`rocm_aiter_fused_experts`][vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe.rocm_aiter_fused_moe_impl] | +| rocm aiter moe | standard | fp8 | G(128),A,T | silu, gelu | Y | N | [`rocm_aiter_fused_experts`][vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe.rocm_aiter_fused_experts] | | cpu_fused_moe | standard | N/A | N/A | silu | N | N | [`CPUFusedMOE`][vllm.model_executor.layers.fused_moe.cpu_fused_moe.CPUFusedMOE] | | naive batched4 | batched | int8,
fp8 | G,A,T | silu, gelu | 6 | Y | [`NaiveBatchedExperts`][vllm.model_executor.layers.fused_moe.fused_batched_moe.NaiveBatchedExperts] | diff --git a/tests/kernels/moe/test_moe.py b/tests/kernels/moe/test_moe.py index 014df1fa111f..c27cf2468ede 100644 --- a/tests/kernels/moe/test_moe.py +++ b/tests/kernels/moe/test_moe.py @@ -6,6 +6,8 @@ """ import functools +import importlib +import sys from collections.abc import Callable from dataclasses import dataclass from typing import Any @@ -20,6 +22,7 @@ import vllm.model_executor.layers.fused_moe # noqa from tests.kernels.moe.utils import fused_moe from tests.kernels.utils import opcheck, stack_and_dev, torch_moe +from vllm._aiter_ops import rocm_aiter_ops from vllm.config import VllmConfig, set_current_vllm_config from vllm.distributed.parallel_state import init_distributed_environment from vllm.forward_context import set_forward_context @@ -412,14 +415,12 @@ def test_mixtral_moe( huggingface.""" # clear the cache before every test - from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( - is_rocm_aiter_moe_enabled, - ) + # Force reload aiter_ops to pick up the new environment variables. + if "rocm_aiter_ops" in sys.modules: + importlib.reload(rocm_aiter_ops) - is_rocm_aiter_moe_enabled.cache_clear() if use_rocm_aiter: monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1") - if dtype == torch.float32: pytest.skip("AITER ROCm test skip for float32") diff --git a/tests/model_executor/test_enabled_custom_ops.py b/tests/model_executor/test_enabled_custom_ops.py index 41419553aa83..9121284de85b 100644 --- a/tests/model_executor/test_enabled_custom_ops.py +++ b/tests/model_executor/test_enabled_custom_ops.py @@ -4,6 +4,7 @@ import pytest import torch +from vllm._aiter_ops import rocm_aiter_ops from vllm.config import CompilationConfig, VllmConfig, set_current_vllm_config from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.layers.activation import ( @@ -15,9 +16,6 @@ dispatch_topk_func, vllm_topk_softmax, ) -from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( - is_rocm_aiter_moe_enabled, -) from vllm.model_executor.layers.layernorm import ( RMSNorm, dispatch_rocm_rmsnorm_func, @@ -126,50 +124,39 @@ def test_enabled_ops_invalid(env: str): RMSNorm(1024).enabled() -@pytest.mark.parametrize("use_rocm_aiter", ["0", "1"]) -def test_topk_dispatch(use_rocm_aiter: str, monkeypatch): - monkeypatch.setenv("VLLM_ROCM_USE_AITER", use_rocm_aiter) - topk_func = dispatch_topk_func() - is_rocm_aiter_moe_enabled.cache_clear() - if current_platform.is_rocm() and int(use_rocm_aiter): - from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( - rocm_aiter_topk_softmax, - ) +@pytest.mark.parametrize( + "use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False] +) +def test_topk_dispatch(use_rocm_aiter: bool): + topk_func = dispatch_topk_func(use_rocm_aiter) - assert topk_func == rocm_aiter_topk_softmax + if current_platform.is_rocm() and use_rocm_aiter: + assert topk_func == rocm_aiter_ops.topk_softmax else: assert topk_func == vllm_topk_softmax @pytest.mark.parametrize("add_residual", [True, False]) @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16]) -@pytest.mark.parametrize("use_rocm_aiter", ["0", "1"]) -@pytest.mark.parametrize("use_rocm_aiter_norm", ["0", "1"]) +@pytest.mark.parametrize("use_rocm_aiter", [True, False]) @pytest.mark.skipif( not current_platform.is_rocm(), reason="AITER is a feature exclusive for ROCm" ) def test_rms_norm_dispatch( - add_residual: bool, - dtype: torch.dtype, - use_rocm_aiter: str, - use_rocm_aiter_norm: str, - monkeypatch, + add_residual: bool, dtype: torch.dtype, use_rocm_aiter: bool ): - monkeypatch.setenv("VLLM_ROCM_USE_AITER", use_rocm_aiter) - monkeypatch.setenv("VLLM_ROCM_USE_AITER_RMSNORM", use_rocm_aiter_norm) - rms_norm_func = dispatch_rocm_rmsnorm_func(add_residual, dtype) + rms_norm_func = dispatch_rocm_rmsnorm_func(add_residual, dtype, use_rocm_aiter) should_use_rocm_aiter = ( current_platform.is_rocm() - and int(use_rocm_aiter) - and int(use_rocm_aiter_norm) + and use_rocm_aiter and dtype in RMS_NORM_SUPPORTED_DTYPES ) if add_residual and should_use_rocm_aiter: - assert rms_norm_func == torch.ops.vllm.rocm_aiter_rmsnorm2d_fwd_with_add + assert rms_norm_func == rocm_aiter_ops.rms_norm2d_with_add elif should_use_rocm_aiter: - assert rms_norm_func == torch.ops.vllm.rocm_aiter_rms_norm + assert rms_norm_func == rocm_aiter_ops.rms_norm elif add_residual: assert rms_norm_func == fused_add_rms_norm else: diff --git a/vllm/_aiter_ops.py b/vllm/_aiter_ops.py new file mode 100644 index 000000000000..9a4b5f3399be --- /dev/null +++ b/vllm/_aiter_ops.py @@ -0,0 +1,941 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import functools +from collections.abc import Callable + +import torch + +import vllm.envs as envs +from vllm.platforms import current_platform +from vllm.utils.torch_utils import direct_register_custom_op, is_torch_equal_or_newer + + +def is_aiter_found() -> bool: + from importlib.util import find_spec + + return find_spec("aiter") is not None + + +# `find_spec` is not torch.compile compatible. +# In cases where aiter availability might have +# been checked in forward passes that are torch compiled. +# we keep this global outside to not cause torch compile breaks. +IS_AITER_FOUND = is_aiter_found() + + +def if_aiter_supported(func: Callable) -> Callable: + """Decorator that only executes the function if + ROCm AITER package is supported on gfx9 archs. + """ + + @functools.wraps(func) + def wrapper(*args, **kwargs): + # checks the platform, device arch and aiter library existance. + + from vllm.platforms.rocm import on_gfx9 + + if current_platform.is_rocm() and on_gfx9() and IS_AITER_FOUND: + return func(*args, **kwargs) + else: + # Return None or do nothing if not supported + return None + + return wrapper + + +def _rocm_aiter_fused_moe_impl( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weight: torch.Tensor, + topk_ids: torch.Tensor, + expert_mask: torch.Tensor | None = None, + activation_method: int = 0, + quant_method: int = 0, + doweight_stage1: bool = False, + w1_scale: torch.Tensor | None = None, + w2_scale: torch.Tensor | None = None, + a1_scale: torch.Tensor | None = None, + a2_scale: torch.Tensor | None = None, +) -> torch.Tensor: + from aiter import ActivationType, QuantType + from aiter.fused_moe import fused_moe + + activation = ActivationType(activation_method) + quant_type = QuantType(quant_method) + + return fused_moe( + hidden_states, + w1, + w2, + topk_weight, + topk_ids, + expert_mask, + activation, + quant_type, + doweight_stage1, + w1_scale, + w2_scale, + a1_scale, + a2_scale, + ) + + +def _rocm_aiter_fused_moe_fake( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weight: torch.Tensor, + topk_ids: torch.Tensor, + expert_mask: torch.Tensor | None = None, + activation_method: int = 0, + quant_method: int = 0, + doweight_stage1: bool = False, + w1_scale: torch.Tensor | None = None, + w2_scale: torch.Tensor | None = None, + a1_scale: torch.Tensor | None = None, + a2_scale: torch.Tensor | None = None, +) -> torch.Tensor: + return torch.empty_like(hidden_states) + + +def _rocm_aiter_asm_moe_tkw1_impl( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + fc1_scale: torch.Tensor | None = None, + fc2_scale: torch.Tensor | None = None, + fc1_smooth_scale: torch.Tensor | None = None, + fc2_smooth_scale: torch.Tensor | None = None, + a16: bool = False, + per_tensor_quant_scale: torch.Tensor | None = None, + expert_mask: torch.Tensor | None = None, + activation_method: int = 0, +) -> torch.Tensor: + from aiter import ActivationType + from aiter.fused_moe_bf16_asm import asm_moe_tkw1 + + activation = ActivationType(activation_method) + + return asm_moe_tkw1( + hidden_states, + w1, + w2, + topk_weights, + topk_ids, + fc1_scale=fc1_scale, + fc2_scale=fc2_scale, + fc1_smooth_scale=fc1_smooth_scale, + fc2_smooth_scale=fc2_smooth_scale, + a16=a16, + per_tensor_quant_scale=per_tensor_quant_scale, + expert_mask=expert_mask, + activation=activation, + ) + + +def _rocm_aiter_asm_moe_tkw1_fake( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + fc1_scale: torch.Tensor | None = None, + fc2_scale: torch.Tensor | None = None, + fc1_smooth_scale: torch.Tensor | None = None, + fc2_smooth_scale: torch.Tensor | None = None, + a16: bool = False, + per_tensor_quant_scale: torch.Tensor | None = None, + expert_mask: torch.Tensor | None = None, + activation_method: int = 0, +) -> torch.Tensor: + return torch.empty_like(hidden_states) + + +def _rocm_aiter_topk_softmax_impl( + topk_weights: torch.Tensor, + topk_indices: torch.Tensor, + token_expert_indices: torch.Tensor, + gating_output: torch.Tensor, + renormalize: bool, +) -> None: + from aiter import topk_softmax + + topk_softmax( + topk_weights, topk_indices, token_expert_indices, gating_output, renormalize + ) + + +def _rocm_aiter_topk_softmax_fake( + topk_weights: torch.Tensor, + topk_indices: torch.Tensor, + token_expert_indices: torch.Tensor, + gating_output: torch.Tensor, + renormalize: bool, +) -> None: + pass + + +def _rocm_aiter_biased_grouped_topk_impl( + gating_output: torch.Tensor, + correction_bias: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + num_expert_group: int, + topk_group: int, + need_renorm: bool, + routed_scaling_factor: float = 1.0, # mul to topk_weights +) -> None: + from aiter import biased_grouped_topk + + biased_grouped_topk( + gating_output, + correction_bias, + topk_weights, + topk_ids, + num_expert_group, + topk_group, + need_renorm, + routed_scaling_factor, + ) + + +def _rocm_aiter_biased_grouped_topk_fake( + gating_output: torch.Tensor, + correction_bias: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + num_expert_group: int, + topk_group: int, + need_renorm: bool, + routed_scaling_factor: float = 1.0, # mul to topk_weights +) -> None: + pass + + +def _rocm_aiter_grouped_topk_impl( + gating_output: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + num_expert_group: int, + topk_group: int, + need_renorm: bool, + scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, # mul to topk_weights +) -> None: + is_softmax = scoring_func == "softmax" + from aiter import grouped_topk + + grouped_topk( + gating_output, + topk_weights, + topk_ids, + num_expert_group, + topk_group, + need_renorm, + is_softmax, + routed_scaling_factor, + ) + + +def _rocm_aiter_grouped_topk_fake( + gating_output: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + num_expert_group: int, + topk_group: int, + need_renorm: bool, + scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, # mul to topk_weights +) -> None: + pass + + +def _rocm_aiter_mla_decode_fwd_impl( + q: torch.Tensor, + kv_buffer: torch.Tensor, + o: torch.Tensor, + qo_indptr: torch.Tensor, + max_seqlen_qo: int, + kv_indptr: torch.Tensor | None = None, + kv_indices: torch.Tensor | None = None, + kv_last_page_lens: torch.Tensor | None = None, + sm_scale: float = 1.0, + logit_cap: float = 0.0, +) -> None: + from aiter.mla import mla_decode_fwd + + mla_decode_fwd( + q, + kv_buffer.view(-1, 1, 1, q.shape[-1]), + o, + qo_indptr, + kv_indptr, + kv_indices, + kv_last_page_lens, + max_seqlen_qo, + sm_scale=sm_scale, + logit_cap=logit_cap, + ) + + +def _rocm_aiter_mla_decode_fwd_fake( + q: torch.Tensor, + kv_buffer: torch.Tensor, + o: torch.Tensor, + qo_indptr: torch.Tensor, + max_seqlen_qo: int, + kv_indptr: torch.Tensor | None = None, + kv_indices: torch.Tensor | None = None, + kv_last_page_lens: torch.Tensor | None = None, + sm_scale: float = 1.0, + logit_cap: float = 0.0, +) -> None: + pass + + +def _rocm_aiter_gemm_w8a8_impl( + A: torch.Tensor, + B: torch.Tensor, + As: torch.Tensor, + Bs: torch.Tensor, + bias: torch.Tensor | None = None, + output_dtype: torch.dtype = torch.float16, +) -> torch.Tensor: + from aiter import gemm_a8w8_CK + + # gemm_a8w8_CK(a, b, scale_a, scale_b, bias) expects + # a to be [M, K] + # b to be [N, K] + # CutlassScaledMMLinearKernel prepare weight `w_q` in [K, N] format + return gemm_a8w8_CK(A, B, As, Bs, bias, output_dtype) + + +def _rocm_aiter_gemm_w8a8_fake( + A: torch.Tensor, + B: torch.Tensor, + As: torch.Tensor, + Bs: torch.Tensor, + bias: torch.Tensor | None = None, + output_dtype: torch.dtype = torch.float16, +) -> torch.Tensor: + m = A.shape[0] + n = B.shape[0] + Y = torch.empty(m, n, dtype=output_dtype, device=A.device) + return Y + + +def _rocm_aiter_gemm_w8a8_blockscale_impl( + A: torch.Tensor, + B: torch.Tensor, + As: torch.Tensor, + Bs: torch.Tensor, + output_dtype: torch.dtype = torch.float16, +) -> torch.Tensor: + from aiter import gemm_a8w8_blockscale + + return gemm_a8w8_blockscale(A, B, As, Bs, dtype=output_dtype) + + +def _rocm_aiter_gemm_w8a8_blockscale_fake( + A: torch.Tensor, + B: torch.Tensor, + As: torch.Tensor, + Bs: torch.Tensor, + output_dtype: torch.dtype = torch.float16, +) -> torch.Tensor: + m = A.shape[0] + n = B.shape[0] + Y = torch.empty(m, n, dtype=output_dtype, device=A.device) + return Y + + +def _rocm_aiter_rms_norm_impl( + x: torch.Tensor, weight: torch.Tensor, variance_epsilon: float +) -> torch.Tensor: + from aiter import rms_norm + + if x.dim() > 2: + x_original_shape = x.shape + x = x.reshape(-1, x_original_shape[-1]) + x = rms_norm(x, weight, variance_epsilon) + return x.reshape(x_original_shape) + + return rms_norm(x, weight, variance_epsilon) + + +def _rocm_aiter_rms_norm_fake( + x: torch.Tensor, weight: torch.Tensor, variance_epsilon: float +) -> torch.Tensor: + return torch.empty_like(x) + + +def _rocm_aiter_rmsnorm2d_fwd_with_add_impl( + x: torch.Tensor, + residual: torch.Tensor, + weight: torch.Tensor, + variance_epsilon: float, +) -> tuple[torch.Tensor, torch.Tensor]: + from aiter import rmsnorm2d_fwd_with_add + + residual_out = torch.empty_like(residual) + output = torch.empty_like(x) + rmsnorm2d_fwd_with_add( + output, # output + x, # input + residual, # residual input + residual_out, # residual output + weight, + variance_epsilon, + ) + return output, residual_out + + +def _rocm_aiter_rmsnorm2d_fwd_with_add_fake( + x: torch.Tensor, + residual: torch.Tensor, + weight: torch.Tensor, + variance_epsilon: float, +) -> tuple[torch.Tensor, torch.Tensor]: + return torch.empty_like(x), torch.empty_like(residual) + + +# Global flag to ensure ops are registered only once +_OPS_REGISTERED = False + + +class rocm_aiter_ops: + _AITER_ENABLED = envs.VLLM_ROCM_USE_AITER + _LINEAR_ENABLED = envs.VLLM_ROCM_USE_AITER_LINEAR + _RMSNORM_ENABLED = envs.VLLM_ROCM_USE_AITER_RMSNORM + _FMOE_ENABLED = envs.VLLM_ROCM_USE_AITER_MOE + _MLA_ENABLED = envs.VLLM_ROCM_USE_AITER_MLA + _PG_ATTN_ENABLED = envs.VLLM_ROCM_USE_AITER_PAGED_ATTN + _MHA_ENABLED = envs.VLLM_ROCM_USE_AITER_MHA + _TRITON_UNIFIED_ATTN_ENABLED = envs.VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION + _FP8BMM_ENABLED = envs.VLLM_ROCM_USE_AITER_FP8BMM + _FP4_GEMM_DYNAMIC_QUANT_ASM = envs.VLLM_ROCM_USE_AITER_FP4_ASM_GEMM + _TRITON_ROTARY_EMBED = envs.VLLM_ROCM_USE_AITER_TRITON_ROPE + _MOE_SHARED_EXPERTS_ENABLED = envs.VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS + + @classmethod + @if_aiter_supported + def is_enabled(cls) -> bool: + """Verifies device specs and availability of aiter main env variable.""" + return cls._AITER_ENABLED + + @classmethod + @if_aiter_supported + def is_linear_enabled(cls) -> bool: + """ "Verifies device specs and availability of env variable.""" + return cls._AITER_ENABLED and cls._LINEAR_ENABLED + + @classmethod + @if_aiter_supported + def is_linear_fp8_enaled(cls) -> bool: + """ "Verifies device specs and availability of env variable.""" + return cls.is_linear_enabled() and current_platform.is_fp8_fnuz() + + @classmethod + @if_aiter_supported + def is_rmsnorm_enabled(cls) -> bool: + """ "Verifies device specs and availability of env variable.""" + return cls._AITER_ENABLED and cls._RMSNORM_ENABLED + + @classmethod + @if_aiter_supported + def is_fused_moe_enabled(cls) -> bool: + """ "Verifies device specs and availability of env variable.""" + return cls._AITER_ENABLED and cls._FMOE_ENABLED + + @classmethod + @if_aiter_supported + def is_fusion_moe_shared_experts_enabled(cls) -> bool: + return cls.is_fused_moe_enabled() and cls._MOE_SHARED_EXPERTS_ENABLED + + @classmethod + @if_aiter_supported + def is_mla_enabled(cls) -> bool: + """ "Verifies device specs and availability of env variable.""" + return cls._AITER_ENABLED and cls._MLA_ENABLED + + @classmethod + @if_aiter_supported + def is_mha_enabled(cls) -> bool: + """ "Verifies device specs and availability of env variable.""" + return cls._AITER_ENABLED and cls._MHA_ENABLED + + @classmethod + @if_aiter_supported + def is_pa_attn_enabled(cls) -> bool: + """ "Verifies device specs and availability of env variable.""" + return cls._AITER_ENABLED and cls._PG_ATTN_ENABLED + + @classmethod + @if_aiter_supported + def is_triton_unified_attn_enabled(cls) -> bool: + """ "Verifies device specs and availability of env variable.""" + return cls._AITER_ENABLED and cls._TRITON_UNIFIED_ATTN_ENABLED + + @classmethod + @if_aiter_supported + def is_fp8bmm_enabled(cls) -> bool: + return cls._AITER_ENABLED and cls._FP8BMM_ENABLED + + @classmethod + @if_aiter_supported + def is_asm_fp4_gemm_dynamic_quant_enabled(cls) -> bool: + return cls._AITER_ENABLED and cls._FP4_GEMM_DYNAMIC_QUANT_ASM + + @classmethod + @if_aiter_supported + def is_triton_rotary_embed_enabled(cls) -> bool: + return cls._AITER_ENABLED and cls._TRITON_ROTARY_EMBED + + @staticmethod + @if_aiter_supported + def register_ops_once() -> None: + global _OPS_REGISTERED + if not _OPS_REGISTERED: + tags = ( + tuple() + if is_torch_equal_or_newer("2.7.0") + else (torch.Tag.needs_fixed_stride_order,) + ) + + # register all the custom ops here + direct_register_custom_op( + op_name="rocm_aiter_asm_moe_tkw1", + op_func=_rocm_aiter_asm_moe_tkw1_impl, + mutates_args=[], + fake_impl=_rocm_aiter_asm_moe_tkw1_fake, + dispatch_key=current_platform.dispatch_key, + ) + + direct_register_custom_op( + op_name="rocm_aiter_fused_moe", + op_func=_rocm_aiter_fused_moe_impl, + mutates_args=[], + fake_impl=_rocm_aiter_fused_moe_fake, + dispatch_key=current_platform.dispatch_key, + ) + + direct_register_custom_op( + op_name="rocm_aiter_topk_softmax", + op_func=_rocm_aiter_topk_softmax_impl, + mutates_args=["topk_weights", "topk_indices", "token_expert_indices"], + fake_impl=_rocm_aiter_topk_softmax_fake, + dispatch_key=current_platform.dispatch_key, + ) + + direct_register_custom_op( + op_name="rocm_aiter_biased_grouped_topk", + op_func=_rocm_aiter_biased_grouped_topk_impl, + mutates_args=["topk_weights", "topk_ids"], + fake_impl=_rocm_aiter_biased_grouped_topk_fake, + dispatch_key=current_platform.dispatch_key, + ) + + direct_register_custom_op( + op_name="rocm_aiter_grouped_topk", + op_func=_rocm_aiter_grouped_topk_impl, + mutates_args=["topk_weights", "topk_ids"], + fake_impl=_rocm_aiter_grouped_topk_fake, + dispatch_key=current_platform.dispatch_key, + ) + + direct_register_custom_op( + op_name="rocm_aiter_mla_decode_fwd", + op_func=_rocm_aiter_mla_decode_fwd_impl, + mutates_args=["o"], + fake_impl=_rocm_aiter_mla_decode_fwd_fake, + tags=tags, + ) + + direct_register_custom_op( + op_name="rocm_aiter_gemm_w8a8", + op_func=_rocm_aiter_gemm_w8a8_impl, + mutates_args=[], + fake_impl=_rocm_aiter_gemm_w8a8_fake, + dispatch_key=current_platform.dispatch_key, + ) + + direct_register_custom_op( + op_name="rocm_aiter_gemm_w8a8_blockscale", + op_func=_rocm_aiter_gemm_w8a8_blockscale_impl, + mutates_args=[], + fake_impl=_rocm_aiter_gemm_w8a8_blockscale_fake, + dispatch_key=current_platform.dispatch_key, + ) + + direct_register_custom_op( + op_name="rocm_aiter_rms_norm", + op_func=_rocm_aiter_rms_norm_impl, + mutates_args=[], + fake_impl=_rocm_aiter_rms_norm_fake, + dispatch_key=current_platform.dispatch_key, + ) + + direct_register_custom_op( + op_name="rocm_aiter_rmsnorm2d_fwd_with_add", + op_func=_rocm_aiter_rmsnorm2d_fwd_with_add_impl, + mutates_args=[], + fake_impl=_rocm_aiter_rmsnorm2d_fwd_with_add_fake, + dispatch_key=current_platform.dispatch_key, + ) + + _OPS_REGISTERED = True + + @staticmethod + def rms_norm2d_with_add( + x: torch.Tensor, + residual: torch.Tensor, + weight: torch.Tensor, + variance_epsilon: float, + ) -> tuple[torch.Tensor, torch.Tensor]: + return torch.ops.vllm.rocm_aiter_rmsnorm2d_fwd_with_add( + x, residual, weight, variance_epsilon + ) + + @staticmethod + def rms_norm( + x: torch.Tensor, weight: torch.Tensor, variance_epsilon: float + ) -> torch.Tensor: + return torch.ops.vllm.rocm_aiter_rms_norm(x, weight, variance_epsilon) + + @staticmethod + def gemm_w8a8( + A: torch.Tensor, + B: torch.Tensor, + As: torch.Tensor, + Bs: torch.Tensor, + bias: torch.Tensor | None = None, + output_dtype: torch.dtype = torch.float16, + ) -> torch.Tensor: + return torch.ops.vllm.rocm_aiter_gemm_w8a8(A, B, As, Bs, bias, output_dtype) + + @staticmethod + def gemm_w8a8_blockscale( + A: torch.Tensor, + B: torch.Tensor, + As: torch.Tensor, + Bs: torch.Tensor, + block_size: list[int], + output_dtype: torch.dtype = torch.float16, + ) -> torch.Tensor: + return torch.ops.vllm.rocm_aiter_gemm_w8a8_blockscale( + A, B, As, Bs, output_dtype + ) + + @staticmethod + def fused_moe( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weight: torch.Tensor, + topk_ids: torch.Tensor, + expert_mask: torch.Tensor | None = None, + activation_method: int = 0, + quant_method: int = 0, + doweight_stage1: bool = False, + w1_scale: torch.Tensor | None = None, + w2_scale: torch.Tensor | None = None, + a1_scale: torch.Tensor | None = None, + a2_scale: torch.Tensor | None = None, + ) -> torch.Tensor: + return torch.ops.vllm.rocm_aiter_fused_moe( + hidden_states, + w1, + w2, + topk_weight, + topk_ids, + expert_mask, + activation_method, + quant_method, + doweight_stage1, + w1_scale, + w2_scale, + a1_scale, + a2_scale, + ) + + @staticmethod + def asm_moe_tkw1( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + fc1_scale: torch.Tensor | None = None, + fc2_scale: torch.Tensor | None = None, + fc1_smooth_scale: torch.Tensor | None = None, + fc2_smooth_scale: torch.Tensor | None = None, + a16: bool = False, + per_tensor_quant_scale: torch.Tensor | None = None, + expert_mask: torch.Tensor | None = None, + activation_method: int = 0, + ) -> torch.Tensor: + return torch.ops.vllm.rocm_aiter_asm_moe_tkw1( + hidden_states, + w1, + w2, + topk_weights, + topk_ids, + fc1_scale, + fc2_scale, + fc1_smooth_scale, + fc2_smooth_scale, + a16, + per_tensor_quant_scale, + expert_mask, + activation_method, + ) + + @staticmethod + def topk_softmax( + topk_weights: torch.Tensor, + topk_indices: torch.Tensor, + token_expert_indices: torch.Tensor, + gating_output: torch.Tensor, + renormalize: bool, + ) -> tuple[torch.Tensor, ...]: + torch.ops.vllm.rocm_aiter_topk_softmax( + topk_weights, topk_indices, token_expert_indices, gating_output, renormalize + ) + return topk_weights, topk_indices + + @staticmethod + def biased_grouped_topk( + gating_output: torch.Tensor, + correction_bias: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + num_expert_group: int, + topk_group: int, + need_renorm: bool, + routed_scaling_factor: float = 1.0, + ) -> None: + torch.ops.vllm.rocm_aiter_biased_grouped_topk( + gating_output, + correction_bias, + topk_weights, + topk_ids, + num_expert_group, + topk_group, + need_renorm, + routed_scaling_factor, + ) + + @staticmethod + def grouped_topk( + gating_output: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + num_expert_group: int, + topk_group: int, + need_renorm: bool, + scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, + ) -> None: + torch.ops.vllm.rocm_aiter_grouped_topk( + gating_output, + topk_weights, + topk_ids, + num_expert_group, + topk_group, + need_renorm, + scoring_func, + routed_scaling_factor, + ) + + @staticmethod + def mla_decode_fwd( + q: torch.Tensor, + kv_buffer: torch.Tensor, + o: torch.Tensor, + sm_scale: float, + qo_indptr: torch.Tensor, + max_seqlen_qo: int, + kv_indptr: torch.Tensor | None = None, + kv_indices: torch.Tensor | None = None, + kv_last_page_lens: torch.Tensor | None = None, + logit_cap: float = 0.0, + ): + torch.ops.vllm.rocm_aiter_mla_decode_fwd( + q, + kv_buffer.view(-1, 1, 1, q.shape[-1]), + o, + qo_indptr, + max_seqlen_qo, + kv_indptr, + kv_indices, + kv_last_page_lens, + sm_scale=sm_scale, + logit_cap=logit_cap, + ) + + @staticmethod + def triton_fp4_gemm_dynamic_qaunt( + x: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + out_dtype: torch.dtype | None = torch.bfloat16, + x_scales: torch.Tensor | None = None, + ) -> torch.Tensor: + from aiter.ops.triton.gemm_afp4wfp4 import gemm_afp4wfp4 + from aiter.ops.triton.quant import dynamic_mxfp4_quant + + if x_scales is None: + x_q, x_s = dynamic_mxfp4_quant(x) + else: + x_q = x + x_s = x_scales + + y = torch.empty( + x_q.shape[0], weight.shape[0], device=x_q.device, dtype=out_dtype + ) + + gemm_afp4wfp4(x_q, weight, x_s, weight_scale.T, out_dtype, y) + return y + + @staticmethod + def triton_rotary_embed( + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + cos_sin_cache: torch.Tensor, + head_size: int, + rotary_dim: int, + is_neox_style: bool, + ): + from aiter.ops.triton.rope import rope_cached_thd_positions_2c_fwd_inplace + + num_tokens = positions.numel() + cos, sin = cos_sin_cache.chunk(2, dim=-1) + query_shape = query.shape + key_shape = key.shape + rotate_style = 0 if is_neox_style else 1 + + query = query.view(num_tokens, -1, head_size) + key = key.view(num_tokens, -1, head_size) + query_ = query[..., :rotary_dim] + key_ = key[..., :rotary_dim] + positions = positions.view(*query.shape[:1]) + rope_cached_thd_positions_2c_fwd_inplace( + positions, + sin, + cos, + query_, + key_, + rotate_style, + reuse_freqs_front_part=True, + is_nope_first=False, + ) + query = query.view(query_shape) + key = key.view(key_shape) + + @staticmethod + def triton_fp8_bmm( + X: torch.Tensor, + WQ: torch.Tensor, + w_scale: torch.Tensor, + group_size: int = 128, + bias: torch.Tensor | None = None, + dtype: torch.dtype | None = torch.bfloat16, + splitK: int | None = None, + YQ: torch.Tensor | None = None, + transpose_bm: bool | None = False, + config: dict | None = None, + ) -> torch.Tensor: + # ruff: noqa: E501 # isort: skip + from aiter.ops.triton.batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant import ( + batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant as aiter_triton_fp8_bmm, + ) + + return aiter_triton_fp8_bmm( + X, + WQ, + w_scale, + group_size=group_size, + bias=bias, + dtype=dtype, + splitK=splitK, + YQ=YQ, + transpose_bm=transpose_bm, + config=config, + ) + + @staticmethod + def triton_gemm_a8w8_blockscale( + A: torch.Tensor, + B: torch.Tensor, + As: torch.Tensor, + Bs: torch.Tensor, + block_size: list[int], + output_dtype: torch.dtype = torch.float16, + ) -> torch.Tensor: + from aiter.ops.triton.gemm_a8w8_blockscale import gemm_a8w8_blockscale + + return gemm_a8w8_blockscale(A, B, As, Bs, dtype=output_dtype) + + @staticmethod + def per_1x128_fp8_quant( + input_2d: torch.Tensor, + ) -> tuple[torch.Tensor, ...]: + """Only applies quantization method for fp8 data type only.""" + from aiter import QuantType, dtypes, get_hip_quant + + aiter_per1x128_quant = get_hip_quant(QuantType.per_1x128) + return aiter_per1x128_quant(input_2d.contiguous(), quant_dtype=dtypes.fp8) + + @staticmethod + def is_triton_gemm_w8a8_tuned(n: int, k: int) -> bool: + return (n, k) in [ + (1024, 8192), + (2112, 7168), + (3072, 1536), + (32768, 8192), + (4096, 7168), + (4608, 7168), + (512, 7168), + (7168, 2048), + (7168, 256), + (8192, 1024), + (8192, 32768), + ] + + @staticmethod + def shuffle_weight( + self, tensor: torch.Tensor, layout: tuple[int, int] = (16, 16) + ) -> torch.Tensor: + from aiter.ops.shuffle import shuffle_weight + + return shuffle_weight(tensor, layout=layout) + + @staticmethod + def shuffle_weights( + *tensors: torch.Tensor, layout: tuple[int, int] = (16, 16) + ) -> tuple[torch.Tensor, ...]: + """ + Applies shuffle_weight function from AITER to each + input tensor and returns them. + + Rearranges (shuffles) the input tensor/s + into a specified block layout for optimized computation. + + Args: + *tensors: Variable number of torch.Tensor objects. + layout: A pair of integers specifying the block sizes used to divide + the tensors during shuffling. Default is (16, 16). + + Returns: + A Tuple of shuffled tensors. + """ + from aiter.ops.shuffle import shuffle_weight + + return tuple(shuffle_weight(tensor, layout=layout) for tensor in tensors) + + +rocm_aiter_ops.register_ops_once() diff --git a/vllm/attention/ops/rocm_aiter_mla.py b/vllm/attention/ops/rocm_aiter_mla.py deleted file mode 100644 index 6308f63cc4e7..000000000000 --- a/vllm/attention/ops/rocm_aiter_mla.py +++ /dev/null @@ -1,105 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - - -import torch - -from vllm.platforms import current_platform -from vllm.utils.torch_utils import direct_register_custom_op, is_torch_equal_or_newer - - -def get_aiter_mla_metadata( - max_batch_size: int, block_size: int, max_block_per_batch: int, device: torch.device -) -> tuple[torch.Tensor, ...]: - paged_kv_indices = torch.zeros( - max_batch_size * max_block_per_batch, dtype=torch.int32, device=device - ) - paged_kv_indptr = torch.zeros(max_batch_size + 1, dtype=torch.int32, device=device) - paged_kv_last_page_lens = torch.full( - (max_batch_size,), block_size, dtype=torch.int32 - ) - qo_indptr = torch.zeros(max_batch_size + 1, dtype=torch.int, device=device) - return paged_kv_indices, paged_kv_indptr, paged_kv_last_page_lens, qo_indptr - - -def aiter_mla_decode_fwd( - q: torch.Tensor, - kv_buffer: torch.Tensor, - o: torch.Tensor, - sm_scale: float, - qo_indptr: torch.Tensor, - max_seqlen_qo: int, - kv_indptr: torch.Tensor | None = None, - kv_indices: torch.Tensor | None = None, - kv_last_page_lens: torch.Tensor | None = None, - logit_cap: float = 0.0, -): - torch.ops.vllm.rocm_aiter_mla_decode_fwd( - q, - kv_buffer.view(-1, 1, 1, q.shape[-1]), - o, - qo_indptr, - max_seqlen_qo, - kv_indptr, - kv_indices, - kv_last_page_lens, - sm_scale=sm_scale, - logit_cap=logit_cap, - ) - - -def mla_decode_fwd_impl( - q: torch.Tensor, - kv_buffer: torch.Tensor, - o: torch.Tensor, - qo_indptr: torch.Tensor, - max_seqlen_qo: int, - kv_indptr: torch.Tensor | None = None, - kv_indices: torch.Tensor | None = None, - kv_last_page_lens: torch.Tensor | None = None, - sm_scale: float = 1.0, - logit_cap: float = 0.0, -) -> None: - from aiter.mla import mla_decode_fwd - - mla_decode_fwd( - q, - kv_buffer.view(-1, 1, 1, q.shape[-1]), - o, - qo_indptr, - kv_indptr, - kv_indices, - kv_last_page_lens, - max_seqlen_qo, - sm_scale=sm_scale, - logit_cap=logit_cap, - ) - - -def mla_decode_fwd_fake( - q: torch.Tensor, - kv_buffer: torch.Tensor, - o: torch.Tensor, - qo_indptr: torch.Tensor, - max_seqlen_qo: int, - kv_indptr: torch.Tensor | None = None, - kv_indices: torch.Tensor | None = None, - kv_last_page_lens: torch.Tensor | None = None, - sm_scale: float = 1.0, - logit_cap: float = 0.0, -) -> None: - pass - - -if current_platform.is_rocm(): - if is_torch_equal_or_newer("2.7.0"): - tags = () - else: - tags = ((torch.Tag.needs_fixed_stride_order,),) - direct_register_custom_op( - op_name="rocm_aiter_mla_decode_fwd", - op_func=mla_decode_fwd_impl, - mutates_args=["o"], - fake_impl=mla_decode_fwd_fake, - tags=tags, - ) diff --git a/vllm/envs.py b/vllm/envs.py index 078e5c38f0f4..30c62e90e9fb 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -109,7 +109,7 @@ VLLM_ROCM_USE_AITER_MLA: bool = True VLLM_ROCM_USE_AITER_MHA: bool = True VLLM_ROCM_USE_AITER_FP4_ASM_GEMM: bool = False - VLLM_ROCM_USE_TRITON_ROPE: bool = False + VLLM_ROCM_USE_AITER_TRITON_ROPE: bool = False VLLM_ROCM_USE_AITER_FP8BMM: bool = True VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION: bool = False VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS: bool = True @@ -926,8 +926,8 @@ def get_vllm_port() -> int | None: ), # Whether to use aiter rope. # By default is disabled. - "VLLM_ROCM_USE_TRITON_ROPE": lambda: ( - os.getenv("VLLM_ROCM_USE_TRITON_ROPE", "False").lower() in ("true", "1") + "VLLM_ROCM_USE_AITER_TRITON_ROPE": lambda: ( + os.getenv("VLLM_ROCM_USE_AITER_TRITON_ROPE", "False").lower() in ("true", "1") ), # Whether to use aiter triton fp8 bmm kernel # By default is enabled. @@ -1589,7 +1589,7 @@ def compute_hash() -> str: "VLLM_ROCM_USE_AITER_MLA", "VLLM_ROCM_USE_AITER_MHA", "VLLM_ROCM_USE_AITER_FP4_ASM_GEMM", - "VLLM_ROCM_USE_TRITON_ROPE", + "VLLM_ROCM_USE_AITER_TRITON_ROPE", "VLLM_ROCM_USE_AITER_FP8BMM", "VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION", "VLLM_ROCM_USE_AITER_TRITON_GEMM", diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 7ad3ce1397b3..2e042d85fcfc 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -14,6 +14,7 @@ import vllm.envs as envs import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm import _custom_ops as ops +from vllm._aiter_ops import rocm_aiter_ops from vllm.logger import init_logger from vllm.model_executor.layers.batch_invariant import ( vllm_is_batch_invariant, @@ -55,8 +56,6 @@ from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used from vllm.utils.torch_utils import direct_register_custom_op, is_torch_equal_or_newer -from .rocm_aiter_fused_moe import is_rocm_aiter_moe_enabled - logger = init_logger(__name__) @@ -1089,11 +1088,11 @@ def vllm_topk_softmax( return topk_weights, topk_indices -def dispatch_topk_func() -> Callable[..., tuple[torch.Tensor, ...]]: - if is_rocm_aiter_moe_enabled(): - from .rocm_aiter_fused_moe import rocm_aiter_topk_softmax - - return rocm_aiter_topk_softmax +def dispatch_topk_func( + use_rocm_aiter: bool = False, +) -> Callable[..., tuple[torch.Tensor, ...]]: + if use_rocm_aiter: + return rocm_aiter_ops.topk_softmax return vllm_topk_softmax @@ -1121,7 +1120,7 @@ def fused_topk( M, topk, dtype=torch.int32, device=hidden_states.device ) - topk_func = dispatch_topk_func() + topk_func = dispatch_topk_func(use_rocm_aiter=rocm_aiter_ops.is_fused_moe_enabled()) topk_weights, topk_ids = topk_func( topk_weights, topk_ids, token_expert_indices, gating_output, renormalize ) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index e69ead074c50..45b0f50a7997 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -13,6 +13,7 @@ from torch.nn.parameter import UninitializedParameter import vllm.envs as envs +from vllm._aiter_ops import rocm_aiter_ops from vllm.config import VllmConfig, get_current_vllm_config from vllm.config.parallel import ExpertPlacementStrategy from vllm.distributed import ( @@ -41,8 +42,6 @@ ) from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( init_aiter_topK_meta_data, - is_rocm_aiter_fusion_shared_expert_enabled, - is_rocm_aiter_moe_enabled, ) from vllm.model_executor.layers.fused_moe.routing_simulator import RoutingSimulator from vllm.model_executor.layers.quantization.base_config import ( @@ -92,13 +91,11 @@ def _eplb_map_to_physical_and_record( return topk_ids eplb_map_to_physical_and_record = _eplb_map_to_physical_and_record +from vllm.model_executor.layers.fused_moe.fused_moe import grouped_topk +from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa: E501 + rocm_aiter_grouped_topk, +) -if is_rocm_aiter_moe_enabled(): - from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa: E501 - rocm_aiter_grouped_topk as grouped_topk_aiter, - ) -else: - from vllm.model_executor.layers.fused_moe.fused_moe import grouped_topk if current_platform.is_tpu(): from .moe_pallas import fused_moe as fused_moe_pallas else: @@ -463,7 +460,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): def __init__(self, moe: FusedMoEConfig): super().__init__(moe) - self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled() + + self.rocm_aiter_moe_enabled = rocm_aiter_ops.is_fused_moe_enabled() if self.rocm_aiter_moe_enabled: from .rocm_aiter_fused_moe import rocm_aiter_fused_experts @@ -620,13 +618,9 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: # Padding the weight for better performance on ROCm layer.w13_weight.data = self._maybe_pad_weight(layer.w13_weight.data) layer.w2_weight.data = self._maybe_pad_weight(layer.w2_weight.data) - # Lazy import to avoid importing triton. - from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( - shuffle_weights, - ) if self.rocm_aiter_moe_enabled: - shuffled_w13, shuffled_w2 = shuffle_weights( + shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights( layer.w13_weight.data, layer.w2_weight.data ) @@ -1002,6 +996,7 @@ def determine_expert_map( global_num_experts: int, expert_placement_strategy: ExpertPlacementStrategy = "linear", num_fused_shared_experts: int = 0, + return_expert_mask: bool = False, ) -> tuple[int, torch.Tensor | None, torch.Tensor | None]: """ Calculates how many experts should be assigned to each rank for EP and @@ -1064,7 +1059,7 @@ def determine_expert_map( ) expert_mask = None - if is_rocm_aiter_moe_enabled(): + if return_expert_mask: expert_mask = torch.ones( (global_num_experts + num_fused_shared_experts + 1,), dtype=torch.int32 ) @@ -1292,14 +1287,18 @@ def __init__( self.logical_replica_count: torch.Tensor | None = None # ROCm aiter shared experts fusion + self.rocm_aiter_fmoe_enabled = rocm_aiter_ops.is_fused_moe_enabled() + self.aiter_fmoe_shared_expert_enabled = ( + rocm_aiter_ops.is_fusion_moe_shared_experts_enabled() + ) + self.num_fused_shared_experts = ( n_shared_experts - if n_shared_experts is not None - and is_rocm_aiter_fusion_shared_expert_enabled() + if n_shared_experts is not None and self.aiter_fmoe_shared_expert_enabled else 0 ) if ( - not is_rocm_aiter_fusion_shared_expert_enabled() + not self.aiter_fmoe_shared_expert_enabled and self.num_fused_shared_experts != 0 ): raise ValueError( @@ -1346,6 +1345,7 @@ def __init__( global_num_experts=self.global_num_experts, expert_placement_strategy=expert_placement_strategy, num_fused_shared_experts=self.num_fused_shared_experts, + return_expert_mask=self.rocm_aiter_fmoe_enabled, ) self.local_num_experts = local_num_experts self.register_buffer("expert_map", expert_map) @@ -1570,13 +1570,16 @@ def update_expert_map(self): ep_rank=self.ep_rank, global_num_experts=self.global_num_experts, num_fused_shared_experts=self.num_fused_shared_experts, + return_expert_mask=self.rocm_aiter_fmoe_enabled, ) self.local_num_experts = local_num_experts self.register_buffer("expert_map", expert_map) self.register_buffer("expert_mask", expert_mask) - self._init_aiter_shared_experts_topK_buffer( - vllm_config=get_current_vllm_config(), dp_size=get_dp_group().world_size - ) + if self.aiter_fmoe_shared_expert_enabled: + self._init_aiter_shared_experts_topK_buffer( + vllm_config=get_current_vllm_config(), + dp_size=get_dp_group().world_size, + ) def _load_per_tensor_weight_scale( self, @@ -1753,20 +1756,19 @@ def _map_global_expert_id_to_local_expert_id(self, expert_id: int) -> int: def _init_aiter_shared_experts_topK_buffer( self, vllm_config: VllmConfig, dp_size: int ): - if is_rocm_aiter_fusion_shared_expert_enabled(): - if self.num_fused_shared_experts > 0: - init_aiter_topK_meta_data( - n_routed_experts=self.global_num_experts, - n_shared_experts=self.num_fused_shared_experts, - top_k=self.top_k, - tp_rank=self.ep_rank if self.use_ep else self.tp_rank, - tp_size=self.ep_size if self.use_ep else self.tp_size, - shared_experts_score=1.0, - max_num_tokens=vllm_config.scheduler_config.max_num_batched_tokens - * dp_size, - is_EP=self.use_ep, - ) - self.local_num_experts += self.num_fused_shared_experts + if self.num_fused_shared_experts > 0: + init_aiter_topK_meta_data( + n_routed_experts=self.global_num_experts, + n_shared_experts=self.num_fused_shared_experts, + top_k=self.top_k, + tp_rank=self.ep_rank if self.use_ep else self.tp_rank, + tp_size=self.ep_size if self.use_ep else self.tp_size, + shared_experts_score=1.0, + max_num_tokens=vllm_config.scheduler_config.max_num_batched_tokens + * dp_size, + is_EP=self.use_ep, + ) + self.local_num_experts += self.num_fused_shared_experts @overload def weight_loader( @@ -2208,15 +2210,16 @@ def select_experts( elif use_grouped_topk: assert topk_group is not None assert num_expert_group is not None - if is_rocm_aiter_moe_enabled(): - if not is_rocm_aiter_fusion_shared_expert_enabled(): + if rocm_aiter_ops.is_fused_moe_enabled(): + if not rocm_aiter_ops.is_fusion_moe_shared_experts_enabled(): assert num_fused_shared_experts == 0 grouped_topk_impl = partial( - grouped_topk_aiter, + rocm_aiter_grouped_topk, num_fused_shared_experts=num_fused_shared_experts, ) else: grouped_topk_impl = grouped_topk + topk_weights, topk_ids = grouped_topk_impl( hidden_states=hidden_states, gating_output=router_logits, @@ -2448,7 +2451,7 @@ def process_chunk(chunk_start, chunk_end, skip_result_store=False): use_grouped_topk=self.use_grouped_topk, global_num_experts=self.global_num_experts, expert_map=self.expert_map - if not is_rocm_aiter_moe_enabled() + if not self.rocm_aiter_fmoe_enabled else self.expert_mask, topk_group=self.topk_group, num_expert_group=self.num_expert_group, @@ -2612,7 +2615,7 @@ def forward_impl( use_grouped_topk=self.use_grouped_topk, global_num_experts=self.global_num_experts, expert_map=self.expert_map - if not is_rocm_aiter_moe_enabled() + if not self.rocm_aiter_fmoe_enabled else self.expert_mask, topk_group=self.topk_group, num_expert_group=self.num_expert_group, diff --git a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py index e18514ad43f6..8f05828d74f5 100644 --- a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py @@ -1,17 +1,15 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from enum import IntEnum -from functools import cache, lru_cache +from functools import lru_cache import torch -from vllm import envs +from vllm._aiter_ops import rocm_aiter_ops from vllm.model_executor.layers.fused_moe.config import ( FUSED_MOE_UNQUANTIZED_CONFIG, FusedMoEQuantConfig, ) -from vllm.platforms import current_platform -from vllm.utils.torch_utils import direct_register_custom_op class QuantMethod(IntEnum): @@ -37,27 +35,6 @@ class ActivationMethod(IntEnum): GELU = 1 -@cache -def is_rocm_aiter_moe_enabled() -> bool: - return ( - current_platform.is_rocm() - and envs.VLLM_ROCM_USE_AITER_MOE - and envs.VLLM_ROCM_USE_AITER - ) - - -@cache -def use_mxfp4_aiter_moe() -> bool: - return current_platform.is_rocm() and envs.VLLM_ROCM_USE_AITER - - -@cache -def is_rocm_aiter_fusion_shared_expert_enabled() -> bool: - return ( - envs.VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS and is_rocm_aiter_moe_enabled() - ) - - aiter_topK_meta_data = None @@ -114,250 +91,6 @@ def init_aiter_topK_meta_data( aiter_topK_meta_data = (total_topk_weights, total_topk_ids) -def rocm_aiter_asm_moe_tkw1_impl( - hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - fc1_scale: torch.Tensor | None = None, - fc2_scale: torch.Tensor | None = None, - fc1_smooth_scale: torch.Tensor | None = None, - fc2_smooth_scale: torch.Tensor | None = None, - a16: bool = False, - per_tensor_quant_scale: torch.Tensor | None = None, - expert_mask: torch.Tensor | None = None, - activation_method: int = ActivationMethod.SILU.value, -) -> torch.Tensor: - from aiter import ActivationType - from aiter.fused_moe_bf16_asm import asm_moe_tkw1 - - activation = ActivationType(activation_method) - - return asm_moe_tkw1( - hidden_states, - w1, - w2, - topk_weights, - topk_ids, - fc1_scale=fc1_scale, - fc2_scale=fc2_scale, - fc1_smooth_scale=fc1_smooth_scale, - fc2_smooth_scale=fc2_smooth_scale, - a16=a16, - per_tensor_quant_scale=per_tensor_quant_scale, - expert_mask=expert_mask, - activation=activation, - ) - - -def rocm_aiter_asm_moe_tkw1_fake( - hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - fc1_scale: torch.Tensor | None = None, - fc2_scale: torch.Tensor | None = None, - fc1_smooth_scale: torch.Tensor | None = None, - fc2_smooth_scale: torch.Tensor | None = None, - a16: bool = False, - per_tensor_quant_scale: torch.Tensor | None = None, - expert_mask: torch.Tensor | None = None, - activation_method: int = ActivationMethod.SILU.value, -) -> torch.Tensor: - return torch.empty_like(hidden_states) - - -def rocm_aiter_topk_softmax_impl( - topk_weights: torch.Tensor, - topk_indices: torch.Tensor, - token_expert_indices: torch.Tensor, - gating_output: torch.Tensor, - renormalize: bool, -) -> None: - from aiter import topk_softmax - - topk_softmax( - topk_weights, topk_indices, token_expert_indices, gating_output, renormalize - ) - - -def rocm_aiter_topk_softmax_fake( - topk_weights: torch.Tensor, - topk_indices: torch.Tensor, - token_expert_indices: torch.Tensor, - gating_output: torch.Tensor, - renormalize: bool, -) -> None: - pass - - -def rocm_aiter_biased_grouped_topk_impl( - gating_output: torch.Tensor, - correction_bias: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - num_expert_group: int, - topk_group: int, - need_renorm: bool, - routed_scaling_factor: float = 1.0, # mul to topk_weights -) -> None: - from aiter import biased_grouped_topk - - biased_grouped_topk( - gating_output, - correction_bias, - topk_weights, - topk_ids, - num_expert_group, - topk_group, - need_renorm, - routed_scaling_factor, - ) - - -def rocm_aiter_biased_grouped_topk_fake( - gating_output: torch.Tensor, - correction_bias: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - num_expert_group: int, - topk_group: int, - need_renorm: bool, - routed_scaling_factor: float = 1.0, # mul to topk_weights -) -> None: - pass - - -def rocm_aiter_grouped_topk_impl( - gating_output: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - num_expert_group: int, - topk_group: int, - need_renorm: bool, - scoring_func: str = "softmax", - routed_scaling_factor: float = 1.0, # mul to topk_weights -) -> None: - from aiter import grouped_topk - - grouped_topk( - gating_output, - topk_weights, - topk_ids, - num_expert_group, - topk_group, - need_renorm, - scoring_func, - routed_scaling_factor, - ) - - -def rocm_aiter_grouped_topk_fake( - gating_output: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - num_expert_group: int, - topk_group: int, - need_renorm: bool, - scoring_func: str = "softmax", - routed_scaling_factor: float = 1.0, # mul to topk_weights -) -> None: - pass - - -def rocm_aiter_fused_moe_impl( - hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - topk_weight: torch.Tensor, - topk_ids: torch.Tensor, - expert_mask: torch.Tensor | None = None, - activation_method: int = ActivationMethod.SILU.value, - quant_method: int = QuantMethod.NO.value, - doweight_stage1: bool = False, - w1_scale: torch.Tensor | None = None, - w2_scale: torch.Tensor | None = None, - a1_scale: torch.Tensor | None = None, - a2_scale: torch.Tensor | None = None, -) -> torch.Tensor: - from aiter import ActivationType, QuantType - from aiter.fused_moe import fused_moe - - activation = ActivationType(activation_method) - quant_type = QuantType(quant_method) - - return fused_moe( - hidden_states, - w1, - w2, - topk_weight, - topk_ids, - expert_mask, - activation, - quant_type, - doweight_stage1, - w1_scale, - w2_scale, - a1_scale, - a2_scale, - ) - - -def rocm_aiter_fused_moe_fake( - hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - topk_weight: torch.Tensor, - topk_ids: torch.Tensor, - expert_mask: torch.Tensor | None = None, - activation_method: int = ActivationMethod.SILU.value, - quant_method: int = QuantMethod.NO.value, - doweight_stage1: bool = False, - w1_scale: torch.Tensor | None = None, - w2_scale: torch.Tensor | None = None, - a1_scale: torch.Tensor | None = None, - a2_scale: torch.Tensor | None = None, -) -> torch.Tensor: - return torch.empty_like(hidden_states) - - -if current_platform.is_rocm(): - direct_register_custom_op( - op_name="rocm_aiter_asm_moe_tkw1", - op_func=rocm_aiter_asm_moe_tkw1_impl, - fake_impl=rocm_aiter_asm_moe_tkw1_fake, - ) - - direct_register_custom_op( - op_name="rocm_aiter_fused_moe", - op_func=rocm_aiter_fused_moe_impl, - fake_impl=rocm_aiter_fused_moe_fake, - ) - - direct_register_custom_op( - op_name="rocm_aiter_topk_softmax", - op_func=rocm_aiter_topk_softmax_impl, - mutates_args=["topk_weights", "topk_indices", "token_expert_indices"], - fake_impl=rocm_aiter_topk_softmax_fake, - ) - - direct_register_custom_op( - op_name="rocm_aiter_biased_grouped_topk", - op_func=rocm_aiter_biased_grouped_topk_impl, - mutates_args=["topk_weights", "topk_ids"], - fake_impl=rocm_aiter_biased_grouped_topk_fake, - ) - - direct_register_custom_op( - op_name="rocm_aiter_grouped_topk", - op_func=rocm_aiter_grouped_topk_impl, - mutates_args=["topk_weights", "topk_ids"], - fake_impl=rocm_aiter_grouped_topk_fake, - ) - - def rocm_aiter_grouped_topk( hidden_states: torch.Tensor, gating_output: torch.Tensor, @@ -372,7 +105,10 @@ def rocm_aiter_grouped_topk( ) -> tuple[torch.Tensor, torch.Tensor]: token = hidden_states.shape[0] device = hidden_states.device - if is_rocm_aiter_fusion_shared_expert_enabled() and num_fused_shared_experts > 0: + if ( + rocm_aiter_ops.is_fusion_moe_shared_experts_enabled() + and num_fused_shared_experts > 0 + ): assert aiter_topK_meta_data is not None, ( "AITER topK meta data is not initialized. " "Please ensure that init_aiter_topK_meta_data " @@ -397,7 +133,7 @@ def rocm_aiter_grouped_topk( topk_weights = torch.empty((token, topk), dtype=torch.float32, device=device) if e_score_correction_bias is not None: - torch.ops.vllm.rocm_aiter_biased_grouped_topk( + rocm_aiter_ops.biased_grouped_topk( gating_output, e_score_correction_bias.to(gating_output.dtype), topk_weights, @@ -409,7 +145,7 @@ def rocm_aiter_grouped_topk( ) else: assert scoring_func == "softmax" or scoring_func == "sigmoid" - torch.ops.vllm.rocm_aiter_grouped_topk( + rocm_aiter_ops.grouped_topk( gating_output, topk_weights, topk_ids, @@ -420,7 +156,10 @@ def rocm_aiter_grouped_topk( routed_scaling_factor=routed_scaling_factor, ) - if is_rocm_aiter_fusion_shared_expert_enabled() and num_fused_shared_experts > 0: + if ( + rocm_aiter_ops.is_fusion_moe_shared_experts_enabled() + and num_fused_shared_experts > 0 + ): return total_topk_weights, total_topk_ids return topk_weights, topk_ids @@ -464,7 +203,7 @@ def rocm_aiter_fused_experts( "Only support topk=1 when `apply_router_weight_on_input` is True" ) - return torch.ops.vllm.rocm_aiter_asm_moe_tkw1( + return rocm_aiter_ops.asm_moe_tkw1( hidden_states, w1, w2, @@ -482,7 +221,9 @@ def rocm_aiter_fused_experts( else: quant_method = QuantMethod.NO.value - + # quark moe for mxfp4 w_dtype + if quant_config.use_mxfp4_w4a16: + quant_method = QuantMethod.BLOCK_1X32.value # w8a8 block-scaled if quant_config.block_shape is not None and quant_config.use_fp8_w8a8: assert not apply_router_weight_on_input, ( @@ -507,7 +248,7 @@ def rocm_aiter_fused_experts( "Only support topk=1 when `apply_router_weight_on_input` is True" ) - return torch.ops.vllm.rocm_aiter_fused_moe( + return rocm_aiter_ops.fused_moe( hidden_states, w1, w2, @@ -522,39 +263,3 @@ def rocm_aiter_fused_experts( a2_scale=quant_config.a2_scale, doweight_stage1=apply_router_weight_on_input, ) - - -def rocm_aiter_topk_softmax( - topk_weights: torch.Tensor, - topk_indices: torch.Tensor, - token_expert_indices: torch.Tensor, - gating_output: torch.Tensor, - renormalize: bool, -) -> tuple[torch.Tensor, ...]: - torch.ops.vllm.rocm_aiter_topk_softmax( - topk_weights, topk_indices, token_expert_indices, gating_output, renormalize - ) - return topk_weights, topk_indices - - -def shuffle_weights( - *tensors: torch.Tensor, layout: tuple[int, int] = (16, 16) -) -> tuple[torch.Tensor, ...]: - """ - Applies shuffle_weight function from AITER to each - input tensor and returns them. - - Rearranges (shuffles) the input tensor/s - into a specified block layout for optimized computation. - - Args: - *tensors: Variable number of torch.Tensor objects. - layout: A pair of integers specifying the block sizes used to divide - the tensors during shuffling. Default is (16, 16). - - Returns: - A Tuple of shuffled tensors. - """ - from aiter.ops.shuffle import shuffle_weight - - return tuple(shuffle_weight(tensor, layout=layout) for tensor in tensors) diff --git a/vllm/model_executor/layers/layernorm.py b/vllm/model_executor/layers/layernorm.py index a883ac81f41e..8cc374ac9155 100644 --- a/vllm/model_executor/layers/layernorm.py +++ b/vllm/model_executor/layers/layernorm.py @@ -6,18 +6,13 @@ import torch.nn as nn import torch.nn.functional as F -import vllm.envs as envs +from vllm._aiter_ops import rocm_aiter_ops from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.layers.batch_invariant import ( rms_norm_batch_invariant, vllm_is_batch_invariant, ) from vllm.platforms import current_platform -from vllm.utils.torch_utils import direct_register_custom_op - - -def is_rocm_aiter_rmsnorm_enabled() -> bool: - return envs.VLLM_ROCM_USE_AITER_RMSNORM and envs.VLLM_ROCM_USE_AITER def rms_norm( @@ -58,80 +53,34 @@ def fused_add_rms_norm( return x, residual -def rocm_aiter_rms_norm_impl( - x: torch.Tensor, weight: torch.Tensor, variance_epsilon: float +def poly_norm( + x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, variance_epsilon: float ) -> torch.Tensor: - import aiter as rocm_aiter - - if x.dim() > 2: - x_original_shape = x.shape - x = x.reshape(-1, x_original_shape[-1]) - x = rocm_aiter.rms_norm(x, weight, variance_epsilon) - return x.reshape(x_original_shape) - - return rocm_aiter.rms_norm(x, weight, variance_epsilon) - + from vllm import _custom_ops as ops -def rocm_aiter_rmsnorm2d_fwd_with_add_impl( - x: torch.Tensor, - residual: torch.Tensor, - weight: torch.Tensor, - variance_epsilon: float, -) -> tuple[torch.Tensor, torch.Tensor]: - import aiter as rocm_aiter - - residual_out = torch.empty_like(residual) - output = torch.empty_like(x) - rocm_aiter.rmsnorm2d_fwd_with_add( - output, # output - x, # input - residual, # residual input - residual_out, # residual output + out = torch.empty_like(x) + ops.poly_norm( + out, + x, weight, + bias, variance_epsilon, ) - return output, residual_out - - -def rocm_aiter_rms_norm_fake( - x: torch.Tensor, weight: torch.Tensor, variance_epsilon: float -) -> torch.Tensor: - return torch.empty_like(x) - - -def rocm_aiter_rmsnorm2d_fwd_with_add_fake( - x: torch.Tensor, - residual: torch.Tensor, - weight: torch.Tensor, - variance_epsilon: float, -) -> tuple[torch.Tensor, torch.Tensor]: - return torch.empty_like(x), torch.empty_like(residual) - - -if current_platform.is_rocm(): - direct_register_custom_op( - op_name="rocm_aiter_rms_norm", - op_func=rocm_aiter_rms_norm_impl, - fake_impl=rocm_aiter_rms_norm_fake, - ) - - direct_register_custom_op( - op_name="rocm_aiter_rmsnorm2d_fwd_with_add", - op_func=rocm_aiter_rmsnorm2d_fwd_with_add_impl, - fake_impl=rocm_aiter_rmsnorm2d_fwd_with_add_fake, - ) + return out -def dispatch_rocm_rmsnorm_func(with_fused_add: bool, dtype: torch.dtype): - use_aiter = is_rocm_aiter_rmsnorm_enabled() and dtype in [ +def dispatch_rocm_rmsnorm_func( + with_fused_add: bool, dtype: torch.dtype, use_aiter: bool = False +): + use_aiter = use_aiter and dtype in [ torch.float16, torch.bfloat16, ] if use_aiter and with_fused_add: - return torch.ops.vllm.rocm_aiter_rmsnorm2d_fwd_with_add + return rocm_aiter_ops.rms_norm2d_with_add if use_aiter: - return torch.ops.vllm.rocm_aiter_rms_norm + return rocm_aiter_ops.rms_norm # fall back to CUDA implementation if with_fused_add: @@ -169,11 +118,14 @@ def __init__( self.weight = nn.Parameter(self.weight) if current_platform.is_rocm(): + aiter_rmsnorm_enabled = rocm_aiter_ops.is_rmsnorm_enabled() self.rocm_norm_func = dispatch_rocm_rmsnorm_func( - with_fused_add=False, dtype=weight_dtype + with_fused_add=False, + dtype=weight_dtype, + use_aiter=aiter_rmsnorm_enabled, ) self.rocm_norm_func_with_add = dispatch_rocm_rmsnorm_func( - with_fused_add=True, dtype=weight_dtype + with_fused_add=True, dtype=weight_dtype, use_aiter=aiter_rmsnorm_enabled ) @staticmethod diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index d95d49eddfe3..d32ae6674ee6 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -12,6 +12,7 @@ import vllm.envs as envs import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm import _custom_ops as ops +from vllm._aiter_ops import rocm_aiter_ops from vllm.distributed import get_tensor_model_parallel_world_size from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe import ( @@ -582,11 +583,8 @@ def __init__( # Disable marlin for rocm if current_platform.is_rocm(): self.use_marlin = False - from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( - is_rocm_aiter_moe_enabled, - ) - self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled() + self.rocm_aiter_moe_enabled = rocm_aiter_ops.is_fused_moe_enabled() # cutlass path self.is_fp8_w8a8_sm100 = quant_config._is_fp8_w8a8_sm100( @@ -829,12 +827,8 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: # Property to determine if AITER is used if self.rocm_aiter_moe_enabled: - from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa E501 - shuffle_weights, - ) - # reshaping weights is required for aiter moe kernel. - shuffled_w13, shuffled_w2 = shuffle_weights( + shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights( layer.w13_weight.data, layer.w2_weight.data ) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py index ee431c9148b8..6da136cbc8f6 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py @@ -7,12 +7,12 @@ from compressed_tensors.quantization import QuantizationArgs, QuantizationStrategy from torch.nn import Parameter +from vllm._aiter_ops import rocm_aiter_ops from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( CompressedTensorsScheme, ) from vllm.model_executor.layers.quantization.utils.fp8_utils import ( W8A8BlockFp8LinearOp, - check_aiter_fp8_linear_support, create_fp8_input_scale, create_fp8_scale_parameter, create_fp8_weight_parameter, @@ -61,7 +61,7 @@ def __init__(self, weight_quant: QuantizationArgs, is_static_input_scheme: bool) ) self.cutlass_block_fp8_supported = cutlass_block_fp8_supported() - self.use_aiter_and_is_supported = check_aiter_fp8_linear_support() + self.use_aiter_and_is_supported = rocm_aiter_ops.is_linear_fp8_enaled() if self.weight_block_size is not None: assert not self.is_static_input_scheme diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index ce40645782e5..e4e1cbff712f 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -12,6 +12,7 @@ import vllm.envs as envs import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm import _custom_ops as ops +from vllm._aiter_ops import rocm_aiter_ops from vllm.distributed import get_tensor_model_parallel_world_size from vllm.logger import init_logger from vllm.model_executor.layers.batch_invariant import ( @@ -56,7 +57,6 @@ ) from vllm.model_executor.layers.quantization.utils.fp8_utils import ( W8A8BlockFp8LinearOp, - check_aiter_fp8_linear_support, create_fp8_input_scale, create_fp8_scale_parameter, create_fp8_weight_parameter, @@ -369,7 +369,7 @@ def __init__(self, quant_config: Fp8Config): if vllm_is_batch_invariant(): self.use_marlin = False - self.use_aiter_and_is_supported = check_aiter_fp8_linear_support() + self.use_aiter_and_is_supported = rocm_aiter_ops.is_linear_fp8_enaled() self.use_deep_gemm = is_deep_gemm_supported() self.weight_block_size = self.quant_config.weight_block_size @@ -869,12 +869,8 @@ def create_weights( def process_weights_after_loading(self, layer: Module) -> None: # Lazy import to avoid importing triton too early. - from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( - is_rocm_aiter_moe_enabled, - shuffle_weights, - ) - self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled() + self.rocm_aiter_moe_enabled = rocm_aiter_ops.is_fused_moe_enabled() # TODO (rob): refactor block quant into separate class. if self.block_quant: @@ -916,7 +912,7 @@ def process_weights_after_loading(self, layer: Module) -> None: ) if self.rocm_aiter_moe_enabled: # reshaping weights is required for aiter moe kernel. - shuffled_w13, shuffled_w2 = shuffle_weights( + shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights( layer.w13_weight.data, layer.w2_weight.data ) @@ -962,7 +958,7 @@ def process_weights_after_loading(self, layer: Module) -> None: layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False) if self.rocm_aiter_moe_enabled: # reshaping weights is required for aiter moe kernel. - shuffled_w13, shuffled_w2 = shuffle_weights( + shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights( layer.w13_weight, layer.w2_weight ) @@ -1042,7 +1038,7 @@ def process_weights_after_loading(self, layer: Module) -> None: start += shard_size if self.rocm_aiter_moe_enabled: - shuffled_w13, shuffled_w2 = shuffle_weights( + shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights( layer.w13_weight, layer.w2_weight ) diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/aiter.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/aiter.py index a19396a162bc..f5cd91469b78 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/aiter.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/aiter.py @@ -4,54 +4,14 @@ import torch -import vllm.envs as envs from vllm import _custom_ops as ops +from vllm._aiter_ops import rocm_aiter_ops from vllm.platforms import current_platform -from vllm.utils.torch_utils import direct_register_custom_op from .cutlass import CutlassScaledMMLinearKernel from .ScaledMMLinearKernel import ScaledMMLinearLayerConfig -def rocm_aiter_gemm_w8a8_impl( - A: torch.Tensor, - B: torch.Tensor, - As: torch.Tensor, - Bs: torch.Tensor, - bias: torch.Tensor | None = None, - output_dtype: torch.dtype = torch.float16, -) -> torch.Tensor: - from aiter import gemm_a8w8_CK - - # gemm_a8w8_CK(a, b, scale_a, scale_b, bias) expects - # a to be [M, K] - # b to be [N, K] - # CutlassScaledMMLinearKernel prepare weight `w_q` in [K, N] format - return gemm_a8w8_CK(A, B, As, Bs, bias, output_dtype) - - -def rocm_aiter_gemm_w8a8_fake( - A: torch.Tensor, - B: torch.Tensor, - As: torch.Tensor, - Bs: torch.Tensor, - bias: torch.Tensor | None = None, - output_dtype: torch.dtype = torch.float16, -) -> torch.Tensor: - m = A.shape[0] - n = B.shape[0] - Y = torch.empty(m, n, dtype=output_dtype, device=A.device) - return Y - - -if current_platform.is_rocm(): - direct_register_custom_op( - op_name="rocm_aiter_gemm_w8a8", - op_func=rocm_aiter_gemm_w8a8_impl, - fake_impl=rocm_aiter_gemm_w8a8_fake, - ) - - class AiterScaledMMLinearKernel(CutlassScaledMMLinearKernel): @classmethod def get_min_capability(cls) -> int: @@ -75,7 +35,7 @@ def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: + "installed on ROCm.", ) # Check if rocm_aiter_gemm_w8a8_scaled_mm is enabled - if not (envs.VLLM_ROCM_USE_AITER_LINEAR and envs.VLLM_ROCM_USE_AITER): + if not (rocm_aiter_ops.is_linear_enabled()): return ( False, "AiterScaledMMLinearKernel is disabled. " @@ -157,6 +117,4 @@ def apply_weights( # a to be [M, K] # b to be [N, K] # CutlassScaledMMLinearKernel prepare weight `w_q` in [K, N] format - return torch.ops.vllm.rocm_aiter_gemm_w8a8( - x_q, w_q.t(), x_s, w_s, bias, out_dtype - ) + return rocm_aiter_ops.gemm_w8a8(x_q, w_q.t(), x_s, w_s, bias, out_dtype) diff --git a/vllm/model_executor/layers/quantization/quark/quark_moe.py b/vllm/model_executor/layers/quantization/quark/quark_moe.py index eca6b0cb1d8e..30772c3665b0 100644 --- a/vllm/model_executor/layers/quantization/quark/quark_moe.py +++ b/vllm/model_executor/layers/quantization/quark/quark_moe.py @@ -8,6 +8,7 @@ import vllm.envs as envs from vllm import _custom_ops as ops +from vllm._aiter_ops import rocm_aiter_ops from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe import ( FusedMoE, @@ -21,10 +22,6 @@ ocp_mx_moe_quant_config, ) from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe -from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( - is_rocm_aiter_moe_enabled, - use_mxfp4_aiter_moe, -) from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( prepare_moe_fp8_layer_for_marlin, ) @@ -122,7 +119,7 @@ def __init__( if current_platform.is_rocm(): self.use_marlin = False - self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled() + self.rocm_aiter_moe_enabled = rocm_aiter_ops.is_fused_moe_enabled() def create_weights( self, @@ -309,12 +306,8 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: ) # Property to determine if AITER is used if self.rocm_aiter_moe_enabled: - from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa E501 - shuffle_weights, - ) - # reshaping weights is required for aiter moe kernel. - shuffled_w13, shuffled_w2 = shuffle_weights( + shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights( layer.w13_weight.data, layer.w2_weight.data ) @@ -470,13 +463,15 @@ def __init__( "not implemented. Please open an issue." ) + self.use_rocm_aiter_moe = rocm_aiter_ops.is_fused_moe_enabled() + self.emulate = not current_platform.supports_mx() or not ( - use_mxfp4_aiter_moe() and self.ocp_mx_scheme == "w_mxfp4_a_mxfp4" + self.use_rocm_aiter_moe and self.ocp_mx_scheme == "w_mxfp4_a_mxfp4" ) if self.emulate: logger.warning_once( f"The current mode (supports_mx={current_platform.supports_mx()}, " - f"use_mxfp4_aiter_moe={use_mxfp4_aiter_moe()}, " + f"use_mxfp4_aiter_moe={self.use_rocm_aiter_moe}, " f"ocp_mx_scheme={self.ocp_mx_scheme}) " "does not support native MXFP4/MXFP6 " "computation. Simulated weight dequantization and activation " @@ -656,28 +651,18 @@ def apply( ) if not self.emulate: - from aiter import ActivationType, QuantType - from aiter.fused_moe import fused_moe - - aiter_acts = { - ActivationType.No.name.lower(): ActivationType.No, - ActivationType.Silu.name.lower(): ActivationType.Silu, - ActivationType.Gelu.name.lower(): ActivationType.Gelu, - } - assert activation in aiter_acts, ( - f"Aiter CK fp4 MoE doesn't support activation {activation}" - ) - out = fused_moe( + from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( + rocm_aiter_fused_experts, + ) + + out = rocm_aiter_fused_experts( x, layer.w13_weight, layer.w2_weight, - topk_weights, - topk_ids, - quant_type=QuantType.per_1x32, - w1_scale=layer.w13_weight_scale, - w2_scale=layer.w2_weight_scale, - activation=aiter_acts[activation], - doweight_stage1=False, + topk_weights=topk_weights, + topk_ids=topk_ids, + activation=activation, + quant_config=self.moe_quant_config, ) else: from vllm.model_executor.layers.fused_moe import fused_experts diff --git a/vllm/model_executor/layers/quantization/quark/schemes/quark_ocp_mx.py b/vllm/model_executor/layers/quantization/quark/schemes/quark_ocp_mx.py index c25c522dea55..007e78e68d5c 100644 --- a/vllm/model_executor/layers/quantization/quark/schemes/quark_ocp_mx.py +++ b/vllm/model_executor/layers/quantization/quark/schemes/quark_ocp_mx.py @@ -31,6 +31,13 @@ logger = init_logger(__name__) +# TODO: move registration of custom op to aiter_ops.py +# `from vllm._aiter_ops import rocm_aiter_ops` +# use `rocm_aiter_ops.is_asm_fp4_gemm_dynamic_quant_enabled()` +# for envs checks which does not require @cache anymore. +# triton kernel is torch compile compatible. +# does not require direct registeration. +# use `rocm_aiter_ops.triton_fp4_gemm_dynamic_qaunt`. @cache def is_rocm_aiter_fp4_asm_gemm_enabled() -> bool: return ( diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index 7fecda2166ef..63726c07b7d1 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -12,6 +12,7 @@ import vllm.envs as envs from vllm import _custom_ops as ops +from vllm._aiter_ops import rocm_aiter_ops from vllm.logger import init_logger from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 from vllm.model_executor.layers.quantization.utils.quant_utils import ( @@ -68,78 +69,6 @@ def cutlass_scaled_mm( ) -def rocm_aiter_gemm_w8a8_blockscale_impl( - input_2d: torch.Tensor, - weight: torch.Tensor, - input_scale: torch.Tensor, - weight_scale: torch.Tensor, - group_size: int, - output_dtype: torch.dtype = torch.float16, -) -> torch.Tensor: - def is_aiter_triton_kernel_tuned(n, k): - return (n, k) in [ - (1024, 8192), - (2112, 7168), - (3072, 1536), - (32768, 8192), - (4096, 7168), - (4608, 7168), - (512, 7168), - (7168, 2048), - (7168, 256), - (8192, 1024), - (8192, 32768), - ] - - n, k = weight.shape - if input_scale is not None: - q_input = input_2d - elif not current_platform.is_fp8_fnuz() and is_aiter_triton_kernel_tuned(n, k): - from aiter.ops.triton.gemm_a8w8_blockscale import gemm_a8w8_blockscale - - # MI350 case uses triton kernel - q_input, input_scale = per_token_group_quant_fp8( - input_2d, - group_size, - column_major_scales=False, - use_ue8m0=False, - ) - else: - # MI300 uses tuned AITER ASM/C++ kernel - import aiter as rocm_aiter - from aiter import gemm_a8w8_blockscale, get_hip_quant - - aiter_per1x128_quant = get_hip_quant(rocm_aiter.QuantType.per_1x128) - q_input, input_scale = aiter_per1x128_quant( - input_2d.contiguous(), quant_dtype=rocm_aiter.dtypes.fp8 - ) - - return gemm_a8w8_blockscale( - q_input, weight, input_scale, weight_scale, dtype=output_dtype - ) - - -def rocm_aiter_gemm_w8a8_blockscale_fake( - input_2d: torch.Tensor, - weight: torch.Tensor, - input_scale: torch.Tensor, - weight_scale: torch.Tensor, - group_size: int, - output_dtype: torch.dtype = torch.float16, -) -> torch.Tensor: - m = input_2d.shape[0] - n = weight.shape[0] - return torch.empty(m, n, dtype=output_dtype, device=input_2d.device) - - -if current_platform.is_rocm(): - direct_register_custom_op( - op_name="rocm_aiter_gemm_w8a8_blockscale", - op_func=rocm_aiter_gemm_w8a8_blockscale_impl, - fake_impl=rocm_aiter_gemm_w8a8_blockscale_fake, - ) - - # TODO we should be able to change the type of block_size to GroupShape # after we resolve GroupShape compilation issue # https://github.com/vllm-project/vllm/issues/25270 @@ -385,14 +314,40 @@ def _run_aiter( input_scale: torch.Tensor | None = None, ) -> torch.Tensor: assert self.act_quant_group_shape == GroupShape(1, 128) - return torch.ops.vllm.rocm_aiter_gemm_w8a8_blockscale( - input_2d, - weight, - input_scale, - weight_scale, - self.act_quant_group_shape.col, - input_2d.dtype, - ) + + n, k = weight.shape + if input_scale is not None: + q_input = input_2d + + # MI350 case uses triton kernel + if ( + not current_platform.is_fp8_fnuz() + and rocm_aiter_ops.is_triton_gemm_w8a8_tuned(n, k) + ): + q_input, input_scale = per_token_group_quant_fp8( + input_2d, + self.act_quant_group_shape.col, + column_major_scales=False, + use_ue8m0=False, + ) + return rocm_aiter_ops.triton_gemm_a8w8_blockscale( + q_input, + weight, + input_scale, + weight_scale, + input_2d.dtype, + ) + + # MI300 uses tuned AITER ASM/C++ kernel + else: + q_input, input_scale = rocm_aiter_ops.per_1x128_fp8_quant(input_2d) + return rocm_aiter_ops.gemm_w8a8_blockscale( + q_input, + weight, + input_scale, + weight_scale, + input_2d.dtype, + ) def _run_triton( self, @@ -971,15 +926,6 @@ def requant_weight_ue8m0_inplace( s_old.copy_(s_requant) -def check_aiter_fp8_linear_support() -> bool: - """AITER is only supported on ROCm for MI3XX""" - return ( - current_platform.is_rocm() - and envs.VLLM_ROCM_USE_AITER - and envs.VLLM_ROCM_USE_AITER_LINEAR - ) - - def _maybe_pad_fp8_weight(weight: torch.Tensor) -> torch.Tensor: """Pad the weight tensor. This is an optimization on ROCm platform, which can benefit from tensors located far enough from one another in memory""" diff --git a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py index 380431e86435..7fe902807a74 100644 --- a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py @@ -472,7 +472,7 @@ def apply( # Example: # When the number of token is 1, per-token scale is [[1]] # When per-tensor scale is [1] or (). - per_tensor_weights = (weight_scale.numel() == 1) and weight_scale.dim() < 2 + per_tensor_weights = weight_scale.numel() == 1 per_tensor_activations = (x_scale.numel() == 1) and x_scale.dim() < 2 # TODO(luka) do this dispatch during init (after ScaledMM refactor) diff --git a/vllm/model_executor/layers/rotary_embedding/base.py b/vllm/model_executor/layers/rotary_embedding/base.py index 91276320df4d..2ef54e75df44 100644 --- a/vllm/model_executor/layers/rotary_embedding/base.py +++ b/vllm/model_executor/layers/rotary_embedding/base.py @@ -4,13 +4,10 @@ import torch +from vllm._aiter_ops import rocm_aiter_ops from vllm.model_executor.custom_op import CustomOp from .common import apply_rotary_emb_torch -from .rocm_aiter_rope_ops import ( - is_rocm_triton_rotary_embedding_enabled, - rocm_aiter_rotary_emb, -) @CustomOp.register("rotary_embedding") @@ -48,8 +45,8 @@ def __init__( cache = cache.to(dtype) self.cos_sin_cache: torch.Tensor self.register_buffer("cos_sin_cache", cache, persistent=False) - self.is_rocm_triton_rotary_embedding_enabled = ( - is_rocm_triton_rotary_embedding_enabled() + self.is_rocm_triton_rotary_embed_enabled = ( + rocm_aiter_ops.is_triton_rotary_embed_enabled() ) def _compute_inv_freq(self, base: float) -> torch.Tensor: @@ -169,9 +166,9 @@ def forward_hip( query: torch.Tensor, key: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor | None]: - if self.is_rocm_triton_rotary_embedding_enabled: + if self.is_rocm_triton_rotary_embed_enabled: self._match_cos_sin_cache_dtype(query) - rocm_aiter_rotary_emb( + rocm_aiter_ops.triton_rotary_embed( positions, query, key, diff --git a/vllm/model_executor/layers/rotary_embedding/deepseek_scaling_rope.py b/vllm/model_executor/layers/rotary_embedding/deepseek_scaling_rope.py index d9134f05fddf..e72834e473c1 100644 --- a/vllm/model_executor/layers/rotary_embedding/deepseek_scaling_rope.py +++ b/vllm/model_executor/layers/rotary_embedding/deepseek_scaling_rope.py @@ -146,6 +146,15 @@ def forward_native( key = key_rot return query, key + def forward_hip( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor | None = None, + offsets: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor | None]: + return self.forward_native(positions, query, key, offsets) + def forward_cuda( self, positions: torch.Tensor, diff --git a/vllm/model_executor/layers/rotary_embedding/rocm_aiter_rope_ops.py b/vllm/model_executor/layers/rotary_embedding/rocm_aiter_rope_ops.py deleted file mode 100644 index a01d14f7b3a1..000000000000 --- a/vllm/model_executor/layers/rotary_embedding/rocm_aiter_rope_ops.py +++ /dev/null @@ -1,94 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import torch - -import vllm.envs as envs -from vllm.platforms import current_platform -from vllm.utils.torch_utils import direct_register_custom_op - - -def is_rocm_triton_rotary_embedding_enabled() -> bool: - return ( - current_platform.is_rocm() - and envs.VLLM_ROCM_USE_AITER - and envs.VLLM_ROCM_USE_TRITON_ROPE - ) - - -def rocm_aiter_rotary_emb_with_key_forward_triton_impl( - positions: torch.Tensor, - sin: torch.Tensor, - cos: torch.Tensor, - query: torch.Tensor, - key: torch.Tensor, - rotate_style: int = 0, - is_nope_first: bool = False, -) -> None: - import aiter.ops.triton.rope as ops - - ops.rope_cached_thd_positions_2c_fwd_inplace( - query, - key, - cos, - sin, - positions, - rotate_style, - reuse_freqs_front_part=True, - nope_first=is_nope_first, - ) - - -def rocm_aiter_rotary_emb_with_key_forward_triton_fake( - positions: torch.Tensor, - sin: torch.Tensor, - cos: torch.Tensor, - query: torch.Tensor, - key: torch.Tensor, - rotate_style: int = 0, - is_nope_first: bool = False, -) -> None: - pass - - -if is_rocm_triton_rotary_embedding_enabled(): - direct_register_custom_op( - op_name="rocm_aiter_rotary_emb_with_key_forward_triton", - op_func=rocm_aiter_rotary_emb_with_key_forward_triton_impl, - mutates_args=["key", "query"], - fake_impl=rocm_aiter_rotary_emb_with_key_forward_triton_fake, - dispatch_key=current_platform.dispatch_key, - ) - - -def rocm_aiter_rotary_emb( - positions: torch.Tensor, - query: torch.Tensor, - key: torch.Tensor, - cos_sin_cache: torch.Tensor, - head_size: int, - rotary_dim: int, - is_neox_style: bool, -): - num_tokens = positions.numel() - cos, sin = cos_sin_cache.chunk(2, dim=-1) - query_shape = query.shape - key_shape = key.shape - rotate_style = 0 if is_neox_style else 1 - - query = query.view(num_tokens, -1, head_size) - key = key.view(num_tokens, -1, head_size) - query_ = query[..., :rotary_dim] - key_ = key[..., :rotary_dim] - positions = positions.view(*query.shape[:1]) - torch.ops.vllm.rocm_aiter_rotary_emb_with_key_forward_triton( - positions, - sin, - cos, - query_, - key_, - rotate_style, - False, - ) - query = query.view(query_shape) - key = key.view(key_shape) diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index 63eaf63cc3c4..38189e17f7d8 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -33,6 +33,7 @@ from torch import nn from transformers import DeepseekV2Config, DeepseekV3Config +from vllm._aiter_ops import rocm_aiter_ops from vllm.attention import Attention from vllm.attention.backends.abstract import AttentionBackend from vllm.attention.ops.common import pack_seq_triton, unpack_seq_triton @@ -50,10 +51,6 @@ from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase from vllm.model_executor.layers.fused_moe import SharedFusedMoE -from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( - is_rocm_aiter_fusion_shared_expert_enabled, - is_rocm_aiter_moe_enabled, -) from vllm.model_executor.layers.layernorm import LayerNorm, RMSNorm from vllm.model_executor.layers.linear import ( ColumnParallelLinear, @@ -294,10 +291,8 @@ def __init__( self.physical_expert_start + self.n_local_physical_experts ) - if ( - config.n_shared_experts is None - or is_rocm_aiter_fusion_shared_expert_enabled() - ): + self.is_rocm_aiter_moe_enabled = rocm_aiter_ops.is_fused_moe_enabled() + if config.n_shared_experts is None or self.is_rocm_aiter_moe_enabled: self.shared_experts = None else: intermediate_size = config.moe_intermediate_size * config.n_shared_experts @@ -330,14 +325,14 @@ def __init__( # we do scaling outside, set factor to 1.0 to avoid double mul # aiter applies routed_scaling_factor internally routed_scaling_factor=1.0 - if not is_rocm_aiter_moe_enabled() + if not self.is_rocm_aiter_moe_enabled else self.routed_scaling_factor, e_score_correction_bias=self.gate.e_score_correction_bias, enable_eplb=self.enable_eplb, num_redundant_experts=self.n_redundant_experts, is_sequence_parallel=self.is_sequence_parallel, n_shared_experts=config.n_shared_experts - if is_rocm_aiter_fusion_shared_expert_enabled() + if rocm_aiter_ops.is_fusion_moe_shared_experts_enabled() else None, ) @@ -371,7 +366,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: # Fix FP16 overflow # See DeepseekV2DecoderLayer for more details. if hidden_states.dtype != torch.float16: - if not is_rocm_aiter_moe_enabled(): + if not self.is_rocm_aiter_moe_enabled: final_hidden_states *= self.routed_scaling_factor elif self.shared_experts is not None: assert shared_output is not None @@ -1428,6 +1423,9 @@ def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: ) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + rocm_aiter_moe_shared_expert_enabled = ( + rocm_aiter_ops.is_fusion_moe_shared_experts_enabled() + ) stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("gate_up_proj", "gate_proj", 0), @@ -1456,7 +1454,7 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: num_experts=self.config.n_routed_experts + ( self.config.n_shared_experts - if is_rocm_aiter_fusion_shared_expert_enabled() + if rocm_aiter_moe_shared_expert_enabled else 0 ), num_redundant_experts=self.num_redundant_experts, @@ -1472,9 +1470,8 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: if spec_layer is not None: continue # skip spec decode layers for main model - is_fuse_shared_experts_layer = ( - is_rocm_aiter_fusion_shared_expert_enabled() - and ("mlp.shared_experts" in name) + is_fuse_shared_experts_layer = rocm_aiter_moe_shared_expert_enabled and ( + "mlp.shared_experts" in name ) for param_name, weight_name, shard_id in stacked_params_mapping: diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index 1abd6300036d..e6536a02a73d 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -142,6 +142,8 @@ def use_rocm_custom_paged_attention( alibi_slopes: torch.Tensor | None = None, sinks: torch.Tensor | None = None, ) -> bool: + from vllm._aiter_ops import rocm_aiter_ops + GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName ON_GFX9 = any(arch in GPU_ARCH for arch in ["gfx90a", "gfx942", "gfx950"]) ON_GFX11_GFX12 = any(arch in GPU_ARCH for arch in ["gfx11", "gfx12"]) @@ -157,7 +159,7 @@ def use_rocm_custom_paged_attention( and (gqa_ratio >= 1 and gqa_ratio <= 16) and max_seq_len <= 128 * 1024 and (envs.VLLM_ROCM_CUSTOM_PAGED_ATTN) - and not (envs.VLLM_ROCM_USE_AITER_PAGED_ATTN and envs.VLLM_ROCM_USE_AITER) + and not (rocm_aiter_ops.is_pa_attn_enabled()) and sinks is None ) @@ -202,12 +204,15 @@ class RocmPlatform(Platform): ] @classmethod - def get_vit_attn_backend(cls, head_size: int, dtype: torch.dtype) -> "_Backend": + def get_vit_attn_backend(cls, head_size: int, dtype: torch.dtype) -> _Backend: from importlib.util import find_spec + from vllm._aiter_ops import rocm_aiter_ops from vllm.attention.backends.registry import _Backend - if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA and on_gfx9(): + if rocm_aiter_ops.is_mha_enabled(): + # Note: AITER FA is only supported for Qwen-VL models. + # TODO: Add support for other VL models in their model class. return _Backend.ROCM_AITER_FA if on_gfx9() and find_spec("flash_attn") is not None: @@ -228,19 +233,23 @@ def get_attn_backend_cls( has_sink, use_sparse, ) -> str: + from vllm._aiter_ops import rocm_aiter_ops from vllm.attention.backends.registry import _Backend if use_sparse: raise NotImplementedError("Sparse Attention is not supported on ROCm.") - if use_mla: - from vllm.v1.attention.backends.mla.rocm_aiter_mla import ( - is_aiter_mla_enabled, + + if not use_v1: + raise RuntimeError( + "V0 attention backends have been removed. Set VLLM_USE_V1=1 " + "to select a supported backend." ) + if use_mla: if selected_backend is None: selected_backend = ( _Backend.ROCM_AITER_MLA - if is_aiter_mla_enabled() or block_size == 1 + if rocm_aiter_ops.is_mla_enabled() or block_size == 1 else _Backend.TRITON_MLA ) @@ -265,12 +274,12 @@ def get_attn_backend_cls( logger.info("Using FlexAttention backend.") return "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend" if ( - envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA and on_gfx9() + rocm_aiter_ops.is_mha_enabled() ) or selected_backend == _Backend.ROCM_AITER_FA: logger.info("Using Aiter Flash Attention backend.") return "vllm.v1.attention.backends.rocm_aiter_fa.AiterFlashAttentionBackend" if ( - envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION + rocm_aiter_ops.is_triton_unified_attn_enabled() ) or selected_backend == _Backend.ROCM_AITER_UNIFIED_ATTN: logger.info("Using Aiter Unified Attention backend.") return ( diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index 40ce12c4bd75..e38f7bcfa44e 100755 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -198,6 +198,7 @@ import vllm.envs as envs from vllm import _custom_ops as ops +from vllm._aiter_ops import rocm_aiter_ops from vllm.attention.backends.abstract import ( AttentionBackend, AttentionLayer, @@ -270,28 +271,15 @@ class QueryLenSupport(Enum): flashinfer_available = False -def is_rocm_aiter_fp8bmm_enabled() -> bool: - return ( - current_platform.is_rocm() - and envs.VLLM_ROCM_USE_AITER_FP8BMM - and envs.VLLM_ROCM_USE_AITER - ) - - -if is_rocm_aiter_fp8bmm_enabled(): - from aiter.ops.triton.batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant import ( # noqa: E501 - batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant as aiter_triton_fp8_bmm, # noqa: E501 - ) - - def dynamic_per_batched_tensor_quant( - x: torch.Tensor, dtype: torch.dtype = torch.float8_e4m3fn - ): - DTYPE_MAX = torch.finfo(dtype).max - min_val, max_val = x.aminmax() - amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-10) - scale = DTYPE_MAX / amax - x_scl_sat = (x * scale).clamp(min=-DTYPE_MAX, max=DTYPE_MAX) - return x_scl_sat.to(dtype).contiguous(), scale.float().reciprocal() +def dynamic_per_batched_tensor_quant( + x: torch.Tensor, dtype: torch.dtype = torch.float8_e4m3fn +): + DTYPE_MAX = torch.finfo(dtype).max + min_val, max_val = x.aminmax() + amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-10) + scale = DTYPE_MAX / amax + x_scl_sat = (x * scale).clamp(min=-DTYPE_MAX, max=DTYPE_MAX) + return x_scl_sat.to(dtype).contiguous(), scale.float().reciprocal() logger = init_logger(__name__) @@ -1109,6 +1097,7 @@ def __init__( self.kv_b_proj = kv_b_proj self.indexer = indexer self.q_pad_num_heads = q_pad_num_heads + self.is_aiter_triton_fp8_bmm_enabled = rocm_aiter_ops.is_fp8bmm_enabled() def process_weights_after_loading(self, act_dtype: torch.dtype): def get_layer_weight(layer): @@ -1158,7 +1147,7 @@ def get_and_maybe_dequant_weights(layer: LinearBase): [self.qk_nope_head_dim, self.v_head_dim], dim=-1 ) - if is_rocm_aiter_fp8bmm_enabled(): + if self.is_aiter_triton_fp8_bmm_enabled: W_K = W_UK.transpose(0, 1) # 16 512 128 W_V = W_UV.permute(1, 2, 0) # 16 128 512 self.W_K, self.W_K_scale = dynamic_per_batched_tensor_quant( @@ -1187,7 +1176,7 @@ def get_and_maybe_dequant_weights(layer: LinearBase): dtype=torch.bfloat16, device=self.W_K.device, ) - aiter_triton_fp8_bmm( + rocm_aiter_ops.triton_fp8_bmm( x, self.W_K, self.W_K_scale, group_size=128, transpose_bm=True ) @@ -1196,7 +1185,7 @@ def get_and_maybe_dequant_weights(layer: LinearBase): dtype=torch.bfloat16, device=self.W_V.device, ) - aiter_triton_fp8_bmm( + rocm_aiter_ops.triton_fp8_bmm( x, self.W_V, self.W_V_scale, group_size=128, transpose_bm=True ) else: @@ -1208,10 +1197,9 @@ def get_and_maybe_dequant_weights(layer: LinearBase): def _v_up_proj(self, x: torch.Tensor, out: torch.Tensor): # Convert from (B, N, L) to (N, B, L) x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1) - - if is_rocm_aiter_fp8bmm_enabled(): + if self.is_aiter_triton_fp8_bmm_enabled: # Multiply + Transpose (N, B, L) x (N, L, V)->(N, B, V)->(B, N, V) - x = aiter_triton_fp8_bmm( + x = rocm_aiter_ops.triton_fp8_bmm( x, self.W_V, self.W_V_scale, group_size=128, transpose_bm=True ) # Convert from (B, N, V) to (B, N * V) @@ -1571,7 +1559,7 @@ def get_and_maybe_dequant_weights(layer: LinearBase): [self.qk_nope_head_dim, self.v_head_dim], dim=-1 ) - if is_rocm_aiter_fp8bmm_enabled(): + if self.is_aiter_triton_fp8_bmm_enabled: W_K = W_UK.transpose(0, 1) # 16 512 128 W_V = W_UV.permute(1, 2, 0) # 16 128 512 self.W_K, self.W_K_scale = dynamic_per_batched_tensor_quant( @@ -1600,7 +1588,7 @@ def get_and_maybe_dequant_weights(layer: LinearBase): dtype=torch.bfloat16, device=self.W_K.device, ) - aiter_triton_fp8_bmm( + rocm_aiter_ops.triton_fp8_bmm( x, self.W_K, self.W_K_scale, group_size=128, transpose_bm=True ) @@ -1609,7 +1597,7 @@ def get_and_maybe_dequant_weights(layer: LinearBase): dtype=torch.bfloat16, device=self.W_V.device, ) - aiter_triton_fp8_bmm( + rocm_aiter_ops.triton_fp8_bmm( x, self.W_V, self.W_V_scale, group_size=128, transpose_bm=True ) else: @@ -1958,7 +1946,6 @@ def forward( # Convert from (B, N, P) to (N, B, P) decode_q_nope = decode_q_nope.transpose(0, 1) - # Pads the head_dim if necessary (for the underlying kernel) if self.q_pad_num_heads is not None: B, N, L = decode_q_pe.shape decode_pe_padded = decode_q_pe.new_empty((B, self.q_pad_num_heads, L)) @@ -1966,9 +1953,9 @@ def forward( decode_pe_padded.copy_(decode_q_pe) decode_q_pe = decode_pe_padded - if is_rocm_aiter_fp8bmm_enabled(): + if self.is_aiter_triton_fp8_bmm_enabled: # Multiply+Transpose (N, B, P)x(N, P, L)->(N, B, L)->(B, N, L) - decode_ql_nope = aiter_triton_fp8_bmm( + decode_ql_nope = rocm_aiter_ops.triton_fp8_bmm( decode_q_nope, self.W_K, self.W_K_scale, diff --git a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py index 4ad7236eb1be..5757aeadba05 100644 --- a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py +++ b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py @@ -6,9 +6,8 @@ import torch -import vllm.envs as envs +from vllm._aiter_ops import rocm_aiter_ops from vllm.attention.backends.abstract import AttentionLayer -from vllm.attention.ops.rocm_aiter_mla import aiter_mla_decode_fwd from vllm.config import VllmConfig from vllm.utils.math_utils import cdiv from vllm.v1.attention.backends.mla.common import ( @@ -22,10 +21,6 @@ from vllm.v1.kv_cache_interface import AttentionSpec -def is_aiter_mla_enabled() -> bool: - return envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MLA - - class AiterMLABackend(MLACommonBackend): @staticmethod def get_name() -> str: @@ -284,7 +279,7 @@ def _forward_decode( # max_seqlen_qo must be 1 except for MTP # TODO: Find the best value for MTP max_seqlen_qo = 1 - aiter_mla_decode_fwd( + rocm_aiter_ops.mla_decode_fwd( q, kv_buffer, o, From d0e186c16f0d62af8c128e2dc7c94cde1387ac02 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Tue, 11 Nov 2025 00:30:06 +0800 Subject: [PATCH 020/183] [V0 Deprecation] Remove unused `context_len` and `seq_len` from M-RoPE (#28395) Signed-off-by: DarkLight1337 --- vllm/model_executor/models/ernie45_vl.py | 3 --- vllm/model_executor/models/glm4_1v.py | 3 --- vllm/model_executor/models/glm4v.py | 3 --- vllm/model_executor/models/interfaces.py | 4 ---- vllm/model_executor/models/keye.py | 3 --- vllm/model_executor/models/keye_vl1_5.py | 3 --- vllm/model_executor/models/paddleocr_vl.py | 3 --- vllm/model_executor/models/qwen2_5_omni_thinker.py | 3 --- vllm/model_executor/models/qwen2_5_vl.py | 3 --- vllm/model_executor/models/qwen2_vl.py | 3 --- vllm/model_executor/models/qwen3_omni_moe_thinker.py | 2 -- vllm/model_executor/models/qwen3_vl.py | 4 +--- vllm/model_executor/models/transformers/multimodal.py | 4 +--- 13 files changed, 2 insertions(+), 39 deletions(-) diff --git a/vllm/model_executor/models/ernie45_vl.py b/vllm/model_executor/models/ernie45_vl.py index 7c1eba103ae7..f287cff12086 100644 --- a/vllm/model_executor/models/ernie45_vl.py +++ b/vllm/model_executor/models/ernie45_vl.py @@ -1435,8 +1435,6 @@ def get_mrope_input_positions( hf_config: PretrainedConfig, image_grid_thw: list[list[int]] | torch.Tensor, video_grid_thw: list[list[int]] | torch.Tensor, - context_len: int = 0, - seq_len: int | None = None, second_per_grid_ts: list[float] | None = None, audio_feature_lengths: torch.Tensor | None = None, use_audio_in_video: bool = False, @@ -1569,7 +1567,6 @@ def get_mrope_input_positions( llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1)) llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) - llm_positions = llm_positions[:, context_len:seq_len] mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item() return llm_positions, mrope_position_delta diff --git a/vllm/model_executor/models/glm4_1v.py b/vllm/model_executor/models/glm4_1v.py index 121e84469c52..b9cd3545ec45 100644 --- a/vllm/model_executor/models/glm4_1v.py +++ b/vllm/model_executor/models/glm4_1v.py @@ -1622,8 +1622,6 @@ def get_mrope_input_positions( image_grid_thw: list[list[int]] | torch.Tensor | None, video_grid_thw: list[list[int]] | torch.Tensor | None, second_per_grid_ts: list[float] | None = None, - context_len: int = 0, - seq_len: int | None = None, audio_feature_lengths: torch.Tensor | None = None, use_audio_in_video: bool = False, ) -> tuple[torch.Tensor, int]: @@ -1754,7 +1752,6 @@ def get_mrope_input_positions( llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1)) llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) - llm_positions = llm_positions[:, context_len:seq_len] mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item() return llm_positions, mrope_position_delta diff --git a/vllm/model_executor/models/glm4v.py b/vllm/model_executor/models/glm4v.py index 2de1e4810952..ebf6934dddea 100644 --- a/vllm/model_executor/models/glm4v.py +++ b/vllm/model_executor/models/glm4v.py @@ -625,8 +625,6 @@ def get_mrope_input_positions( hf_config: PretrainedConfig, image_grid_thw: list[list[int]] | torch.Tensor, video_grid_thw: list[list[int]] | torch.Tensor, - context_len: int = 0, - seq_len: int | None = None, second_per_grid_ts: list[float] | None = None, audio_feature_lengths: torch.Tensor | None = None, use_audio_in_video: bool = False, @@ -758,7 +756,6 @@ def get_mrope_input_positions( llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1)) llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) - llm_positions = llm_positions[:, context_len:seq_len] mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item() return llm_positions, mrope_position_delta diff --git a/vllm/model_executor/models/interfaces.py b/vllm/model_executor/models/interfaces.py index b634c7ec7d67..d6a8f86d998b 100644 --- a/vllm/model_executor/models/interfaces.py +++ b/vllm/model_executor/models/interfaces.py @@ -995,8 +995,6 @@ def get_mrope_input_positions( image_grid_thw: list[list[int]] | torch.Tensor | None, video_grid_thw: list[list[int]] | torch.Tensor | None, second_per_grid_ts: list[float] | None = None, - context_len: int = 0, - seq_len: int | None = None, audio_feature_lengths: torch.Tensor | None = None, use_audio_in_video: bool = False, ) -> tuple[torch.Tensor, int]: @@ -1012,8 +1010,6 @@ def get_mrope_input_positions( image_grid_thw: Image grid dimensions (t, h, w) video_grid_thw: Video grid dimensions (t, h, w) second_per_grid_ts: Seconds per grid timestep for videos - context_len: Context length - seq_len: Sequence length audio_feature_lengths: Audio feature lengths for multimodal models use_audio_in_video: Whether to use audio in video for interleaving diff --git a/vllm/model_executor/models/keye.py b/vllm/model_executor/models/keye.py index 5f8659a3064e..42f16ad9f3b3 100644 --- a/vllm/model_executor/models/keye.py +++ b/vllm/model_executor/models/keye.py @@ -1630,8 +1630,6 @@ def get_mrope_input_positions( hf_config: PretrainedConfig, image_grid_thw: list[list[int]] | torch.Tensor, video_grid_thw: list[list[int]] | torch.Tensor, - context_len: int = 0, - seq_len: int | None = None, second_per_grid_ts: list[float] | None = None, audio_feature_lengths: torch.Tensor | None = None, use_audio_in_video: bool = False, @@ -1759,6 +1757,5 @@ def split_thw(grid_thw: torch.Tensor | list[int]) -> list[list[int]]: llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item() - llm_positions = llm_positions[:, context_len:seq_len] return llm_positions, mrope_position_delta diff --git a/vllm/model_executor/models/keye_vl1_5.py b/vllm/model_executor/models/keye_vl1_5.py index 13e5b2d5f157..6f95a59d36d2 100644 --- a/vllm/model_executor/models/keye_vl1_5.py +++ b/vllm/model_executor/models/keye_vl1_5.py @@ -600,8 +600,6 @@ def get_mrope_input_positions( hf_config: PretrainedConfig, image_grid_thw: list[list[int]] | torch.Tensor, video_grid_thw: list[list[int]] | torch.Tensor, - context_len: int = 0, - seq_len: int | None = None, second_per_grid_ts: list[float] | None = None, audio_feature_lengths: torch.Tensor | None = None, use_audio_in_video: bool = False, @@ -729,6 +727,5 @@ def split_thw(grid_thw: torch.Tensor | list[int]) -> list[list[int]]: llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item() - llm_positions = llm_positions[:, context_len:seq_len] return llm_positions, mrope_position_delta diff --git a/vllm/model_executor/models/paddleocr_vl.py b/vllm/model_executor/models/paddleocr_vl.py index 377b41a35578..631475c964c0 100644 --- a/vllm/model_executor/models/paddleocr_vl.py +++ b/vllm/model_executor/models/paddleocr_vl.py @@ -1179,8 +1179,6 @@ def get_mrope_input_positions( image_grid_thw: list[list[int]] | torch.Tensor, video_grid_thw: list[list[int]] | torch.Tensor, second_per_grid_ts: list[float], - context_len: int = 0, - seq_len: int | None = None, audio_feature_lengths: torch.Tensor | None = None, use_audio_in_video: bool = False, ) -> tuple[torch.Tensor, int]: @@ -1293,7 +1291,6 @@ def get_mrope_input_positions( llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item() - llm_positions = llm_positions[:, context_len:seq_len] return llm_positions, mrope_position_delta diff --git a/vllm/model_executor/models/qwen2_5_omni_thinker.py b/vllm/model_executor/models/qwen2_5_omni_thinker.py index 7e970ebbe2bb..fac281d2caf4 100644 --- a/vllm/model_executor/models/qwen2_5_omni_thinker.py +++ b/vllm/model_executor/models/qwen2_5_omni_thinker.py @@ -927,8 +927,6 @@ def get_mrope_input_positions( image_grid_thw: list[list[int]] | torch.Tensor, video_grid_thw: list[list[int]] | torch.Tensor, second_per_grid_ts: list[float] | None = None, - context_len: int = 0, - seq_len: int | None = None, audio_feature_lengths: torch.Tensor | None = None, use_audio_in_video: bool = False, ) -> tuple[torch.Tensor, int]: @@ -1125,7 +1123,6 @@ def get_mrope_input_positions( mrope_position_delta = ( torch.cat(llm_pos_ids_list, dim=1).max() + 1 - len(src_item) ) - llm_positions = llm_positions[:, context_len:seq_len] return llm_positions, mrope_position_delta diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index d337f1606943..48834ba699e4 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -1118,8 +1118,6 @@ def get_mrope_input_positions( image_grid_thw: list[list[int]] | torch.Tensor, video_grid_thw: list[list[int]] | torch.Tensor, second_per_grid_ts: list[float], - context_len: int = 0, - seq_len: int | None = None, audio_feature_lengths: torch.Tensor | None = None, use_audio_in_video: bool = False, ) -> tuple[torch.Tensor, int]: @@ -1232,7 +1230,6 @@ def get_mrope_input_positions( llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item() - llm_positions = llm_positions[:, context_len:seq_len] return llm_positions, mrope_position_delta diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index 9206ac8f9d03..b3999e6c934e 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -1240,8 +1240,6 @@ def get_mrope_input_positions( image_grid_thw: list[list[int]] | torch.Tensor | None, video_grid_thw: list[list[int]] | torch.Tensor | None, second_per_grid_ts: list[float] | None = None, - context_len: int = 0, - seq_len: int | None = None, audio_feature_lengths: torch.Tensor | None = None, use_audio_in_video: bool = False, ) -> tuple[torch.Tensor, int]: @@ -1360,7 +1358,6 @@ def get_mrope_input_positions( llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item() - llm_positions = llm_positions[:, context_len:seq_len] return llm_positions, mrope_position_delta diff --git a/vllm/model_executor/models/qwen3_omni_moe_thinker.py b/vllm/model_executor/models/qwen3_omni_moe_thinker.py index f20e67902721..da489a812f55 100755 --- a/vllm/model_executor/models/qwen3_omni_moe_thinker.py +++ b/vllm/model_executor/models/qwen3_omni_moe_thinker.py @@ -1417,8 +1417,6 @@ def get_mrope_input_positions( image_grid_thw: list[list[int]] | torch.Tensor | None, video_grid_thw: list[list[int]] | torch.Tensor | None, second_per_grid_ts: list[float] | None = None, - context_len: int = 0, - seq_len: int | None = None, audio_feature_lengths: torch.Tensor | None = None, use_audio_in_video: bool = False, ) -> tuple[torch.Tensor, int]: diff --git a/vllm/model_executor/models/qwen3_vl.py b/vllm/model_executor/models/qwen3_vl.py index 2d8f431bb8fa..fe0124ef3258 100644 --- a/vllm/model_executor/models/qwen3_vl.py +++ b/vllm/model_executor/models/qwen3_vl.py @@ -1419,8 +1419,6 @@ def get_mrope_input_positions( hf_config: PretrainedConfig, image_grid_thw: list[list[int]] | torch.Tensor, video_grid_thw: list[list[int]] | torch.Tensor, - context_len: int = 0, - seq_len: int | None = None, second_per_grid_ts: list[float] | None = None, audio_feature_lengths: torch.Tensor | None = None, use_audio_in_video: bool = False, @@ -1519,7 +1517,7 @@ def get_mrope_input_positions( llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item() - llm_positions = llm_positions[:, context_len:seq_len] + return llm_positions, mrope_position_delta def get_language_model(self) -> torch.nn.Module: diff --git a/vllm/model_executor/models/transformers/multimodal.py b/vllm/model_executor/models/transformers/multimodal.py index 10abd8659536..476074542e6a 100644 --- a/vllm/model_executor/models/transformers/multimodal.py +++ b/vllm/model_executor/models/transformers/multimodal.py @@ -371,8 +371,6 @@ def get_mrope_input_positions( image_grid_thw: list[list[int]] | torch.Tensor | None, video_grid_thw: list[list[int]] | torch.Tensor | None, second_per_grid_ts: list[float] | None = None, - context_len: int = 0, - seq_len: int | None = None, audio_feature_lengths: torch.Tensor | None = None, use_audio_in_video: bool = False, ) -> tuple[torch.Tensor, int]: @@ -390,7 +388,7 @@ def get_mrope_input_positions( video_grid_thw=video_grid_thw, ) - mrope_positions = mrope_positions[:, 0, context_len:seq_len] + mrope_positions = mrope_positions[:, 0] mrope_position_delta = mrope_position_delta[0].item() return mrope_positions, mrope_position_delta From b039bfda8f72b442d42dbdac40f51572bf045ad1 Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Mon, 10 Nov 2025 12:21:52 -0500 Subject: [PATCH 021/183] [Bugfix] Fix persistent_masked_m_silu_mul_quant tests (#28366) Signed-off-by: Varun Sundar Rabindranath Co-authored-by: Varun Sundar Rabindranath --- csrc/quantization/activation_kernels.cu | 15 ++++++++++----- .../moe/test_silu_mul_fp8_quant_deep_gemm.py | 5 ++++- .../layers/fused_moe/batched_deep_gemm_moe.py | 3 ++- 3 files changed, 16 insertions(+), 7 deletions(-) diff --git a/csrc/quantization/activation_kernels.cu b/csrc/quantization/activation_kernels.cu index 6fcd246f63c5..2521b2797e2c 100644 --- a/csrc/quantization/activation_kernels.cu +++ b/csrc/quantization/activation_kernels.cu @@ -578,11 +578,13 @@ void persistent_masked_m_silu_mul_quant( // This kernel currently only supports H % 128 == 0 and assumes a // fixed GROUP_SIZE of 128. + static constexpr int GROUP_SIZE = 128; + TORCH_CHECK(input.dtype() == torch::kBFloat16); TORCH_CHECK(y_q.dtype() == torch::kFloat8_e4m3fn || y_q.dtype() == torch::kFloat8_e4m3fnuz); TORCH_CHECK(y_s.dtype() == torch::kFloat32); - TORCH_CHECK(input.size(-1) % 256 == 0); + TORCH_CHECK(input.size(-1) % (GROUP_SIZE * 2) == 0); using Idx_t = int64_t; @@ -601,8 +603,6 @@ void persistent_masked_m_silu_mul_quant( Idx_t stride_counts_e = tokens_per_expert.stride(0); - static constexpr int GROUP_SIZE = 128; - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); #define KERNEL(BLOCK_COUNT, USE_UE8M0, THREAD_COUNT, STAGES) \ @@ -628,21 +628,26 @@ void persistent_masked_m_silu_mul_quant( static constexpr int SILU_V2_BLOCK_COUNT = 132 * 32; + int const NUM_GROUPS = H / GROUP_SIZE; if (!use_ue8m0) { - if (H >= 4096) { + if (H >= 4096 && (NUM_GROUPS % 8 == 0)) { + /* 8 warps config */ static constexpr int NUM_STAGES = 4; static constexpr int THREAD_COUNT = 256; KERNEL(SILU_V2_BLOCK_COUNT, false, THREAD_COUNT, NUM_STAGES); } else { + /* 1 warp config */ static constexpr int THREAD_COUNT = 32; KERNEL(SILU_V2_BLOCK_COUNT, false, THREAD_COUNT, 2); } } else { - if (H >= 4096) { + if (H >= 4096 && (NUM_GROUPS % 8 == 0)) { + /* 8 warps config */ static constexpr int NUM_STAGES = 4; static constexpr int THREAD_COUNT = 256; KERNEL(SILU_V2_BLOCK_COUNT, true, THREAD_COUNT, NUM_STAGES); } else { + /* 1 warp config */ static constexpr int THREAD_COUNT = 32; KERNEL(SILU_V2_BLOCK_COUNT, true, THREAD_COUNT, 2); } diff --git a/tests/kernels/moe/test_silu_mul_fp8_quant_deep_gemm.py b/tests/kernels/moe/test_silu_mul_fp8_quant_deep_gemm.py index 97a55c37b9a3..420dbbffaac0 100644 --- a/tests/kernels/moe/test_silu_mul_fp8_quant_deep_gemm.py +++ b/tests/kernels/moe/test_silu_mul_fp8_quant_deep_gemm.py @@ -25,6 +25,7 @@ (8, 16, 128 * 2, fp8_dtype), (8, 16, 128 * 3, fp8_dtype), (8, 64, 7168, fp8_dtype), + (8, 128, 128 * 33, fp8_dtype), (8, 128, 7168, fp8_dtype), (8, 512, 7168, fp8_dtype), (8, 1024, 7168, fp8_dtype), @@ -54,8 +55,10 @@ def test_silu_mul_fp8_quant_deep_gemm(E, T, H, fp8_type): ) # Run the SiLU V2 kernel + # TODO (varun): use_e8m0 is set to false as the reference impl does + # not handle that case. y_q, y_s = persistent_masked_m_silu_mul_quant( - y, tokens_per_expert, group_size=group_size + y, tokens_per_expert, group_size=group_size, use_ue8m0=False ) torch.cuda.synchronize() diff --git a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py index 095ec966ea7e..b8a97e92ab79 100644 --- a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py @@ -100,6 +100,7 @@ def persistent_masked_m_silu_mul_quant( tokens_per_expert: torch.Tensor, # (E,) number of valid tokens per expert num_parallel_tokens=16, group_size: int = 128, + use_ue8m0: bool | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: """Quantize silu(y[..., :H]) * y[..., H:] to FP8 with group per-token scales y has shape (E, T, 2*H). The first half of the last dimension is @@ -164,7 +165,7 @@ def persistent_masked_m_silu_mul_quant( device=y.device, ) - use_ue8m0 = is_deep_gemm_e8m0_used() + use_ue8m0 = use_ue8m0 if use_ue8m0 is not None else is_deep_gemm_e8m0_used() cuda_arch = current_platform.get_device_capability( device_id=y.device.index From 34553b9d2702dd2a27a578fec819e88e76dcbfb4 Mon Sep 17 00:00:00 2001 From: jiahanc <173873397+jiahanc@users.noreply.github.com> Date: Mon, 10 Nov 2025 09:34:57 -0800 Subject: [PATCH 022/183] [Performance] Support FP8 flashinfer TRTLLM MOE on Qwen3 and Qwen-3next (#27492) Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com> --- .../model_executor/layers/fused_moe/config.py | 21 +++++++++++++++ .../layers/fused_moe/flashinfer_trtllm_moe.py | 26 +++++++++---------- vllm/model_executor/layers/fused_moe/layer.py | 20 ++++++++++++++ .../model_executor/layers/quantization/fp8.py | 14 +++++----- .../quantization/utils/flashinfer_utils.py | 23 +++++++++------- vllm/model_executor/models/qwen3_moe.py | 2 ++ vllm/model_executor/models/qwen3_next.py | 2 ++ 7 files changed, 78 insertions(+), 30 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/config.py b/vllm/model_executor/layers/fused_moe/config.py index cbc3caafcf2f..a7bd64b1c65e 100644 --- a/vllm/model_executor/layers/fused_moe/config.py +++ b/vllm/model_executor/layers/fused_moe/config.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from dataclasses import dataclass +from enum import IntEnum from typing import Optional, Union import torch @@ -91,6 +92,26 @@ def _quant_flags_to_group_shape( return a_shape, w_shape +# The type of method in top-K routing +# Please keep this in sync with the counterpart defined in https://github.com/flashinfer-ai/flashinfer/blob/main/include/flashinfer/trtllm/fused_moe/runner.h +class RoutingMethodType(IntEnum): + # Default: Softmax -> TopK + Default = (0,) + # Renormalize: TopK -> Softmax + Renormalize = (1,) + # DeepSeekV3: Sigmoid -> RoutingBiasAdd -> Top2 in group -> Top4 groups + # -> Top8 experts from the Top4 groups + DeepSeekV3 = (2,) + # Llama4: Top1 -> Sigmoid + Llama4 = (3,) + # RenormalizeNaive: Softmax -> TopK -> Renormalize + RenormalizeNaive = (4,) + # TopK: TopK (no softmax) + TopK = (5,) + # Unspecified + Unspecified = 6.0 + + @dataclass class FusedMoEQuantDesc: """ diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py b/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py index f21fe16c5108..51e06ac54f49 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py @@ -3,6 +3,7 @@ import torch +from vllm.model_executor.layers.fused_moe.config import RoutingMethodType from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( calculate_tile_tokens_dim, @@ -23,26 +24,24 @@ def flashinfer_fused_moe_blockscale_fp8( w2_weight_scale_inv: torch.Tensor, global_num_experts: int, top_k: int, - num_expert_group: int, - topk_group: int, + num_expert_group: int | None, + topk_group: int | None, intermediate_size: int, expert_offset: int, local_num_experts: int, block_shape: list[int], - routed_scaling: float = 1.0, + routing_method_type: int = RoutingMethodType.DeepSeekV3, + routed_scaling: float | None = 1.0, ) -> torch.Tensor: from vllm.utils.flashinfer import flashinfer_trtllm_fp8_block_scale_moe + topk_group = topk_group if topk_group is not None else 0 assert top_k <= global_num_experts - assert top_k <= 8 - assert topk_group <= 4 - assert global_num_experts > num_expert_group - assert global_num_experts % num_expert_group == 0 + assert top_k <= 10 assert global_num_experts % 4 == 0 - assert top_k < (topk_group * global_num_experts / num_expert_group) assert block_shape == [128, 128] - # Routing kernel expects #experts <= #threads 256 - assert global_num_experts <= 256 + # Routing kernel expects #experts <= #threads 512 + assert global_num_experts <= 512 a_q, a_sf = per_token_group_quant_fp8(x, block_shape[1]) # NOTE: scales of hidden states have to be transposed! @@ -64,10 +63,8 @@ def flashinfer_fused_moe_blockscale_fp8( local_expert_offset=expert_offset, local_num_experts=local_num_experts, routed_scaling_factor=routed_scaling, - tile_tokens_dim=calculate_tile_tokens_dim( - x.shape[0], top_k, global_num_experts - ), - routing_method_type=2, # DeepSeek-styled routing method + tile_tokens_dim=None, + routing_method_type=routing_method_type, use_shuffled_weight=False, ) @@ -88,6 +85,7 @@ def flashinfer_fused_moe_blockscale_fp8_fake( expert_offset: int, local_num_experts: int, block_shape: list[int], + routing_method_type: int, routed_scaling: float = 1.0, ) -> torch.Tensor: return torch.empty_like(x) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 45b0f50a7997..f86a93e30003 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -31,6 +31,7 @@ FusedMoEConfig, FusedMoEParallelConfig, FusedMoEQuantConfig, + RoutingMethodType, biased_moe_quant_config, ) from vllm.model_executor.layers.fused_moe.fused_moe import zero_experts_compute_triton @@ -1213,6 +1214,7 @@ def __init__( zero_expert_type: str | None = None, expert_mapping: list[tuple[str, str, int, str]] | None = None, n_shared_experts: int | None = None, + routing_method_type: int | None = None, ): super().__init__() @@ -1397,6 +1399,24 @@ def __init__( "Only softmax scoring function is supported for non-grouped topk." ) + # ToDo: Better logic to determine the routing method type + if routing_method_type is not None: + self.routing_method_type = routing_method_type + else: + if scoring_func == "sigmoid": + if self.use_grouped_topk: + self.routing_method_type = RoutingMethodType.DeepSeekV3 + elif self.top_k == 1: + self.routing_method_type = RoutingMethodType.Llama4 + elif self.scoring_func == "softmax": + self.routing_method_type = ( + RoutingMethodType.Renormalize + if not self.renormalize + else RoutingMethodType.RenormalizeNaive + ) + else: + self.routing_method_type = RoutingMethodType.TopK + self.moe_config: FusedMoEConfig = FusedMoEConfig( num_experts=self.global_num_experts, experts_per_token=top_k, diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index e4e1cbff712f..f5fc750baaea 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -28,6 +28,7 @@ ) from vllm.model_executor.layers.fused_moe.config import ( FusedMoEQuantConfig, + RoutingMethodType, fp8_w8a8_moe_quant_config, ) from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe @@ -1222,22 +1223,20 @@ def apply( assert activation == "silu", ( f"Expected 'silu' activation but got {activation}" ) - assert scoring_func == "sigmoid", ( - f"Expected 'sigmoid' scoring func but got {scoring_func}" - ) + if self.block_quant: import vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe # noqa: E501, F401 - assert ( - renormalize and use_grouped_topk and custom_routing_function is None - ) e_score_correction_bias = ( e_score_correction_bias.to(x.dtype) if e_score_correction_bias is not None else None ) + routing_method_type = layer.routing_method_type return torch.ops.vllm.flashinfer_fused_moe_blockscale_fp8( - routing_logits=router_logits.to(torch.float32), + routing_logits=router_logits.to(torch.float32) + if routing_method_type == RoutingMethodType.DeepSeekV3 + else router_logits, routing_bias=e_score_correction_bias, x=x, w13_weight=layer.w13_weight, @@ -1252,6 +1251,7 @@ def apply( expert_offset=layer.ep_rank * layer.local_num_experts, local_num_experts=layer.local_num_experts, block_shape=self.weight_block_size, + routing_method_type=routing_method_type, routed_scaling=routed_scaling_factor, ) else: diff --git a/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py b/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py index 50ea049c3d5a..e49d374f154d 100644 --- a/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py +++ b/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py @@ -27,20 +27,25 @@ class FlashinferMoeBackend(Enum): def calculate_tile_tokens_dim(num_tokens, top_k, num_experts): + from flashinfer import next_positive_power_of_2 + # FlashInfer 0.2.10 has issues with larger tile sizes. Set to 8 for now. # TODO: Revert this to dynamic calculation once a new version of FlashInfer # with the necessary kernels is released. tile_tokens_dim = 8 - # from flashinfer import next_positive_power_of_2 - - # # Guess tokens per expert assuming perfect expert distribution first. - # num_tokens_per_expert = (num_tokens * top_k) // num_experts - # # And pad the number to the next power of 2. - # tile_tokens_dim = next_positive_power_of_2(num_tokens_per_expert) - # # Cap to 8-64 tokens per CTA tile as it's the range supported by the - # # kernel. - # tile_tokens_dim = min(max(tile_tokens_dim, 8), 64) + # A factor considering tokens are not perfectly balanced among experts. + imbalance_factor = 1.3 + # Calculate the number of tokens per expert + # assuming perfect distribution. + num_tokens_per_expert = (num_tokens * top_k) // num_experts + # Apply the imbalance factor. + num_tokens_per_expert = int(num_tokens_per_expert * imbalance_factor) + # And pad the number to the next power of 2. + tile_tokens_dim = next_positive_power_of_2(num_tokens_per_expert) + # Cap to 8-max_tile_tokens_dim tokens per CTA tile + # as it's the range supported by the kernel. + tile_tokens_dim = min(max(tile_tokens_dim, 8), 64) return tile_tokens_dim diff --git a/vllm/model_executor/models/qwen3_moe.py b/vllm/model_executor/models/qwen3_moe.py index a7e6772bb708..d57b82cb0227 100644 --- a/vllm/model_executor/models/qwen3_moe.py +++ b/vllm/model_executor/models/qwen3_moe.py @@ -43,6 +43,7 @@ from vllm.logger import init_logger from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.fused_moe.config import RoutingMethodType from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import ( MergedColumnParallelLinear, @@ -171,6 +172,7 @@ def __init__( enable_eplb=self.enable_eplb, num_redundant_experts=self.n_redundant_experts, is_sequence_parallel=self.is_sequence_parallel, + routing_method_type=RoutingMethodType.Renormalize, ) self.gate = ReplicatedLinear( diff --git a/vllm/model_executor/models/qwen3_next.py b/vllm/model_executor/models/qwen3_next.py index 55bbad7a8b27..aa7de5aa5f29 100644 --- a/vllm/model_executor/models/qwen3_next.py +++ b/vllm/model_executor/models/qwen3_next.py @@ -34,6 +34,7 @@ fused_recurrent_gated_delta_rule, ) from vllm.model_executor.layers.fused_moe import SharedFusedMoE +from vllm.model_executor.layers.fused_moe.config import RoutingMethodType from vllm.model_executor.layers.layernorm import ( GemmaRMSNorm as Qwen3NextRMSNorm, ) @@ -173,6 +174,7 @@ def __init__(self, vllm_config: VllmConfig, prefix: str = ""): enable_eplb=self.enable_eplb, num_redundant_experts=self.n_redundant_experts, is_sequence_parallel=self.is_sequence_parallel, + routing_method_type=RoutingMethodType.Renormalize, ) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: From 6d54336ae550528408e0c84cffb7857c426509f0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Delacourt?= <54138269+Flechman@users.noreply.github.com> Date: Mon, 10 Nov 2025 20:53:32 +0100 Subject: [PATCH 023/183] [Bugfix] Fix llguidance backend, rollback when EOS was encountered (#25905) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Rémi Delacourt Signed-off-by: remi Co-authored-by: Russell Bryant --- .../test_backend_guidance.py | 118 ++++++++++++++++++ vllm/v1/structured_output/backend_guidance.py | 10 +- 2 files changed, 126 insertions(+), 2 deletions(-) create mode 100644 tests/v1/structured_output/test_backend_guidance.py diff --git a/tests/v1/structured_output/test_backend_guidance.py b/tests/v1/structured_output/test_backend_guidance.py new file mode 100644 index 000000000000..771076186a3b --- /dev/null +++ b/tests/v1/structured_output/test_backend_guidance.py @@ -0,0 +1,118 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from transformers import AutoTokenizer + +from vllm.config import StructuredOutputsConfig, VllmConfig +from vllm.config.model import ModelConfig +from vllm.config.speculative import SpeculativeConfig +from vllm.sampling_params import SamplingParams, StructuredOutputsParams +from vllm.v1.request import Request +from vllm.v1.structured_output import StructuredOutputManager +from vllm.v1.structured_output.backend_guidance import GuidanceBackend +from vllm.v1.structured_output.backend_types import StructuredOutputOptions + +TOKENIZER = "gpt2" + + +def test_backend_guidance_rollback_terminated(): + # Test that the backend guidance successfully rollbacks from a + # terminated state. This can happen with speculative decoding, + # where the draft model proposes EOS and it is verified by the + # guidance backend. In that case we are in a stopped state, but + # it should be reverted in case EOS is not accepted by the target + # model. + vllm_config = VllmConfig( + decoding_config=StructuredOutputsConfig( + backend="guidance", + ) + ) + tokenizer = AutoTokenizer.from_pretrained(TOKENIZER) + + backend = GuidanceBackend( + vllm_config, + tokenizer=tokenizer, + vocab_size=50257, + ) + + grammar = backend.compile_grammar( + StructuredOutputOptions.JSON, '{"type": "object"}' + ) + + prompt = tokenizer.encode('{"a": "b"}') + assert len(prompt) > 1 + dummy_wrong = tokenizer.encode('{"a"}') + for token in prompt: + assert grammar.accept_tokens("", [token]) + assert not grammar.is_terminated() + assert grammar.accept_tokens("", [tokenizer.eos_token_id]) + assert grammar.is_terminated() + # Giving any other token should also be accepted + assert grammar.accept_tokens("", dummy_wrong) + # Rollback is done from where state was terminated, so from '}' not EOS + grammar.rollback(len(prompt) - 1) + assert not grammar.is_terminated() + assert grammar.validate_tokens([tokenizer.eos_token_id]) == [] + assert grammar.validate_tokens(dummy_wrong) != dummy_wrong + assert grammar.accept_tokens("", prompt[1:]) + assert not grammar.is_terminated() + assert grammar.accept_tokens("", [tokenizer.eos_token_id]) + assert grammar.is_terminated() + # Rollback of <= 0 should not change the terminated state + grammar.rollback(0) + assert grammar.is_terminated() + grammar.rollback(-1) + assert grammar.is_terminated() + + +def test_grammar_bitmask_with_specdec(): + tokenizer = AutoTokenizer.from_pretrained(TOKENIZER) + prompt = tokenizer.encode('{"a": "b"}') + vllm_config = VllmConfig( + model_config=ModelConfig(tokenizer=TOKENIZER), + structured_outputs_config=StructuredOutputsConfig(backend="guidance"), + speculative_config=SpeculativeConfig(model="[ngram]", num_speculative_tokens=3), + ) + structured_output_manager = StructuredOutputManager(vllm_config) + + for i in range(1, 2): + sampling_params = SamplingParams( + structured_outputs=StructuredOutputsParams( + json='{"type": "object"}', + ), + ) + sampling_params.structured_outputs._backend = "guidance" + + my_req_id = f"my_req_id_{i}" + request = Request( + my_req_id, + prompt_token_ids=prompt[:i], + sampling_params=sampling_params, + pooling_params=None, + eos_token_id=tokenizer.eos_token_id, + ) + + structured_output_manager.grammar_init(request) + + def grammar_bitmask(req: Request, tokens: list[int]) -> None: + structured_output_manager.grammar_bitmask( + requests={req.request_id: req}, + structured_output_request_ids={req.request_id: 0}, + scheduled_spec_decode_tokens={req.request_id: tokens}, + ) + # At this point, we rolled-back, so should not be terminated + assert not req.structured_output_request.grammar.is_terminated() + + # The grammar might not yet be compiled, so we wait for it + while not request.structured_output_request._check_grammar_completion(): + continue + + assert request.structured_output_request.grammar.accept_tokens( + request.request_id, prompt[:i] + ) + + grammar_bitmask(request, prompt[i:] + [tokenizer.eos_token_id]) + grammar_bitmask( + request, prompt[i:] + [tokenizer.eos_token_id] + prompt + ) # EOS not the final token + grammar_bitmask(request, prompt[i:]) # EOS not present + grammar_bitmask(request, prompt[i:] + [tokenizer.eos_token_id]) diff --git a/vllm/v1/structured_output/backend_guidance.py b/vllm/v1/structured_output/backend_guidance.py index 00a625e103bd..2962a439dcb3 100644 --- a/vllm/v1/structured_output/backend_guidance.py +++ b/vllm/v1/structured_output/backend_guidance.py @@ -111,6 +111,7 @@ class GuidanceGrammar(StructuredOutputGrammar): vocab_size: int printed_error: bool = False terminated: bool = False + rollback_lag: int = 0 def check_error(self): if not self.printed_error: @@ -127,6 +128,8 @@ def accept_tokens(self, request_id: str, tokens: list[int]) -> bool: """ if self.ll_tokenizer.eos_token in tokens: + if self.ll_matcher.is_stopped() and not self.terminated: + self.rollback_lag = 1 self.terminated = True if self.ll_matcher.is_stopped(): @@ -163,8 +166,11 @@ def validate_tokens(self, tokens: list[int]) -> list[int]: return tokens[:num_tokens] def rollback(self, num_tokens: int) -> None: - self.ll_matcher.rollback(num_tokens) - self.check_error() + if num_tokens > 0: + self.ll_matcher.rollback(num_tokens - self.rollback_lag) + self.terminated = False + self.rollback_lag = 0 + self.check_error() def fill_bitmask(self, bitmask: torch.Tensor, idx: int) -> None: # this will automatically return [EOS] mask if the matcher is stopped From 9c84ca8293034cdf8a324f7bec3a60101e0e12c2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jonas=20M=2E=20K=C3=BCbler?= <44084297+jmkuebler@users.noreply.github.com> Date: Mon, 10 Nov 2025 21:06:04 +0100 Subject: [PATCH 024/183] [FA/Chore] Bump FA version for FP8 two-level accumulation (#27889) Signed-off-by: Jonas Kuebler Co-authored-by: Lucas Wilkinson --- cmake/external_projects/vllm_flash_attn.cmake | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cmake/external_projects/vllm_flash_attn.cmake b/cmake/external_projects/vllm_flash_attn.cmake index 931090db50e9..29db9fa273a4 100644 --- a/cmake/external_projects/vllm_flash_attn.cmake +++ b/cmake/external_projects/vllm_flash_attn.cmake @@ -38,7 +38,7 @@ else() FetchContent_Declare( vllm-flash-attn GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git - GIT_TAG a893712401d70362fbb299cd9c4b3476e8e9ed54 + GIT_TAG 8e1b01d56210dc72030a2d0d41c2d8d266ba6309 GIT_PROGRESS TRUE # Don't share the vllm-flash-attn build between build types BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn From 40d33264c680a8c725b93db6ccce608f99e5c7da Mon Sep 17 00:00:00 2001 From: Sage Moore Date: Mon, 10 Nov 2025 12:39:19 -0800 Subject: [PATCH 025/183] [Bugfix][EPLB] Disabled shared expert overlap when EPLB is enabled (#28377) Signed-off-by: Sage Moore Signed-off-by: Sage Moore Signed-off-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com> --- .../layers/fused_moe/shared_fused_moe.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/shared_fused_moe.py b/vllm/model_executor/layers/fused_moe/shared_fused_moe.py index 6b4a0b8cf073..3d0c5636d6c0 100644 --- a/vllm/model_executor/layers/fused_moe/shared_fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/shared_fused_moe.py @@ -28,13 +28,18 @@ def __init__( super().__init__(**kwargs) self._shared_experts = shared_experts - # Disable shared expert overlap if we are not using - # flashinfer + DP since there is nothing to be gained in this case. - # Disabling the overlap optimization also prevents the shared experts - # from being hidden from torch.compile. + # Disable shared expert overlap if we are using eplb, because of + # correctness issues, or if using flashinfer with DP, since there + # is nothing to be gained in this case. Disabling the overlap + # optimization also prevents the shared experts from being hidden + # from torch.compile. self.use_overlapped = ( use_overlapped - and not (self.use_flashinfer_cutlass_kernels and self.dp_size > 1) + and not ( + # TODO(wentao): find the root cause and remove this condition + self.enable_eplb + or (self.use_flashinfer_cutlass_kernels and self.dp_size > 1) + ) and self._shared_experts is not None ) From bf6a3d0ff5a69e0a30567f2ad417530c002eaa4e Mon Sep 17 00:00:00 2001 From: Wei Wei Date: Mon, 10 Nov 2025 13:03:21 -0800 Subject: [PATCH 026/183] [Misc] Add more scoping for improved trace (#28329) Signed-off-by: Wei Wei --- vllm/v1/core/sched/scheduler.py | 116 ++++++++++++++-------------- vllm/v1/engine/core.py | 117 ++++++++++++++++++----------- vllm/v1/engine/llm_engine.py | 37 +++++---- vllm/v1/worker/gpu_model_runner.py | 70 +++++++++-------- 4 files changed, 192 insertions(+), 148 deletions(-) diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index c17b19b58c97..46dc1071b839 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -38,6 +38,7 @@ from vllm.v1.request import Request, RequestStatus from vllm.v1.spec_decode.metrics import SpecDecodingStats from vllm.v1.structured_output import StructuredOutputManager +from vllm.v1.utils import record_function_or_nullcontext logger = init_logger(__name__) @@ -259,49 +260,52 @@ def schedule(self) -> SchedulerOutput: continue # Schedule newly needed KV blocks for the request. - while True: - new_blocks = self.kv_cache_manager.allocate_slots( - request, - num_new_tokens, - num_lookahead_tokens=self.num_lookahead_tokens, - ) - - if new_blocks is not None: - # The request can be scheduled. - break - - # The request cannot be scheduled. - # Preempt the lowest-priority request. - if self.policy == SchedulingPolicy.PRIORITY: - preempted_req = max( - self.running, - key=lambda r: (r.priority, r.arrival_time), + with record_function_or_nullcontext("schedule: allocate_slots"): + while True: + new_blocks = self.kv_cache_manager.allocate_slots( + request, + num_new_tokens, + num_lookahead_tokens=self.num_lookahead_tokens, ) - self.running.remove(preempted_req) - if preempted_req in scheduled_running_reqs: - scheduled_running_reqs.remove(preempted_req) - token_budget += num_scheduled_tokens[preempted_req.request_id] - req_to_new_blocks.pop(preempted_req.request_id) - num_scheduled_tokens.pop(preempted_req.request_id) - req_index -= 1 - else: - preempted_req = self.running.pop() - self.kv_cache_manager.free(preempted_req) - self.encoder_cache_manager.free(preempted_req) - preempted_req.status = RequestStatus.PREEMPTED - preempted_req.num_computed_tokens = 0 - preempted_req.num_preemptions += 1 - if self.log_stats: - preempted_req.record_event( - EngineCoreEventType.PREEMPTED, scheduled_timestamp - ) + if new_blocks is not None: + # The request can be scheduled. + break - self.waiting.prepend_request(preempted_req) - preempted_reqs.append(preempted_req) - if preempted_req == request: - # No more request to preempt. Cannot schedule this request. - break + # The request cannot be scheduled. + # Preempt the lowest-priority request. + if self.policy == SchedulingPolicy.PRIORITY: + preempted_req = max( + self.running, + key=lambda r: (r.priority, r.arrival_time), + ) + self.running.remove(preempted_req) + if preempted_req in scheduled_running_reqs: + scheduled_running_reqs.remove(preempted_req) + token_budget += num_scheduled_tokens[ + preempted_req.request_id + ] + req_to_new_blocks.pop(preempted_req.request_id) + num_scheduled_tokens.pop(preempted_req.request_id) + req_index -= 1 + else: + preempted_req = self.running.pop() + + self.kv_cache_manager.free(preempted_req) + self.encoder_cache_manager.free(preempted_req) + preempted_req.status = RequestStatus.PREEMPTED + preempted_req.num_computed_tokens = 0 + preempted_req.num_preemptions += 1 + if self.log_stats: + preempted_req.record_event( + EngineCoreEventType.PREEMPTED, scheduled_timestamp + ) + + self.waiting.prepend_request(preempted_req) + preempted_reqs.append(preempted_req) + if preempted_req == request: + # No more request to preempt. Cannot schedule this request. + break if new_blocks is None: # Cannot schedule this request. @@ -599,13 +603,14 @@ def schedule(self) -> SchedulerOutput: # Get the longest common prefix among all requests in the running queue. # This can be potentially used for cascade attention. num_common_prefix_blocks = [0] * len(self.kv_cache_config.kv_cache_groups) - if self.running: - any_request = self.running[0] - num_common_prefix_blocks = ( - self.kv_cache_manager.get_num_common_prefix_blocks( - any_request.request_id + with record_function_or_nullcontext("schedule: get_num_common_prefix_blocks"): + if self.running: + any_request = self.running[0] + num_common_prefix_blocks = ( + self.kv_cache_manager.get_num_common_prefix_blocks( + any_request.request_id + ) ) - ) # Construct the scheduler output. new_reqs_data = [ @@ -614,13 +619,14 @@ def schedule(self) -> SchedulerOutput: ) for req in scheduled_new_reqs ] - cached_reqs_data = self._make_cached_request_data( - scheduled_running_reqs, - scheduled_resumed_reqs, - num_scheduled_tokens, - scheduled_spec_decode_tokens, - req_to_new_blocks, - ) + with record_function_or_nullcontext("schedule: make_cached_request_data"): + cached_reqs_data = self._make_cached_request_data( + scheduled_running_reqs, + scheduled_resumed_reqs, + num_scheduled_tokens, + scheduled_spec_decode_tokens, + req_to_new_blocks, + ) # Record the request ids that were scheduled in this step. self.prev_step_scheduled_req_ids.clear() @@ -649,8 +655,8 @@ def schedule(self) -> SchedulerOutput: if self.connector is not None: meta = self.connector.build_connector_meta(scheduler_output) scheduler_output.kv_connector_metadata = meta - - self._update_after_schedule(scheduler_output) + with record_function_or_nullcontext("schedule: update_after_schedule"): + self._update_after_schedule(scheduler_output) return scheduler_output def _update_after_schedule( diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index fba018432e0a..c3efd52130cc 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -61,6 +61,7 @@ from vllm.v1.request import Request, RequestStatus from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder from vllm.v1.structured_output import StructuredOutputManager +from vllm.v1.utils import record_function_or_nullcontext from vllm.version import __version__ as VLLM_VERSION logger = init_logger(__name__) @@ -315,17 +316,21 @@ def step(self) -> tuple[dict[int, EngineCoreOutputs], bool]: # or finished and not yet removed from the batch. if not self.scheduler.has_requests(): return {}, False - scheduler_output = self.scheduler.schedule() - future = self.model_executor.execute_model(scheduler_output, non_block=True) - grammar_output = self.scheduler.get_grammar_bitmask(scheduler_output) - with self.log_error_detail(scheduler_output): - model_output = future.result() - if model_output is None: - model_output = self.model_executor.sample_tokens(grammar_output) - - engine_core_outputs = self.scheduler.update_from_output( - scheduler_output, model_output - ) + with record_function_or_nullcontext("core step: schedule"): + scheduler_output = self.scheduler.schedule() + + with record_function_or_nullcontext("core step: execute_model"): + future = self.model_executor.execute_model(scheduler_output, non_block=True) + grammar_output = self.scheduler.get_grammar_bitmask(scheduler_output) + with self.log_error_detail(scheduler_output): + model_output = future.result() + if model_output is None: + model_output = self.model_executor.sample_tokens(grammar_output) + + with record_function_or_nullcontext("core step: update_from_output"): + engine_core_outputs = self.scheduler.update_from_output( + scheduler_output, model_output + ) return engine_core_outputs, scheduler_output.total_num_scheduled_tokens > 0 @@ -363,32 +368,49 @@ def step_with_batch_queue( model_executed = False deferred_scheduler_output = None if self.scheduler.has_requests(): - scheduler_output = self.scheduler.schedule() - exec_future = self.model_executor.execute_model( - scheduler_output, non_block=True - ) + with record_function_or_nullcontext("core step_with_batch_queue: schedule"): + scheduler_output = self.scheduler.schedule() + with record_function_or_nullcontext( + "core step_with_batch_queue: execute_model" + ): + exec_future = self.model_executor.execute_model( + scheduler_output, non_block=True + ) model_executed = scheduler_output.total_num_scheduled_tokens > 0 if scheduler_output.pending_structured_output_tokens: - # We need to defer sampling until we have processed the model output - # from the prior step. - deferred_scheduler_output = scheduler_output - # Block-wait for execute to return (continues running async on the GPU). - with self.log_error_detail(scheduler_output): - exec_result = exec_future.result() - assert exec_result is None + with record_function_or_nullcontext( + "core step_with_batch_queue: pending_structured_output_tokens" + ): + # We need to defer sampling until we have processed the model output + # from the prior step. + deferred_scheduler_output = scheduler_output + # Block-wait for execute to return + # (continues running async on the GPU). + with self.log_error_detail(scheduler_output): + exec_result = exec_future.result() + assert exec_result is None else: - # We aren't waiting for any tokens, get any grammar output immediately. - grammar_output = self.scheduler.get_grammar_bitmask(scheduler_output) + with record_function_or_nullcontext( + "core step_with_batch_queue: get_grammar_bitmask" + ): + # We aren't waiting for any tokens, get any grammar + # output immediately. + grammar_output = self.scheduler.get_grammar_bitmask( + scheduler_output + ) # Block-wait for execute to return (continues running async on the GPU). with self.log_error_detail(scheduler_output): exec_result = exec_future.result() if exec_result is None: - # Call sample tokens. - future = self.model_executor.sample_tokens( - grammar_output, non_block=True - ) + with record_function_or_nullcontext( + "core step_with_batch_queue: sample_tokens" + ): + # Call sample tokens. + future = self.model_executor.sample_tokens( + grammar_output, non_block=True + ) else: # No sampling required (e.g. all requests finished). future = cast(Future[ModelRunnerOutput], exec_future) @@ -408,27 +430,34 @@ def step_with_batch_queue( # only be called when the scheduler contains requests or the queue # is non-empty. return None, False - - # Block until the next result is available. - future, scheduler_output = batch_queue.pop() - with self.log_error_detail(scheduler_output): - model_output = future.result() - - engine_core_outputs = self.scheduler.update_from_output( - scheduler_output, model_output - ) + with record_function_or_nullcontext("core step_with_batch_queue: model_output"): + # Block until the next result is available. + future, scheduler_output = batch_queue.pop() + with self.log_error_detail(scheduler_output): + model_output = future.result() + with record_function_or_nullcontext( + "core step_with_batch_queue: update_from_output" + ): + engine_core_outputs = self.scheduler.update_from_output( + scheduler_output, model_output + ) # NOTE(nick): We can either handle the deferred tasks here or save # in a field and do it immediately once step_with_batch_queue is # re-called. The latter slightly favors TTFT over TPOT/throughput. if deferred_scheduler_output: - # We now have the tokens needed to compute the bitmask for the - # deferred request. Get the bitmask and call sample tokens. - grammar_output = self.scheduler.get_grammar_bitmask( - deferred_scheduler_output - ) - future = self.model_executor.sample_tokens(grammar_output, non_block=True) - batch_queue.appendleft((future, deferred_scheduler_output)) + with record_function_or_nullcontext( + "core step_with_batch_queue: deferred_scheduler_output" + ): + # We now have the tokens needed to compute the bitmask for the + # deferred request. Get the bitmask and call sample tokens. + grammar_output = self.scheduler.get_grammar_bitmask( + deferred_scheduler_output + ) + future = self.model_executor.sample_tokens( + grammar_output, non_block=True + ) + batch_queue.appendleft((future, deferred_scheduler_output)) return engine_core_outputs, model_executed diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py index e32c74aff313..d27d13840989 100644 --- a/vllm/v1/engine/llm_engine.py +++ b/vllm/v1/engine/llm_engine.py @@ -36,6 +36,7 @@ from vllm.v1.metrics.loggers import StatLoggerFactory, StatLoggerManager from vllm.v1.metrics.reader import Metric, get_metrics_snapshot from vllm.v1.metrics.stats import IterationStats +from vllm.v1.utils import record_function_or_nullcontext from vllm.v1.worker.worker_base import WorkerBase logger = init_logger(__name__) @@ -280,28 +281,32 @@ def step(self) -> list[RequestOutput | PoolingRequestOutput]: return [] # 1) Get EngineCoreOutput from the EngineCore. - outputs = self.engine_core.get_output() + with record_function_or_nullcontext("llm_genine step: get_output"): + outputs = self.engine_core.get_output() # 2) Process EngineCoreOutputs. - iteration_stats = IterationStats() if self.log_stats else None - processed_outputs = self.output_processor.process_outputs( - outputs.outputs, - engine_core_timestamp=outputs.timestamp, - iteration_stats=iteration_stats, - ) - self.output_processor.update_scheduler_stats(outputs.scheduler_stats) + with record_function_or_nullcontext("llm_genine step: process_outputs"): + iteration_stats = IterationStats() if self.log_stats else None + processed_outputs = self.output_processor.process_outputs( + outputs.outputs, + engine_core_timestamp=outputs.timestamp, + iteration_stats=iteration_stats, + ) + self.output_processor.update_scheduler_stats(outputs.scheduler_stats) # 3) Abort any reqs that finished due to stop strings. - self.engine_core.abort_requests(processed_outputs.reqs_to_abort) + with record_function_or_nullcontext("llm_genine step: abort_requests"): + self.engine_core.abort_requests(processed_outputs.reqs_to_abort) # 4) Record stats - if self.logger_manager is not None and outputs.scheduler_stats is not None: - self.logger_manager.record( - scheduler_stats=outputs.scheduler_stats, - iteration_stats=iteration_stats, - mm_cache_stats=self.processor.stat_mm_cache(), - ) - self.do_log_stats_with_interval() + with record_function_or_nullcontext("llm_genine step: record_stats"): + if self.logger_manager is not None and outputs.scheduler_stats is not None: + self.logger_manager.record( + scheduler_stats=outputs.scheduler_stats, + iteration_stats=iteration_stats, + mm_cache_stats=self.processor.stat_mm_cache(), + ) + self.do_log_stats_with_interval() return processed_outputs.request_outputs diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 26007d29d61b..9403b5756e05 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -2525,7 +2525,7 @@ def execute_model( "after execute_model() returns None." ) num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens - with record_function_or_nullcontext("Preprocess"): + with record_function_or_nullcontext("gpu_model_runner: preprocess"): with self.synchronize_input_prep(): # Update persistent batch states. self._update_states(scheduler_output) @@ -2648,7 +2648,7 @@ def execute_model( batch_descriptor=batch_descriptor, ubatch_slices=ubatch_slices, ), - record_function_or_nullcontext("Forward"), + record_function_or_nullcontext("gpu_model_runner: forward"), self.maybe_get_kv_connector_output(scheduler_output) as kv_connector_output, ): model_output = self._model_forward( @@ -2659,7 +2659,7 @@ def execute_model( **model_kwargs, ) - with record_function_or_nullcontext("Postprocess"): + with record_function_or_nullcontext("gpu_model_runner: postprocess"): if self.use_aux_hidden_state_outputs: # True when EAGLE 3 is used. hidden_states, aux_hidden_states = model_output @@ -2756,12 +2756,12 @@ def sample_tokens( scheduler_output, grammar_output, self.input_batch, logits ) - with record_function_or_nullcontext("Sample"): + with record_function_or_nullcontext("gpu_model_runner: sample"): sampler_output = self._sample(logits, spec_decode_metadata) def propose_draft_token_ids(sampled_token_ids): assert spec_decode_common_attn_metadata is not None - with record_function_or_nullcontext("Draft"): + with record_function_or_nullcontext("gpu_model_runner: draft"): self._draft_token_ids = self.propose_draft_token_ids( scheduler_output, sampled_token_ids, @@ -2799,7 +2799,7 @@ def propose_draft_token_ids(sampled_token_ids): # as inputs, and does not need to wait for bookkeeping to finish. propose_draft_token_ids(sampler_output.sampled_token_ids) - with record_function_or_nullcontext("Bookkeep"): + with record_function_or_nullcontext("gpu_model_runner: bookkeep"): ( num_nans_in_logits, logprobs_lists, @@ -2826,37 +2826,41 @@ def propose_draft_token_ids(sampled_token_ids): # tokens on the CPU, so they are run after bookkeeping. propose_draft_token_ids(valid_sampled_token_ids) - with record_function_or_nullcontext("EPLB"): + with record_function_or_nullcontext("gpu_model_runner: eplb"): self.eplb_step() - - output = ModelRunnerOutput( - req_ids=req_ids_output_copy, - req_id_to_index=req_id_to_index_output_copy, - sampled_token_ids=valid_sampled_token_ids, - logprobs=logprobs_lists, - prompt_logprobs_dict=prompt_logprobs_dict, - pooler_output=[], - kv_connector_output=kv_connector_output, - num_nans_in_logits=num_nans_in_logits, - ) + with record_function_or_nullcontext("gpu_model_runner: ModelRunnerOutput"): + output = ModelRunnerOutput( + req_ids=req_ids_output_copy, + req_id_to_index=req_id_to_index_output_copy, + sampled_token_ids=valid_sampled_token_ids, + logprobs=logprobs_lists, + prompt_logprobs_dict=prompt_logprobs_dict, + pooler_output=[], + kv_connector_output=kv_connector_output, + num_nans_in_logits=num_nans_in_logits, + ) if not self.use_async_scheduling: return output - - async_output = AsyncGPUModelRunnerOutput( - model_runner_output=output, - sampled_token_ids=sampler_output.sampled_token_ids, - logprobs_tensors=sampler_output.logprobs_tensors, - invalid_req_indices=invalid_req_indices, - async_output_copy_stream=self.async_output_copy_stream, - ) - - # Save ref of sampled_token_ids CPU tensor if the batch contains - # any requests with sampling params that that require output ids. - self.input_batch.set_async_sampled_token_ids( - async_output.sampled_token_ids_cpu, - async_output.async_copy_ready_event, - ) + with record_function_or_nullcontext( + "gpu_model_runner: AsyncGPUModelRunnerOutput" + ): + async_output = AsyncGPUModelRunnerOutput( + model_runner_output=output, + sampled_token_ids=sampler_output.sampled_token_ids, + logprobs_tensors=sampler_output.logprobs_tensors, + invalid_req_indices=invalid_req_indices, + async_output_copy_stream=self.async_output_copy_stream, + ) + with record_function_or_nullcontext( + "gpu_model_runner: set_async_sampled_token_ids" + ): + # Save ref of sampled_token_ids CPU tensor if the batch contains + # any requests with sampling params that that require output ids. + self.input_batch.set_async_sampled_token_ids( + async_output.sampled_token_ids_cpu, + async_output.async_copy_ready_event, + ) return async_output From 6dec9f61098786690b4ca2140682dbafb849f8d9 Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Mon, 10 Nov 2025 17:01:17 -0500 Subject: [PATCH 027/183] [BugFix] Fix DeepGEMM over-allocating workspace (#28254) Signed-off-by: Lucas Wilkinson --- vllm/model_executor/layers/fused_moe/deep_gemm_moe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py index 484b8aa9d107..86cdd25f2c87 100644 --- a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py @@ -215,7 +215,7 @@ def workspace_shapes( ) assert M_sum % block_m == 0 - workspace1 = (M_sum, max(N, K)) + workspace1 = (M_sum, N) workspace2 = (M_sum, max(N // 2, K)) output = (M, K) return (workspace1, workspace2, output) From 4b94ed8f928533b1f7c3a0293790ccb592b49f1a Mon Sep 17 00:00:00 2001 From: Andrew Xia Date: Mon, 10 Nov 2025 14:07:49 -0800 Subject: [PATCH 028/183] [Frontend][2/n] remove empty content from _parse_tool_calls_from_content (#28331) Signed-off-by: Andrew Xia Co-authored-by: Andrew Xia --- vllm/entrypoints/openai/serving_engine.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index 8ce4ff574699..30b8499b08d5 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -1375,6 +1375,8 @@ def _parse_tool_calls_from_content( for tool_call in tool_call_info.tool_calls ) content = tool_call_info.content + if content and content.strip() == "": + content = None else: # No tool calls. return None, content From 30700b1cd7de51f191be718215a58f5a8ddcb8aa Mon Sep 17 00:00:00 2001 From: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com> Date: Mon, 10 Nov 2025 17:36:11 -0500 Subject: [PATCH 029/183] [CI] Fix Plugin Tests Tests (#28413) Signed-off-by: Robert Shaw --- vllm/config/vllm.py | 24 +++++++++++++----------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index d4ee6f980e6e..0fca967d9083 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -608,17 +608,19 @@ def __post_init__(self): ) current_platform.check_and_update_config(self) - assert ( - self.parallel_config.dcp_kv_cache_interleave_size - <= self.cache_config.block_size - and self.cache_config.block_size - % self.parallel_config.dcp_kv_cache_interleave_size - == 0 - ), ( - f"Block_size({self.cache_config.block_size}) should be " - "greater than or equal to and divisible by dcp_kv_cache_interleave_size " - f"({self.parallel_config.dcp_kv_cache_interleave_size})." - ) + # If DCP, ensure the block size is right. + if self.parallel_config.decode_context_parallel_size > 1: + assert ( + self.parallel_config.dcp_kv_cache_interleave_size + <= self.cache_config.block_size + and self.cache_config.block_size + % self.parallel_config.dcp_kv_cache_interleave_size + == 0 + ), ( + f"Block_size({self.cache_config.block_size}) should be greater " + "than or equal to and divisible by dcp_kv_cache_interleave_size " + f"({self.parallel_config.dcp_kv_cache_interleave_size})." + ) assert ( self.parallel_config.dcp_kv_cache_interleave_size == 1 From 021143561fcffa9bee133d0b3bd311bc5cb3703c Mon Sep 17 00:00:00 2001 From: Yong Hoon Shin <48474650+sarckk@users.noreply.github.com> Date: Mon, 10 Nov 2025 13:13:36 -1000 Subject: [PATCH 030/183] [ROCm] Add missing gemm_a8w8_blockscale import (#28378) Signed-off-by: Yong Hoon Shin --- .../layers/quantization/utils/fp8_utils.py | 41 ++++++++++--------- 1 file changed, 21 insertions(+), 20 deletions(-) diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index 63726c07b7d1..c63196b89357 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -316,38 +316,39 @@ def _run_aiter( assert self.act_quant_group_shape == GroupShape(1, 128) n, k = weight.shape - if input_scale is not None: - q_input = input_2d - # MI350 case uses triton kernel - if ( + use_triton = ( not current_platform.is_fp8_fnuz() and rocm_aiter_ops.is_triton_gemm_w8a8_tuned(n, k) - ): + ) + + if use_triton: + gemm_a8w8_blockscale_op = rocm_aiter_ops.triton_gemm_a8w8_blockscale + else: + gemm_a8w8_blockscale_op = rocm_aiter_ops.gemm_w8a8_blockscale + + if input_scale is not None: + q_input = input_2d + # MI350 case uses triton kernel + elif use_triton: q_input, input_scale = per_token_group_quant_fp8( input_2d, self.act_quant_group_shape.col, column_major_scales=False, use_ue8m0=False, ) - return rocm_aiter_ops.triton_gemm_a8w8_blockscale( - q_input, - weight, - input_scale, - weight_scale, - input_2d.dtype, - ) - # MI300 uses tuned AITER ASM/C++ kernel else: q_input, input_scale = rocm_aiter_ops.per_1x128_fp8_quant(input_2d) - return rocm_aiter_ops.gemm_w8a8_blockscale( - q_input, - weight, - input_scale, - weight_scale, - input_2d.dtype, - ) + + return gemm_a8w8_blockscale_op( + q_input, + weight, + input_scale, + weight_scale, + list(self.weight_group_shape), + output_dtype=input_2d.dtype, + ) def _run_triton( self, From d17ecc6b19b597615893be6c0eb61c9b4a9c9455 Mon Sep 17 00:00:00 2001 From: Ilya Markov Date: Tue, 11 Nov 2025 00:33:11 +0100 Subject: [PATCH 031/183] [PERF] Allreduce fusion. Support torch native matching. Tuning of the thresholds (#24248) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Luka Govedič Signed-off-by: Luka Govedič Signed-off-by: ilmarkov Co-authored-by: Luka Govedič Co-authored-by: Luka Govedič Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com> --- .buildkite/test-pipeline.yaml | 4 +- .../kernels/benchmark_fused_collective.py | 1129 +++++++++++++++++ tests/compile/test_fusions_e2e.py | 7 + vllm/compilation/collective_fusion.py | 132 +- vllm/config/compilation.py | 50 +- vllm/model_executor/layers/fused_moe/layer.py | 45 +- 6 files changed, 1284 insertions(+), 83 deletions(-) create mode 100644 benchmarks/kernels/benchmark_fused_collective.py diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index b81c090fa471..3152cd6488f3 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -463,8 +463,8 @@ steps: - pytest -v -s compile/test_multimodal_compile.py - pytest -v -s compile/piecewise/ -- label: PyTorch Fullgraph Test # 22min - timeout_in_minutes: 35 +- label: PyTorch Fullgraph Test # 27min + timeout_in_minutes: 40 mirror_hardwares: [amdexperimental] torch_nightly: true source_file_dependencies: diff --git a/benchmarks/kernels/benchmark_fused_collective.py b/benchmarks/kernels/benchmark_fused_collective.py new file mode 100644 index 000000000000..38e7fdcf5542 --- /dev/null +++ b/benchmarks/kernels/benchmark_fused_collective.py @@ -0,0 +1,1129 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +""" +Benchmark for FlashInfer fused collective operations vs standard operations. + +This benchmark compares: +1. FlashInfer's trtllm_allreduce_fusion (fused allreduce + rmsnorm + optional quant) +2. Standard tensor_model_parallel_all_reduce + separate rmsnorm/quant operations + +Usage with torchrun: + torchrun --nproc_per_node=2 benchmark_fused_collective.py + +""" + +import argparse +import itertools +import os +import time + +import pandas as pd +import torch # type: ignore +import torch.distributed as dist # type: ignore + +from vllm.config.vllm import CompilationConfig, VllmConfig, set_current_vllm_config +from vllm.distributed import ( + get_tp_group, + tensor_model_parallel_all_reduce, +) +from vllm.distributed.parallel_state import ( + graph_capture, + init_distributed_environment, + initialize_model_parallel, +) +from vllm.logger import init_logger +from vllm.model_executor.layers.layernorm import RMSNorm # noqa +from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 # noqa +from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape # noqa +from vllm.platforms import current_platform # noqa + +RMS_NORM_OP = torch.ops._C.rms_norm +FUSED_ADD_RMS_NORM_OP = torch.ops._C.fused_add_rms_norm +RMS_NORM_STATIC_FP8_QUANT_OP = torch.ops._C.rms_norm_static_fp8_quant +FUSED_ADD_RMS_NORM_STATIC_FP8_QUANT_OP = ( + torch.ops._C.fused_add_rms_norm_static_fp8_quant +) +SCALED_FP4_QUANT_OP = torch.ops._C.scaled_fp4_quant + +logger = init_logger(__name__) + +# Try to import FlashInfer +try: + import flashinfer.comm as flashinfer_comm # type: ignore + + if not hasattr(flashinfer_comm, "trtllm_allreduce_fusion"): + flashinfer_comm = None + logger.warning( + "FlashInfer comm module found but missing trtllm_allreduce_fusion" + ) +except ImportError: + flashinfer_comm = None + logger.warning("FlashInfer not found, only benchmarking standard operations") + +# Constants +FP8_DTYPE = current_platform.fp8_dtype() +MiB = 1024 * 1024 + +# FlashInfer max sizes per world size +# Enable 64MB for 2, 4, 8 world sizes to verify large input sizes +# use --disable-oneshot to disable oneshot mode for very large input sizes +_FI_MAX_SIZES = { + 2: 64 * MiB, # 64MB + 4: 64 * MiB, # 64MB + 8: 64 * MiB, # 64MB +} + +# Global workspace tensor for FlashInfer +_FI_WORKSPACE_TENSOR = None + + +def setup_flashinfer_workspace( + world_size: int, + rank: int, + hidden_dim: int, + max_token_num: int, + use_fp32_lamport: bool = False, +): + """Setup FlashInfer workspace for fused allreduce operations.""" + global _FI_WORKSPACE_TENSOR + + if flashinfer_comm is None: + return None, None + + if world_size not in _FI_MAX_SIZES: + logger.warning("FlashInfer not supported for world size %s", world_size) + return None, None + + try: + # Create IPC workspace + ipc_handles, workspace_tensor = ( + flashinfer_comm.trtllm_create_ipc_workspace_for_all_reduce_fusion( + tp_rank=rank, + tp_size=world_size, + max_token_num=max_token_num, + hidden_dim=hidden_dim, + group=get_tp_group().device_group, + use_fp32_lamport=use_fp32_lamport, + ) + ) + + _FI_WORKSPACE_TENSOR = workspace_tensor + return ipc_handles, workspace_tensor + except Exception as e: + logger.error("Failed to setup FlashInfer workspace: %s", e) + return None, None + + +def cleanup_flashinfer_workspace(ipc_handles): + """Cleanup FlashInfer workspace.""" + if flashinfer_comm is None or ipc_handles is None: + return + + try: + group = get_tp_group().device_group + flashinfer_comm.trtllm_destroy_ipc_workspace_for_all_reduce(ipc_handles, group) + except Exception as e: + logger.error("Failed to cleanup FlashInfer workspace: %s", e) + + +class FlashInferFusedAllReduceParams: + """Parameters for FlashInfer fused allreduce operations.""" + + def __init__( + self, + rank: int, + world_size: int, + use_fp32_lamport: bool = False, + max_token_num: int = 1024, + ): + self.rank = rank + self.world_size = world_size + self.use_fp32_lamport = use_fp32_lamport + self.trigger_completion_at_end = True + self.launch_with_pdl = True + self.fp32_acc = True + self.max_token_num = max_token_num + + def get_trtllm_fused_allreduce_kwargs(self): + return { + "world_rank": self.rank, + "world_size": self.world_size, + "launch_with_pdl": self.launch_with_pdl, + "trigger_completion_at_end": self.trigger_completion_at_end, + "fp32_acc": self.fp32_acc, + } + + +def flashinfer_fused_allreduce_rmsnorm( + input_tensor: torch.Tensor, + residual: torch.Tensor | None, + rms_gamma: torch.Tensor, + rms_eps: float, + allreduce_params: "FlashInferFusedAllReduceParams", + use_oneshot: bool, + norm_out: torch.Tensor | None = None, +): + """FlashInfer fused allreduce + rmsnorm operation.""" + if flashinfer_comm is None or _FI_WORKSPACE_TENSOR is None: + raise RuntimeError("FlashInfer not available or workspace not initialized") + + if norm_out is None: + norm_out = input_tensor + residual_out = residual + else: + residual_out = input_tensor + + flashinfer_comm.trtllm_allreduce_fusion( + allreduce_in=input_tensor, + token_num=input_tensor.shape[0], + residual_in=residual, + residual_out=residual_out, + norm_out=norm_out, + rms_gamma=rms_gamma, + rms_eps=rms_eps, + hidden_dim=input_tensor.shape[-1], + workspace_ptrs=_FI_WORKSPACE_TENSOR, + pattern_code=flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNorm, + allreduce_out=None, + quant_out=None, + scale_out=None, + layout_code=flashinfer_comm.QuantizationSFLayout.SWIZZLED_128x4, + scale_factor=None, + use_oneshot=use_oneshot, + **allreduce_params.get_trtllm_fused_allreduce_kwargs(), + ) + + +def flashinfer_fused_allreduce_rmsnorm_fp8_quant( + input_tensor: torch.Tensor, + residual: torch.Tensor | None, + rms_gamma: torch.Tensor, + rms_eps: float, + scale_factor: torch.Tensor, + allreduce_params: FlashInferFusedAllReduceParams, + use_oneshot: bool = True, + norm_out: torch.Tensor | None = None, + quant_out: torch.Tensor | None = None, +): + """FlashInfer fused allreduce + rmsnorm + FP8 quantization.""" + if flashinfer_comm is None or _FI_WORKSPACE_TENSOR is None: + raise RuntimeError("FlashInfer not available or workspace not initialized") + + if norm_out is None: + norm_out = input_tensor + residual_out = residual + else: + residual_out = input_tensor + + flashinfer_comm.trtllm_allreduce_fusion( + allreduce_in=input_tensor, + token_num=input_tensor.shape[0], + residual_in=residual, + residual_out=residual_out, + norm_out=norm_out, + rms_gamma=rms_gamma, + rms_eps=rms_eps, + hidden_dim=input_tensor.shape[-1], + workspace_ptrs=_FI_WORKSPACE_TENSOR, + pattern_code=flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNormFP8Quant, + allreduce_out=None, + quant_out=quant_out, + scale_out=None, + layout_code=flashinfer_comm.QuantizationSFLayout.SWIZZLED_128x4, + scale_factor=scale_factor, + use_oneshot=use_oneshot, + **allreduce_params.get_trtllm_fused_allreduce_kwargs(), + ) + + +def flashinfer_fused_allreduce_rmsnorm_fp4_quant( + input_tensor: torch.Tensor, + residual: torch.Tensor | None, + rms_gamma: torch.Tensor, + rms_eps: float, + input_global_scale: torch.Tensor, + allreduce_params: FlashInferFusedAllReduceParams, + quant_out: torch.Tensor, + use_oneshot: bool, + output_scale: torch.Tensor, + norm_out: torch.Tensor | None = None, +): + """FlashInfer fused allreduce + rmsnorm + FP4 quantization.""" + if flashinfer_comm is None or _FI_WORKSPACE_TENSOR is None: + raise RuntimeError("FlashInfer not available or workspace not initialized") + + if norm_out is None: + norm_out = input_tensor + residual_out = residual + else: + residual_out = input_tensor + + flashinfer_comm.trtllm_allreduce_fusion( + allreduce_in=input_tensor, + token_num=input_tensor.shape[0], + residual_in=residual, + residual_out=residual_out, + norm_out=norm_out, + rms_gamma=rms_gamma, + rms_eps=rms_eps, + hidden_dim=input_tensor.shape[-1], + workspace_ptrs=_FI_WORKSPACE_TENSOR, + pattern_code=flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNormFP4Quant, + allreduce_out=None, + quant_out=quant_out, + scale_out=output_scale, + layout_code=flashinfer_comm.QuantizationSFLayout.SWIZZLED_128x4, + scale_factor=input_global_scale, + use_oneshot=use_oneshot, + **allreduce_params.get_trtllm_fused_allreduce_kwargs(), + ) + + +class VllmFusedAllreduce: + def __init__(self, hidden_dim, dtype): + self.rms_eps = 1e-6 + self.rms_norm = RMSNorm(hidden_dim, eps=self.rms_eps, dtype=dtype) + self.fp8_quant = QuantFP8( + static=True, + group_shape=GroupShape.PER_TENSOR, + ) + + def allreduce_rmsnorm( + self, input_tensor: torch.Tensor, residual: torch.Tensor | None + ): + allreduce_out = tensor_model_parallel_all_reduce(input_tensor) + return self.rms_norm(allreduce_out, residual) + + def allreduce_rmsnorm_fp8_quant( + self, + input_tensor: torch.Tensor, + residual: torch.Tensor | None, + scale_factor: torch.Tensor, + ): + allreduce_out = tensor_model_parallel_all_reduce(input_tensor) + rms_out = self.rms_norm(allreduce_out, residual) + if residual is None: + quant_out = self.fp8_quant(rms_out, scale_factor) + return quant_out + else: + rms_out, residual_out = rms_out + quant_out = self.fp8_quant(rms_out, scale_factor) + return quant_out, residual_out + + def allreduce_rmsnorm_fp4_quant( + self, + input_tensor: torch.Tensor, + residual: torch.Tensor | None, + input_global_scale: torch.Tensor, + quant_out: torch.Tensor, + output_scale: torch.Tensor, + ): + allreduce_out = tensor_model_parallel_all_reduce(input_tensor) + rms_out = self.rms_norm(allreduce_out, residual) + if residual is None: + SCALED_FP4_QUANT_OP(quant_out, rms_out, output_scale, input_global_scale) + return quant_out, output_scale + else: + rms_out, residual_out = rms_out + SCALED_FP4_QUANT_OP(quant_out, rms_out, output_scale, input_global_scale) + return quant_out, residual_out, output_scale + + +def create_test_tensors( + num_tokens: int, hidden_dim: int, dtype: torch.dtype, use_residual: bool = True +): + """Create test tensors for benchmarking.""" + input_tensor = torch.randn(num_tokens, hidden_dim, dtype=dtype) + residual = ( + torch.randn_like(input_tensor) + if use_residual + else torch.zeros_like(input_tensor) + ) + rms_gamma = torch.ones(hidden_dim, dtype=dtype) + norm_out = None if use_residual else torch.empty_like(input_tensor) + + # Quantization scales + scale_fp8 = torch.tensor(1.0, dtype=torch.float32) + scale_fp4 = torch.tensor(1.0, dtype=torch.float32) + quant_out_fp8 = torch.empty_like(input_tensor, dtype=FP8_DTYPE) + # Pre-allocate FP4 output tensors (to avoid allocation overhead in benchmarks) + fp4_quant_out = torch.empty((num_tokens, hidden_dim // 2), dtype=torch.uint8) + fp4_output_scale = torch.empty((128, 4), dtype=torch.int32) + + return ( + input_tensor, + norm_out, + residual, + rms_gamma, + scale_fp8, + quant_out_fp8, + scale_fp4, + fp4_quant_out, + fp4_output_scale, + ) + + +def benchmark_operation( + operation_func, *args, warmup: int = 5, trials: int = 20, **kwargs +): + """Benchmark a single operation using CUDA graphs.""" + # Warmup before graph capture + for _ in range(warmup): + operation_func(*args, **kwargs) + torch.cuda.synchronize() + + # Create CUDA graph + graph = torch.cuda.CUDAGraph() + num_op_per_cudagraph = 10 + + # Use vLLM's graph_capture to make tensor_model_parallel_all_reduce graph-safe + device = torch.device(f"cuda:{torch.cuda.current_device()}") + with graph_capture(device=device), torch.cuda.graph(graph): + for _ in range(num_op_per_cudagraph): + operation_func(*args, **kwargs) + + # Graph warmup + torch.cuda.synchronize() + for _ in range(warmup): + graph.replay() + + # Benchmark with CUDA graph + torch.cuda.synchronize() + start_time = time.perf_counter() + + for _ in range(trials // num_op_per_cudagraph): + # operation_func(*args, **kwargs) + graph.replay() + + torch.cuda.synchronize() + end_time = time.perf_counter() + + avg_time_ms = ((end_time - start_time) / trials) * 1000 + return avg_time_ms + + +def run_benchmarks( + num_tokens: int, + hidden_dim: int, + dtype: torch.dtype, + use_residual: bool, + allreduce_params: FlashInferFusedAllReduceParams | None, + quant_modes: set[str], + no_oneshot: bool, +): + """Run all benchmarks for given configuration. + + Args: + quant_mode: "none", "fp8_only", "fp4_only", or "all" + """ + ( + input_tensor, + norm_out, + residual, + rms_gamma, + scale_fp8, + quant_out_fp8, + scale_fp4, + fp4_quant_out, + fp4_output_scale, + ) = create_test_tensors(num_tokens, hidden_dim, dtype, use_residual) + + rms_eps = 1e-6 + results = {} + vllm_fused_allreduce = VllmFusedAllreduce(hidden_dim, dtype) + use_oneshot_options = [False] if no_oneshot else [True, False] + + # Create RMSNorm and QuantFP8 layers once for native benchmarks + + if "none" in quant_modes: + # Standard AllReduce + RMSNorm + for custom_op in ["-rms_norm", "+rms_norm"]: + with set_current_vllm_config( + VllmConfig(compilation_config=CompilationConfig(custom_ops=[custom_op])) + ): + try: + suffix = ( + "_custom_rms_norm" if "+" in custom_op else "_native_rms_norm" + ) + time_ms = benchmark_operation( + vllm_fused_allreduce.allreduce_rmsnorm, + input_tensor, + residual=residual, + ) + results[f"standard_allreduce_{suffix}"] = time_ms + except Exception as e: + logger.error("Standard AllReduce+RMSNorm failed: %s", e) + results[f"standard_allreduce_{suffix}"] = float("inf") + + # Standard AllReduce + RMSNorm Native Compiled + with set_current_vllm_config( + VllmConfig(compilation_config=CompilationConfig(custom_ops=["-rms_norm"])) + ): + try: + standard_allreduce_rmsnorm_native_compiled = torch.compile( + vllm_fused_allreduce.allreduce_rmsnorm, + fullgraph=True, + dynamic=False, + ) + time_ms = benchmark_operation( + standard_allreduce_rmsnorm_native_compiled, + input_tensor, + residual=residual, + ) + results["standard_allreduce_rmsnorm_native_compiled"] = time_ms + except Exception as e: + logger.error("Standard AllReduce+RMSNorm Native Compiled failed: %s", e) + results["standard_allreduce_rmsnorm_native_compiled"] = float("inf") + + # FlashInfer Fused AllReduce + RMSNorm Oneshot/Twoshot + if flashinfer_comm is not None and allreduce_params is not None: + for use_oneshot in use_oneshot_options: + suffix = "_oneshot" if use_oneshot else "_twoshot" + try: + time_ms = benchmark_operation( + flashinfer_fused_allreduce_rmsnorm, + input_tensor, + residual=residual, + norm_out=norm_out, + rms_gamma=rms_gamma, + rms_eps=rms_eps, + allreduce_params=allreduce_params, + use_oneshot=use_oneshot, + ) + results[f"flashinfer_fused_allreduce_rmsnorm{suffix}"] = time_ms + except Exception as e: + logger.error("FlashInfer Fused AllReduce+RMSNorm failed: %s", e) + results[f"flashinfer_fused_allreduce_rmsnorm{suffix}"] = float( + "inf" + ) + + if "fp8" in quant_modes: + # Standard AllReduce + RMSNorm + FP8 Quant + for rms_norm_custom_op in ["-rms_norm", "+rms_norm"]: + suffix = ( + "_custom_rms_norm" if "+" in rms_norm_custom_op else "_native_rms_norm" + ) + for quant_fp8_custom_op in ["-quant_fp8", "+quant_fp8"]: + suffix += ( + "_custom_quant_fp8" + if "+" in quant_fp8_custom_op + else "_native_quant_fp8" + ) + with set_current_vllm_config( + VllmConfig( + compilation_config=CompilationConfig( + custom_ops=[rms_norm_custom_op, quant_fp8_custom_op] + ) + ) + ): + try: + time_ms = benchmark_operation( + vllm_fused_allreduce.allreduce_rmsnorm_fp8_quant, + input_tensor, + residual=residual, + scale_factor=scale_fp8, + ) + results[f"standard_allreduce{suffix}"] = time_ms + except Exception as e: + logger.error("Standard AllReduce+RMSNorm+FP8 failed: %s", e) + results[f"standard_allreduce{suffix}"] = float("inf") + + # Standard AllReduce + RMSNorm + FP8 Quant Native Compiled + with set_current_vllm_config( + VllmConfig( + compilation_config=CompilationConfig( + custom_ops=["-rms_norm", "-quant_fp8"] + ) + ) + ): + try: + standard_allreduce_rmsnorm_fp8_quant_native_compiled = torch.compile( + vllm_fused_allreduce.allreduce_rmsnorm_fp8_quant, + fullgraph=True, + dynamic=False, + ) + time_ms = benchmark_operation( + standard_allreduce_rmsnorm_fp8_quant_native_compiled, + input_tensor, + residual=residual, + scale_factor=scale_fp8, + ) + results["standard_allreduce_rmsnorm_fp8_quant_native_compiled"] = ( + time_ms + ) + except Exception as e: + logger.error( + "Standard AllReduce+RMSNorm+FP8 Native Compiled failed: %s", e + ) + results["standard_allreduce_rmsnorm_fp8_quant_native_compiled"] = float( + "inf" + ) + + # FlashInfer Fused AllReduce + RMSNorm + FP8 Quant Oneshot + if flashinfer_comm is not None and allreduce_params is not None: + for use_oneshot in use_oneshot_options: + suffix = "_oneshot" if use_oneshot else "_twoshot" + try: + time_ms = benchmark_operation( + flashinfer_fused_allreduce_rmsnorm_fp8_quant, + input_tensor, + norm_out=norm_out, + residual=residual, + rms_gamma=rms_gamma, + rms_eps=rms_eps, + scale_factor=scale_fp8, + quant_out=quant_out_fp8, + allreduce_params=allreduce_params, + use_oneshot=use_oneshot, + ) + results[f"flashinfer_fused_allreduce_rmsnorm_fp8_quant{suffix}"] = ( + time_ms + ) + except Exception as e: + logger.error( + "FlashInfer Fused AllReduce+RMSNorm+FP8 Oneshot failed: %s", + e, + ) + results[f"flashinfer_fused_allreduce_rmsnorm_fp8_quant{suffix}"] = ( + float("inf") + ) + + if "fp4" in quant_modes and current_platform.has_device_capability(100): + # Standard AllReduce + RMSNorm + FP4 Quant + for rms_norm_custom_op in ["-rms_norm", "+rms_norm"]: + suffix = ( + "_custom_rms_norm" if "+" in rms_norm_custom_op else "_native_rms_norm" + ) + with set_current_vllm_config( + VllmConfig( + compilation_config=CompilationConfig( + custom_ops=[rms_norm_custom_op] + ) + ) + ): + try: + time_ms = benchmark_operation( + vllm_fused_allreduce.allreduce_rmsnorm_fp4_quant, + input_tensor, + residual=residual, + input_global_scale=scale_fp4, + quant_out=fp4_quant_out, + output_scale=fp4_output_scale, + ) + results[f"standard_allreduce_{suffix}_fp4_quant"] = time_ms + except Exception as e: + logger.error("Standard AllReduce+RMSNorm+FP4 failed: %s", e) + results[f"standard_allreduce_{suffix}_fp4_quant"] = float("inf") + + # Standard AllReduce + RMSNorm + FP4 Quant Native Compiled + with set_current_vllm_config( + VllmConfig(compilation_config=CompilationConfig(custom_ops=["-rms_norm"])) + ): + try: + standard_allreduce_rmsnorm_fp4_quant_native_compiled = torch.compile( + vllm_fused_allreduce.allreduce_rmsnorm_fp4_quant, + fullgraph=True, + dynamic=False, + ) + time_ms = benchmark_operation( + standard_allreduce_rmsnorm_fp4_quant_native_compiled, + input_tensor, + residual=residual, + quant_out=fp4_quant_out, + input_global_scale=scale_fp4, + output_scale=fp4_output_scale, + ) + results["standard_allreduce_rmsnorm_fp4_quant_native_compiled"] = ( + time_ms + ) + except Exception as e: + logger.error( + "Standard AllReduce+RMSNorm+FP4 Native Compiled failed: %s", e + ) + results["standard_allreduce_rmsnorm_fp4_quant_native_compiled"] = float( + "inf" + ) + + # FlashInfer Fused AllReduce + RMSNorm + FP4 Quant Oneshot + if flashinfer_comm is not None and allreduce_params is not None: + for use_oneshot in use_oneshot_options: + suffix = "_oneshot" if use_oneshot else "_twoshot" + try: + time_ms = benchmark_operation( + flashinfer_fused_allreduce_rmsnorm_fp4_quant, + input_tensor, + residual=residual, + norm_out=norm_out, + rms_gamma=rms_gamma, + rms_eps=rms_eps, + input_global_scale=scale_fp4, + allreduce_params=allreduce_params, + quant_out=fp4_quant_out, + output_scale=fp4_output_scale, + use_oneshot=use_oneshot, + ) + results[f"flashinfer_fused_allreduce_rmsnorm_fp4_quant{suffix}"] = ( + time_ms + ) + except Exception as e: + logger.error( + "FlashInfer Fused AllReduce+RMSNorm+FP4 Oneshot failed: %s", + e, + ) + results[f"flashinfer_fused_allreduce_rmsnorm_fp4_quant{suffix}"] = ( + float("inf") + ) + + # FlashInfer Fused AllReduce + RMSNorm + FP4 Quant Two-shot + if flashinfer_comm is not None and allreduce_params is not None: + try: + time_ms = benchmark_operation( + flashinfer_fused_allreduce_rmsnorm_fp4_quant, + input_tensor, + residual=residual, + norm_out=norm_out, + rms_gamma=rms_gamma, + rms_eps=rms_eps, + input_global_scale=scale_fp4, + allreduce_params=allreduce_params, + quant_out=fp4_quant_out, + output_scale=fp4_output_scale, + use_oneshot=False, + ) + results["flashinfer_fused_allreduce_rmsnorm_fp4_quant_twoshot"] = ( + time_ms + ) + except Exception as e: + logger.error( + "FlashInfer Fused AllReduce+RMSNorm+FP4 Two-shot failed: %s", + e, + ) + results["flashinfer_fused_allreduce_rmsnorm_fp4_quant_twoshot"] = float( + "inf" + ) + + return results + + +def prepare_results_with_speedups(results_dict): + """Prepare results with speedup calculations based on dynamic baseline selection.""" + prepared_results = [] + + # Determine the fastest baseline for each operation type + def get_fastest_baseline(op_name, results_dict): + """Get the fastest baseline between standard and native_compiled versions.""" + if "fp8_quant" in op_name: + candidates = [ + "standard_allreduce_rmsnorm_fp8_quant", + "standard_allreduce_rmsnorm_fp8_quant_native_compiled", + ] + elif "fp4_quant" in op_name: + candidates = [ + "standard_allreduce_rmsnorm_fp4_quant", + "standard_allreduce_rmsnorm_fp4_quant_native_compiled", + ] + else: + candidates = [ + "standard_allreduce_rmsnorm", + "standard_allreduce_rmsnorm_native_compiled", + ] + + # Find the fastest among available candidates + fastest_time = float("inf") + fastest_baseline = None + + for candidate in candidates: + if ( + candidate in results_dict + and results_dict[candidate] != float("inf") + and results_dict[candidate] < fastest_time + ): + fastest_time = results_dict[candidate] + fastest_baseline = candidate + + return fastest_baseline + + # Create dynamic baseline mapping + dynamic_baseline_mapping = {} + for op_name in results_dict: + if ( + op_name.startswith("flashinfer_") + or op_name.startswith("standard_") + and not op_name.endswith("_native_compiled") + ): + dynamic_baseline_mapping[op_name] = get_fastest_baseline( + op_name, results_dict + ) + + for op_name, time_ms in results_dict.items(): + if time_ms == float("inf"): + speedup_str = "FAILED" + time_str = "FAILED" + else: + time_str = f"{time_ms:.3f}" + # Find the appropriate baseline for this operation + baseline_op = dynamic_baseline_mapping.get(op_name) + if baseline_op and baseline_op in results_dict: + baseline_time = results_dict[baseline_op] + if baseline_time != float("inf") and baseline_time > 0: + speedup = baseline_time / time_ms + speedup_str = f"{speedup:.2f}x" + else: + speedup_str = "N/A" + else: + # For baseline operations, determine if this is the fastest baseline + if op_name.endswith("_native_compiled") or ( + op_name.startswith("standard_") + and not op_name.endswith("_native_compiled") + ): + fastest_baseline = get_fastest_baseline(op_name, results_dict) + if fastest_baseline == op_name: + speedup_str = "baseline" + else: + if fastest_baseline and fastest_baseline in results_dict: + baseline_time = results_dict[fastest_baseline] + if baseline_time != float("inf") and baseline_time > 0: + speedup = baseline_time / time_ms + speedup_str = f"{speedup:.2f}x" + else: + speedup_str = "N/A" + else: + speedup_str = "N/A" + else: + speedup_str = "N/A" + + prepared_results.append( + { + "operation": op_name, + "time_ms": time_ms, + "time_str": time_str, + "speedup_str": speedup_str, + } + ) + + return prepared_results + + +def print_results( + results_dict, + num_tokens, + hidden_dim, + dtype, + use_residual, + quant_modes, + input_size_mb, +): + """Print benchmark results in a formatted table.""" + print(f"\n{'=' * 80}") + print( + f"Results: num_tokens={num_tokens}, hidden_dim={hidden_dim} " + f"(input size: {input_size_mb:.2f} MB)" + ) + print( + f"dtype={dtype}, residual={'yes' if use_residual else 'no'}, " + f"quant_modes={','.join(sorted(list(quant_modes)))}" + ) + print(f"{'=' * 80}") + print(f"{'Operation':<50} {'Time (ms)':<12} {'Speedup':<10}") + print(f"{'-' * 80}") + + # Prepare results with speedup calculations + prepared_results = prepare_results_with_speedups(results_dict) + + for result in prepared_results: + if result["time_ms"] == float("inf"): + time_display = result["time_str"] + else: + time_display = f"{result['time_ms']:.3f}" + + print( + f"{result['operation']:<50} {time_display:<12} {result['speedup_str']:<10}" + ) + + +def format_results_markdown( + all_results: list[dict], world_size: int, args: argparse.Namespace +) -> str: + """Format all benchmark results as markdown.""" + lines: list[str] = [] + lines.append("# FlashInfer Fused Collective Operations Benchmark Results") + lines.append("") + lines.append(f"**World Size:** {world_size} ") + lines.append(f"**Hidden Dimension:** {args.hidden_dim} ") + lines.append(f"**Warmup Iterations:** {args.warmup} ") + lines.append(f"**Benchmark Trials:** {args.trials} ") + modes = ",".join(all_results[0]["quant_modes"]) if all_results else "N/A" + lines.append(f"**Quantization Modes:** {modes} ") + lines.append("") + lines.append("---") + lines.append("") + + for entry in all_results: + num_tokens = entry["num_tokens"] + dtype = entry["dtype"] + use_residual = entry["use_residual"] + results_dict = entry["results"] + input_size_mb = entry["input_size_mb"] + residual_str = "with residual" if use_residual else "no residual" + + lines.append( + f"## Configuration: num_tokens={num_tokens}, dtype={dtype}, {residual_str}" + ) + lines.append(f"**Input Size:** {input_size_mb:.2f} MB") + lines.append("") + + prepared = prepare_results_with_speedups(results_dict) + # Build DataFrame for markdown export + rows = [ + { + "Operation": r["operation"].replace("_", " ").title(), + "Time (ms)": r["time_str"], + "Speedup": r["speedup_str"], + } + for r in prepared + ] + df = pd.DataFrame(rows) + if df.empty: + lines.append("No results.") + else: + lines.append(df.to_markdown(index=False)) + lines.append("") + + return "\n".join(lines) + + +def save_results_to_file( + all_results: list[dict], world_size: int, args: argparse.Namespace, rank: int +): + """Save benchmark results to markdown file (only on rank 0).""" + if rank != 0: + return + + if not all_results: + logger.warning("No results to save") + return + + output_path = args.output_file + + try: + markdown_content = format_results_markdown(all_results, world_size, args) + + with open(output_path, "a") as f: + f.write(markdown_content) + + except Exception as e: + logger.error("Failed to save results to file: %s", e) + + +def main(): + parser = argparse.ArgumentParser( + description="Benchmark fused collective operations" + ) + parser.add_argument( + "--num-tokens", + type=int, + nargs="+", + default=[128, 512, 1024, 2048], + help="Numbers of tokens to test", + ) + parser.add_argument( + "--hidden-dim", type=int, default=8192, help="Hidden dimension size" + ) + parser.add_argument( + "--dtypes", + type=str, + nargs="+", + default=["bfloat16"], + choices=["float16", "bfloat16", "float32"], + help="Data types to test", + ) + parser.add_argument( + "--no-residual", + action="store_true", + help="Skip residual connection tests", + ) + + parser.add_argument( + "--quant-modes", + type=str, + default="none,fp8,fp4", + help=( + "Comma-separated quantization modes to run: none, fp8, fp4. " + "Default: none,fp8,fp4" + ), + ) + + parser.add_argument( + "--warmup", type=int, default=5, help="Number of warmup iterations" + ) + parser.add_argument( + "--trials", type=int, default=20, help="Number of benchmark trials" + ) + parser.add_argument( + "--output-file", + type=str, + help="""Output file path for markdown results + (default: benchmark_results_.md) + """, + ) + + parser.add_argument( + "--no-oneshot", + action="store_true", + help="Skip oneshot benchmarks", + ) + + args = parser.parse_args() + + # Check if running with torchrun (required for collective operations) + if "RANK" not in os.environ or "WORLD_SIZE" not in os.environ: + raise RuntimeError( + "Must run with torchrun for distributed benchmarking. " + "Example: torchrun --nproc_per_node=2 benchmark_fused_collective.py" + ) + + # Initialize distributed environment + rank = int(os.environ["RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + + device = torch.device(f"cuda:{rank}") + torch.cuda.set_device(device) + torch.set_default_device(device) + + init_distributed_environment() + initialize_model_parallel(tensor_model_parallel_size=world_size) + + # Validate world size (must be > 1 for collective operations) + if world_size <= 1: + raise ValueError( + "World size must be > 1 for collective operations benchmarking. " + f"Current world size: {world_size}. Use torchrun with --nproc_per_node > 1." + ) + + # Parse quantization modes + valid_quant_modes = {"none", "fp8", "fp4"} + raw_modes = [ + m.strip().lower() for m in (args.quant_modes or "").split(",") if m.strip() + ] + quant_modes = set(raw_modes) if raw_modes else {"none", "fp8", "fp4"} + invalid = sorted(list(quant_modes - valid_quant_modes)) + if invalid: + raise ValueError( + f"Invalid --quant-modes entries: {','.join(invalid)}. " + f"Valid options are: {','.join(sorted(valid_quant_modes))}." + ) + + if rank == 0: + logger.info("Running benchmark with world_size=%s, rank=%s", world_size, rank) + logger.info("Quantization modes: %s", ",".join(sorted(list(quant_modes)))) + if flashinfer_comm is not None: + logger.info( + "FlashInfer available - will benchmark fused operations", + ) + else: + logger.info( + "FlashInfer not available - only benchmarking standard operations" + ) + + # Convert dtype strings to torch dtypes + dtype_map = { + "float16": torch.float16, + "bfloat16": torch.bfloat16, + "float32": torch.float32, + } + dtypes = [dtype_map[dt] for dt in args.dtypes] + + # Test configurations + residual_options = [True] if not args.no_residual else [False] + + configs = list(itertools.product(args.num_tokens, dtypes, residual_options)) + + # Setup FlashInfer workspace if available + ipc_handles = None + allreduce_params = None + + if flashinfer_comm is not None: + # Use the largest hidden dimension for workspace setup + max_num_token = _FI_MAX_SIZES.get(world_size) // ( + args.hidden_dim * world_size * 2 + ) + + ipc_handles, workspace_tensor = setup_flashinfer_workspace( + world_size, rank, args.hidden_dim, max_num_token + ) + + if workspace_tensor is not None: + allreduce_params = FlashInferFusedAllReduceParams( + rank=rank, + world_size=world_size, + max_token_num=max_num_token, + ) + + # Collect all results for markdown export + all_results = [] + + try: + # Run benchmarks + for num_tokens, dtype, use_residual in configs: + if rank == 0: + logger.info( + "\nTesting: num_tokens=%s, hidden_dim=%s, dtype=%s, residual=%s", + num_tokens, + args.hidden_dim, + dtype, + use_residual, + ) + + results = run_benchmarks( + num_tokens, + args.hidden_dim, + dtype, + use_residual, + allreduce_params, + quant_modes=quant_modes, + no_oneshot=args.no_oneshot, + ) + + # Store results for markdown export + if rank == 0: + # Calculate input size in MB + input_size_mb = ( + num_tokens * args.hidden_dim * torch.finfo(dtype).bits + ) / (8 * 1024 * 1024) + all_results.append( + { + "num_tokens": num_tokens, + "hidden_dim": args.hidden_dim, + "dtype": str(dtype).replace("torch.", ""), + "use_residual": use_residual, + "quant_modes": sorted(list(quant_modes)), + "input_size_mb": input_size_mb, + "results": results, + } + ) + + print_results( + results, + num_tokens, + args.hidden_dim, + dtype, + use_residual, + quant_modes, + input_size_mb, + ) + + # Save results to markdown file + if args.output_file and rank == 0: + save_results_to_file(all_results, world_size, args, rank) + + finally: + # Cleanup + if ipc_handles is not None: + cleanup_flashinfer_workspace(ipc_handles) + + dist.barrier() + + +if __name__ == "__main__": + main() diff --git a/tests/compile/test_fusions_e2e.py b/tests/compile/test_fusions_e2e.py index 58026e7e7e78..4b910bc28579 100644 --- a/tests/compile/test_fusions_e2e.py +++ b/tests/compile/test_fusions_e2e.py @@ -71,6 +71,13 @@ class ModelBackendTestCase(NamedTuple): attention_fusions=0, allreduce_fusions=65, ), + ModelBackendTestCase( + model_name="Qwen/Qwen3-30B-A3B", + model_kwargs=dict(max_model_len=1024), + backend=_Backend.TRITON_ATTN, + attention_fusions=0, + allreduce_fusions=97, + ), ] elif current_platform.is_rocm(): diff --git a/vllm/compilation/collective_fusion.py b/vllm/compilation/collective_fusion.py index 7294ddce64ba..69d4606d73eb 100644 --- a/vllm/compilation/collective_fusion.py +++ b/vllm/compilation/collective_fusion.py @@ -9,7 +9,6 @@ from torch._inductor.pattern_matcher import PatternMatcherPass from torch.distributed._symmetric_memory import enable_symm_mem_for_group -import vllm.envs as envs from vllm.config import VllmConfig from vllm.distributed import get_tp_group, tensor_model_parallel_all_reduce from vllm.distributed.parallel_state import ( @@ -450,34 +449,41 @@ def __call__(self, graph: fx.Graph): logger.debug("Replaced %s patterns", self.matched_count) +# Max size of the input tensor per world size per device capability +# to use flashinfer fused allreduce +FI_ALLREDUCE_FUSION_MAX_SIZE_MB: dict[int, dict[int, float]] = { + 90: { + 2: 64, # 64MB + 4: 2, # 2MB + 8: 0.5, # 0.5MB + }, + 100: { + 2: 64, # 64MB + 4: 32, # 32MB + 8: 1, # 1MB + }, +} + +# Max size of the input tensor per world size per device capability +# to use flashinfer one shot fused allreduce +# OneShot max size is at most 64MB / world size (FlashInfer restriction) +_FI_ALLREDUCE_ONE_SHOT_MAX_SIZES_MB: dict[int, dict[int, float]] = { + 90: { + 2: 32, # 32MB + 4: 2, # 2MB + 8: 0.5, # 0.5MB + }, + 100: { + 2: 32, # 32MB + 4: 4, # 4MB + 8: 1, # 1MB + }, +} + + if flashinfer_comm is not None: _FI_WORKSPACE_TENSOR = None - MiB = 1024 * 1024 - # Max size of the input tensor per world size - # to use flashinfer fused allreduce - _FI_MAX_SIZES = { - 2: 64 * MiB, # 64MB - 4: MiB, # 1MB - 6: MiB // 2, # 512KB - 8: MiB // 2, # 512KB - } - - try: - _FI_MAX_SIZES.update( - { - int(k): int(float(v) * MiB) - for k, v in envs.VLLM_FLASHINFER_ALLREDUCE_FUSION_THRESHOLDS_MB.items() - } - ) - except Exception as e: - raise ValueError( - "Failed to parse VLLM_FLASHINFER_ALLREDUCE_FUSION_THRESHOLDS_MB: " + str(e) - ) from e - - # opt for a more conservative default value - # when world size is not in _FI_MAX_SIZES - _DEFAULT_FI_MAX_SIZE = MiB // 2 def call_trtllm_fused_allreduce_norm( allreduce_in: torch.Tensor, @@ -491,7 +497,6 @@ def call_trtllm_fused_allreduce_norm( fp32_acc: bool, max_token_num: int, pattern_code: int, - fuse_rms_quant: bool, norm_out: torch.Tensor | None = None, quant_out: torch.Tensor | None = None, scale_out: torch.Tensor | None = None, @@ -500,12 +505,20 @@ def call_trtllm_fused_allreduce_norm( num_tokens, hidden_size = allreduce_in.shape element_size = allreduce_in.element_size() current_tensor_size = num_tokens * hidden_size * element_size - max_fusion_size = max_token_num * hidden_size * element_size - use_flashinfer = current_tensor_size <= min( - _FI_MAX_SIZES.get(world_size, _DEFAULT_FI_MAX_SIZE), - max_fusion_size, - ) - if use_flashinfer: + + if num_tokens <= max_token_num: + device_capability = current_platform.get_device_capability().to_int() + # Get one shot input size limit for the current world size + # for the current device capability + max_one_shot_size_mb = _FI_ALLREDUCE_ONE_SHOT_MAX_SIZES_MB.get( + device_capability, {} + ).get(world_size, None) + # Use one shot if no max size for one shot is specified + use_oneshot = ( + max_one_shot_size_mb is None + or current_tensor_size <= max_one_shot_size_mb * MiB + ) + assert _FI_WORKSPACE_TENSOR is not None, ( "Flashinfer must be enabled when using flashinfer" ) @@ -532,7 +545,7 @@ def call_trtllm_fused_allreduce_norm( hidden_dim=allreduce_in.shape[-1], workspace_ptrs=_FI_WORKSPACE_TENSOR, launch_with_pdl=launch_with_pdl, - use_oneshot=True, + use_oneshot=use_oneshot, trigger_completion_at_end=trigger_completion_at_end, fp32_acc=fp32_acc, pattern_code=pattern_code, @@ -545,7 +558,7 @@ def call_trtllm_fused_allreduce_norm( ) else: allreduce_out = tensor_model_parallel_all_reduce(allreduce_in) - if scale_factor is not None and scale_out is None and fuse_rms_quant: + if scale_factor is not None and scale_out is None: # Do fused rms norm static fp8 quant fused op if norm_out is None: torch.ops._C.fused_add_rms_norm_static_fp8_quant( @@ -568,15 +581,10 @@ def call_trtllm_fused_allreduce_norm( norm_out = allreduce_out else: torch.ops._C.rms_norm(norm_out, allreduce_out, rms_gamma, rms_eps) - if scale_factor is not None: - if scale_out is not None: - torch.ops._C.scaled_fp4_quant( - quant_out, norm_out, scale_out, scale_factor - ) - else: - torch.ops._C.static_scaled_fp8_quant( - quant_out, norm_out, scale_factor - ) + if scale_factor is not None and scale_out is not None: + torch.ops._C.scaled_fp4_quant( + quant_out, norm_out, scale_out, scale_factor + ) if scale_factor is None or norm_out is not None: # we need to return allreduce output # in cases of non quant fused AR + RMS norm @@ -595,7 +603,6 @@ def call_trtllm_fused_allreduce_norm_fake( fp32_acc: bool, max_token_num: int, pattern_code: int, - fuse_rms_quant: bool, norm_out: torch.Tensor | None = None, quant_out: torch.Tensor | None = None, scale_out: torch.Tensor | None = None, @@ -629,7 +636,6 @@ def __init__( world_size: int, use_fp32_lamport: bool = False, max_token_num: int = 1024, - fuse_rms_quant: bool = False, ): self.rank = rank self.world_size = world_size @@ -637,9 +643,7 @@ def __init__( self.trigger_completion_at_end = True self.launch_with_pdl = True self.fp32_acc = True - self.use_oneshot = False self.max_token_num = max_token_num - self.fuse_rms_quant = fuse_rms_quant def get_trtllm_fused_allreduce_kwargs(self): return { @@ -649,7 +653,6 @@ def get_trtllm_fused_allreduce_kwargs(self): "trigger_completion_at_end": self.trigger_completion_at_end, "fp32_acc": self.fp32_acc, "max_token_num": self.max_token_num, - "fuse_rms_quant": self.fuse_rms_quant, } @@ -1119,23 +1122,35 @@ def __init__(self, config: VllmConfig): "skipping allreduce fusion pass" ) return - # Check if the world size is supported - if self.tp_size not in _FI_MAX_SIZES: + max_size = config.compilation_config.pass_config.flashinfer_max_size( + self.tp_size + ) + if max_size is None: + # Flashinfer doesn't support current world size logger.warning( "Flashinfer allreduce fusion is not supported for world size %s", self.tp_size, ) return - max_num_token = min( - _FI_MAX_SIZES.get(self.tp_size, _DEFAULT_FI_MAX_SIZE) - // (self.hidden_dim * self.tp_size * (4 if use_fp32_lamport else 2)), - config.compilation_config.pass_config.fi_allreduce_fusion_max_token_num, + element_size = 4 if use_fp32_lamport else 2 + self.max_token_num = max_size // (self.hidden_dim * element_size) + # take the min to save workspace size and we'll never use more + # than max_num_batched_tokens anyways + self.max_token_num = min( + self.max_token_num, config.scheduler_config.max_num_batched_tokens + ) + logger.debug_once( + f"Flashinfer max size: {max_size // (1024 * 1024)} MB," + "Maximal number of tokens used by " + f"Flashinfer Allreduce Fusion: {self.max_token_num}", + scope="global", ) + self.ipc_handles, workspace_tensor = ( flashinfer_comm.trtllm_create_ipc_workspace_for_all_reduce_fusion( tp_rank=rank, tp_size=self.tp_size, - max_token_num=max_num_token, + max_token_num=self.max_token_num, hidden_dim=self.hidden_dim, group=self.group, use_fp32_lamport=use_fp32_lamport, @@ -1148,10 +1163,7 @@ def __init__(self, config: VllmConfig): rank=rank, world_size=self.tp_size, use_fp32_lamport=use_fp32_lamport, - max_token_num=max_num_token, - # fuse rms norm static fp8 quant fused op - # in fallback path, when we don't use flashinfer - fuse_rms_quant=config.compilation_config.pass_config.enable_fusion, + max_token_num=self.max_token_num, ) self.register_patterns() diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index c84a060922e3..92cf16f259fe 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -111,11 +111,52 @@ class PassConfig: """Whether to enable async TP.""" enable_fi_allreduce_fusion: bool = False """Whether to enable flashinfer allreduce fusion.""" - fi_allreduce_fusion_max_token_num: int = 16384 - """Max number of tokens to used in flashinfer allreduce fusion.""" + fi_allreduce_fusion_max_size_mb: float | None = None + """The threshold of the communicated tensor sizes under which + vllm should use flashinfer fused allreduce. Specified as a + float in MB. + Unspecified will fallback to default values + which are compute capability and world size dependent. + FI_ALLREDUCE_FUSION_MAX_SIZE_MB = { + 90: { + 2: 64, # 64MB + 4: 2, # 2MB + 8: 1, # 1MB + }, + 100: { + 2: 64, # 64MB + 4: 32, # 32MB + 8: 1, # 1MB + }, + }, where key is the device capability""" # TODO(luka) better pass enabling system. + def flashinfer_max_size(self, world_size: int) -> int | None: + """ + Returns the max communication size in bytes for flashinfer + allreduce fusion for the given world size. Returns None if world size + is not supported by configs as it's not supported by flashinfer. + """ + + MiB = 1024 * 1024 + max_size_mb = self.fi_allreduce_fusion_max_size_mb + if max_size_mb is None: + max_size_mb = self.default_fi_allreduce_fusion_max_size_mb().get(world_size) + + return int(max_size_mb * MiB) if max_size_mb is not None else None + + @staticmethod + def default_fi_allreduce_fusion_max_size_mb() -> dict[int, float]: + from vllm.compilation.collective_fusion import FI_ALLREDUCE_FUSION_MAX_SIZE_MB + from vllm.platforms import current_platform + + if not current_platform.is_cuda(): + return {} + return FI_ALLREDUCE_FUSION_MAX_SIZE_MB.get( + current_platform.get_device_capability().to_int(), {} + ) + def uuid(self): """ Produces a hash unique to the pass configuration. @@ -136,6 +177,11 @@ def __post_init__(self) -> None: "Fusion enabled but reshape elimination disabled. " "Attention + quant (fp8) fusion might not work" ) + if self.enable_fi_allreduce_fusion: + logger.warning_once( + "Fusion enabled but reshape elimination disabled. " + "Allreduce + rms norm + quant (fp8) fusion might not work" + ) @config diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index f86a93e30003..27ad9c8fd1c2 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -2356,6 +2356,16 @@ def forward_native( value=0.0, ) + def reduce_output(states: torch.Tensor) -> torch.Tensor: + if ( + not self.is_sequence_parallel + and not self.use_dp_chunking + and self.reduce_results + and (self.tp_size > 1 or self.ep_size > 1) + ): + states = self.maybe_all_reduce_tensor_model_parallel(states) + return states + if self.shared_experts is None: if current_platform.is_tpu(): # TODO: Once the OOM issue for the TPU backend is resolved, we @@ -2366,7 +2376,14 @@ def forward_native( fused_output = torch.ops.vllm.moe_forward( hidden_states, router_logits, self.layer_name ) - return fused_output[..., :og_hidden_states] + if self.zero_expert_num is not None and self.zero_expert_num > 0: + assert isinstance(fused_output, tuple) + fused_output, zero_expert_result = fused_output + return (reduce_output(fused_output) + zero_expert_result)[ + ..., :og_hidden_states + ] + else: + return reduce_output(fused_output)[..., :og_hidden_states] else: if current_platform.is_tpu(): # TODO: Once the OOM issue for the TPU backend is resolved, we @@ -2379,8 +2396,8 @@ def forward_native( hidden_states, router_logits, self.layer_name ) return ( - shared_output[..., :og_hidden_states], - fused_output[..., :og_hidden_states], + reduce_output(shared_output)[..., :og_hidden_states], + reduce_output(fused_output)[..., :og_hidden_states], ) def forward_cuda( @@ -2667,31 +2684,21 @@ def forward_impl( assert isinstance(final_hidden_states, tuple) final_hidden_states, zero_expert_result = final_hidden_states - def reduce_output( - states: torch.Tensor, do_combine: bool = True - ) -> torch.Tensor: - if do_naive_dispatch_combine and do_combine: + def combine_output(states: torch.Tensor) -> torch.Tensor: + if do_naive_dispatch_combine: states = get_ep_group().combine(states, self.is_sequence_parallel) - - if ( - not self.is_sequence_parallel - and self.reduce_results - and (self.tp_size > 1 or self.ep_size > 1) - ): - states = self.maybe_all_reduce_tensor_model_parallel(states) - return states if self.shared_experts is not None: return ( - reduce_output(final_hidden_states[0], do_combine=False), - reduce_output(final_hidden_states[1]), + final_hidden_states[0], + combine_output(final_hidden_states[1]), ) elif self.zero_expert_num is not None and self.zero_expert_num > 0: assert isinstance(final_hidden_states, torch.Tensor) - return reduce_output(final_hidden_states) + zero_expert_result + return (combine_output(final_hidden_states), zero_expert_result) else: - return reduce_output(final_hidden_states) + return combine_output(final_hidden_states) @classmethod def make_expert_params_mapping( From b30372cbd045aeac50833cd6fe6084d2edd5252c Mon Sep 17 00:00:00 2001 From: Jialin Ouyang Date: Mon, 10 Nov 2025 15:34:18 -0800 Subject: [PATCH 032/183] [Perf] Move gc.freeze logic from EngineCoreProc to EngineCore for better coverage (#27896) Signed-off-by: Jialin Ouyang --- vllm/benchmarks/serve.py | 5 ++--- vllm/distributed/parallel_state.py | 3 +++ vllm/entrypoints/openai/api_server.py | 6 ++---- vllm/utils/gc_utils.py | 15 +++++++++++++++ vllm/v1/engine/core.py | 15 ++++++++------- 5 files changed, 30 insertions(+), 14 deletions(-) diff --git a/vllm/benchmarks/serve.py b/vllm/benchmarks/serve.py index e58cf5911282..0e9b0fbe2c02 100644 --- a/vllm/benchmarks/serve.py +++ b/vllm/benchmarks/serve.py @@ -19,7 +19,6 @@ import argparse import asyncio import contextlib -import gc import importlib.util import json import os @@ -49,6 +48,7 @@ from vllm.benchmarks.lib.ready_checker import wait_for_endpoint from vllm.benchmarks.lib.utils import convert_to_pytorch_benchmark_format, write_to_json from vllm.transformers_utils.tokenizer import get_tokenizer +from vllm.utils.gc_utils import freeze_gc_heap MILLISECONDS_TO_SECONDS_CONVERSION = 1000 @@ -1414,8 +1414,7 @@ async def main_async(args: argparse.Namespace) -> dict[str, Any]: percentile_metrics: str = args.percentile_metrics or default_percentile_metrics # Avoid GC processing "static" data - reduce pause times. - gc.collect() - gc.freeze() + freeze_gc_heap() benchmark_result = await benchmark( task_type=task_type, diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index a9b01e82562b..c78e6a32733c 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -1483,6 +1483,9 @@ def destroy_distributed_environment(): def cleanup_dist_env_and_memory(shutdown_ray: bool = False): + # Ensure all objects are not freezed before cleanup + gc.unfreeze() + destroy_model_parallel() destroy_distributed_environment() if shutdown_ray: diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index c8c8d5c034d5..51191879e478 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -1,8 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project - import asyncio -import gc import hashlib import importlib import inspect @@ -118,6 +116,7 @@ from vllm.tasks import POOLING_TASKS from vllm.usage.usage_lib import UsageContext from vllm.utils.argparse_utils import FlexibleArgumentParser +from vllm.utils.gc_utils import freeze_gc_heap from vllm.utils.network_utils import is_valid_ipv6_address from vllm.utils.system_utils import decorate_logs, set_ulimit from vllm.v1.engine.exceptions import EngineDeadError @@ -153,8 +152,7 @@ async def _force_log(): # Mark the startup heap as static so that it's ignored by GC. # Reduces pause times of oldest generation collections. - gc.collect() - gc.freeze() + freeze_gc_heap() try: yield finally: diff --git a/vllm/utils/gc_utils.py b/vllm/utils/gc_utils.py index 4dd85ef26f34..160ac9ac263a 100644 --- a/vllm/utils/gc_utils.py +++ b/vllm/utils/gc_utils.py @@ -89,6 +89,21 @@ def handle(self, phase: str, info: dict[str, int]) -> None: ) +def freeze_gc_heap() -> None: + """ + Freeze all objects tracked by the garbage collector. It should be invoked + after server init / warmup, to reduce GC overhead from static objects + during serving time. + """ + # Ensure all static objects are pushed down to the oldest generation for + # freeze + gc.collect(0) + gc.collect(1) + gc.collect(2) + # Freeze all GC tracked objects + gc.freeze() + + def maybe_attach_gc_debug_callback() -> None: """ Attached a callback for GC debug when VLLM_GC_DEBUG is enabled. diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index c3efd52130cc..ffb5232e770d 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -1,6 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import gc import os import queue import signal @@ -27,7 +26,10 @@ from vllm.multimodal.cache import engine_receiver_cache_from_config from vllm.tasks import POOLING_TASKS, SupportedTask from vllm.transformers_utils.config import maybe_register_config_serialize_by_value -from vllm.utils.gc_utils import maybe_attach_gc_debug_callback +from vllm.utils.gc_utils import ( + freeze_gc_heap, + maybe_attach_gc_debug_callback, +) from vllm.utils.hashing import get_hash_fn_by_name from vllm.utils.network_utils import make_zmq_socket from vllm.utils.system_utils import decorate_logs, set_process_title @@ -197,6 +199,10 @@ def __init__( self.step if self.batch_queue is None else self.step_with_batch_queue ) + # Mark the startup heap as static so that it's ignored by GC. + # Reduces pause times of oldest generation collections. + freeze_gc_heap() + def _initialize_kv_caches( self, vllm_config: VllmConfig ) -> tuple[int, int, KVCacheConfig]: @@ -651,11 +657,6 @@ def __init__( assert addresses.coordinator_input is not None logger.info("Waiting for READY message from DP Coordinator...") - # Mark the startup heap as static so that it's ignored by GC. - # Reduces pause times of oldest generation collections. - gc.collect() - gc.freeze() - # If enable, attach GC debugger after static variable freeze. maybe_attach_gc_debug_callback() From a5a790eea6035760c71eae1861c1e5f369bc6d08 Mon Sep 17 00:00:00 2001 From: Adrian Abeyta Date: Mon, 10 Nov 2025 17:42:37 -0600 Subject: [PATCH 033/183] [Bugfix] Ensure calculated KV scales are applied in attention. (#27232) Signed-off-by: adabeyta --- .buildkite/test-pipeline.yaml | 7 +++++-- tests/compile/test_full_graph.py | 10 ++++++++-- vllm/attention/layer.py | 29 +++++++---------------------- vllm/v1/worker/gpu_model_runner.py | 19 +++++++++---------- 4 files changed, 29 insertions(+), 36 deletions(-) diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 3152cd6488f3..a0d2076199b1 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -471,8 +471,8 @@ steps: - vllm/ - tests/compile commands: - - pytest -v -s compile/test_full_graph.py - # Limit to no custom ops to reduce running time + - pytest -v -s compile/test_full_graph.py -k 'not test_fp8_kv_scale_compile' + # Limit to no custom ops to reduce running time # Wrap with quotes to escape yaml and avoid starting -k string with a - - "pytest -v -s compile/test_fusions_e2e.py -k 'TRITON and -quant_fp8'" @@ -951,10 +951,13 @@ steps: - vllm/model_executor/layers/activation.py - vllm/model_executor/layers/quantization/input_quant_fp8.py - tests/compile/test_fusions_e2e.py + - tests/compile/test_full_graph.py commands: - nvidia-smi # Run all e2e fusion tests - pytest -v -s tests/compile/test_fusions_e2e.py + # test_fp8_kv_scale_compile requires FlashAttention (not supported on default L4/L40) + - pytest -v -s tests/compile/test_full_graph.py::test_fp8_kv_scale_compile - label: Blackwell GPT-OSS Eval timeout_in_minutes: 60 diff --git a/tests/compile/test_full_graph.py b/tests/compile/test_full_graph.py index 0ad8c17d8668..71f90f6d8d3e 100644 --- a/tests/compile/test_full_graph.py +++ b/tests/compile/test_full_graph.py @@ -183,8 +183,14 @@ def test_custom_compile_config( "compilation_mode", [CompilationMode.NONE, CompilationMode.VLLM_COMPILE], ) -def test_fp8_kv_scale_compile(compilation_mode: int): - model = "Qwen/Qwen2-0.5B" +@pytest.mark.parametrize( + "model", + [ + "Qwen/Qwen2-0.5B", # Standard attention model + "deepseek-ai/DeepSeek-V2-Lite", # MLA (Multi-head Latent Attention) model + ], +) +def test_fp8_kv_scale_compile(compilation_mode: int, model: str): model_kwargs = { "quantization": "fp8", "kv_cache_dtype": "fp8_e4m3", diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 17e025155a43..96272981692c 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -745,6 +745,9 @@ def forward( k_pe: torch.Tensor, output_shape: torch.Size | None = None, ) -> torch.Tensor: + if self.calculate_kv_scales: + torch.ops.vllm.maybe_calc_kv_scales(q, kv_c_normed, k_pe, self.layer_name) + if self.use_direct_call: forward_context: ForwardContext = get_forward_context() attn_metadata = forward_context.attn_metadata @@ -752,12 +755,6 @@ def forward( attn_metadata = attn_metadata[self.layer_name] self_kv_cache = self.kv_cache[forward_context.virtual_engine] - # Mirror Attention.forward scale calculation path - if self.calculate_kv_scales and getattr( - attn_metadata, "enable_kv_scales_calculation", False - ): - self.calc_kv_scales(q, kv_c_normed, k_pe) - if self.attn_backend.accept_output_buffer: output = torch.empty(output_shape, dtype=q.dtype, device=q.device) self.impl.forward( @@ -786,14 +783,6 @@ def forward( ) return output else: - # We can still access forward context to check calculation flag - if self.calculate_kv_scales: - forward_context = get_forward_context() - attn_metadata = forward_context.attn_metadata - if isinstance(attn_metadata, dict): - attn_metadata = attn_metadata[self.layer_name] - if getattr(attn_metadata, "enable_kv_scales_calculation", False): - self.calc_kv_scales(q, kv_c_normed, k_pe) return torch.ops.vllm.unified_mla_attention( q, kv_c_normed, @@ -881,17 +870,13 @@ def maybe_calc_kv_scales( layer_name: str, ) -> None: forward_context: ForwardContext = get_forward_context() - attn_metadata = forward_context.attn_metadata - - if isinstance(attn_metadata, dict): - attn_metadata = attn_metadata[layer_name] + self = forward_context.no_compile_layers[layer_name] - if attn_metadata is None or not getattr( - attn_metadata, "enable_kv_scales_calculation", False - ): + # Only calculate if the layer's calculate_kv_scales flag is True + # This flag gets set to False after the first forward pass + if not self.calculate_kv_scales: return - self = forward_context.no_compile_layers[layer_name] self.calc_kv_scales(query, key, value) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 9403b5756e05..6fccf2ea2f47 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -279,6 +279,9 @@ def __init__( # This will be overridden in load_model() self.is_multimodal_pruning_enabled = False self.max_model_len = model_config.max_model_len + + # Always set to false after the first forward pass + self.calculate_kv_scales = self.cache_config.calculate_kv_scales self.dcp_world_size = self.parallel_config.decode_context_parallel_size self.dcp_rank = 0 if self.dcp_world_size <= 1 else get_dcp_group().rank_in_group self.max_num_tokens = scheduler_config.max_num_batched_tokens @@ -2625,16 +2628,12 @@ def execute_model( ) # Set cudagraph mode to none if calc_kv_scales is true. - if attn_metadata is not None: - metadata_list = ( - attn_metadata.values() - if isinstance(attn_metadata, dict) - else [attn_metadata] - ) - if any( - getattr(m, "enable_kv_scales_calculation", False) for m in metadata_list - ): - cudagraph_runtime_mode = CUDAGraphMode.NONE + # KV scales calculation involves dynamic operations that are incompatible + # with CUDA graph capture. + if self.calculate_kv_scales: + cudagraph_runtime_mode = CUDAGraphMode.NONE + # Mark KV scales as calculated after the first forward pass + self.calculate_kv_scales = False # Run the model. # Use persistent buffers for CUDA graphs. From 0bf29fadf5f8b28817fbccb037fb70adaef3f7f1 Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Mon, 10 Nov 2025 17:57:41 -0600 Subject: [PATCH 034/183] [Test] Remove old non-varlen FA2 test (#28420) Signed-off-by: Matthew Bonanni --- tests/kernels/attention/test_flash_attn.py | 119 --------------------- 1 file changed, 119 deletions(-) diff --git a/tests/kernels/attention/test_flash_attn.py b/tests/kernels/attention/test_flash_attn.py index 18995545552e..6e5468969bf2 100644 --- a/tests/kernels/attention/test_flash_attn.py +++ b/tests/kernels/attention/test_flash_attn.py @@ -9,7 +9,6 @@ from vllm.vllm_flash_attn import ( fa_version_unsupported_reason, flash_attn_varlen_func, - flash_attn_with_kvcache, is_fa_version_supported, ) @@ -83,124 +82,6 @@ def ref_paged_attn( return torch.cat(outputs, dim=0) -@pytest.mark.parametrize("use_out", [True, False]) -@pytest.mark.parametrize("kv_lens", [[1328, 18, 463], [1, 54, 293, 70]]) -@pytest.mark.parametrize("num_heads", NUM_HEADS) -@pytest.mark.parametrize("head_size", HEAD_SIZES) -@pytest.mark.parametrize("block_size", BLOCK_SIZES) -@pytest.mark.parametrize("dtype", DTYPES) -@pytest.mark.parametrize("soft_cap", SOFT_CAPS) -@pytest.mark.parametrize("num_blocks", NUM_BLOCKS) -@pytest.mark.parametrize("sliding_window", SLIDING_WINDOWS) -@pytest.mark.parametrize("fa_version", [2, 3]) -@pytest.mark.parametrize("q_dtype", QDTYPES) -@torch.inference_mode() -def test_flash_attn_with_paged_kv( - use_out: bool, - kv_lens: list[int], - num_heads: tuple[int, int], - head_size: int, - dtype: torch.dtype, - block_size: int, - soft_cap: float | None, - num_blocks: int, - sliding_window: int | None, - fa_version: int, - q_dtype: torch.dtype | None, -) -> None: - torch.set_default_device("cuda") - if not is_fa_version_supported(fa_version): - pytest.skip( - f"Flash attention version {fa_version} not supported due " - f'to: "{fa_version_unsupported_reason(fa_version)}"' - ) - if q_dtype is not None and (dtype != torch.bfloat16 or fa_version == 2): - pytest.skip( - "Flash attention with quantized inputs is only " - "supported on version 3 with bfloat16 base type" - ) - - current_platform.seed_everything(0) - num_seqs = len(kv_lens) - num_query_heads = num_heads[0] - num_kv_heads = num_heads[1] - assert num_query_heads % num_kv_heads == 0 - max_kv_len = max(kv_lens) - scale = head_size**-0.5 - window_size = (sliding_window - 1, 0) if sliding_window is not None else (-1, -1) - - query = torch.randn(num_seqs, num_query_heads, head_size, dtype=dtype) - key_cache = torch.randn( - num_blocks, block_size, num_kv_heads, head_size, dtype=dtype - ) - value_cache = torch.randn_like(key_cache) - kv_lens_tensor = torch.tensor(kv_lens, dtype=torch.int32) - - max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size - block_tables = torch.randint( - 0, num_blocks, (num_seqs, max_num_blocks_per_seq), dtype=torch.int32 - ) - - q = query.unsqueeze(1) - out = torch.empty_like(q) if use_out else None - - maybe_quantized_query = q - maybe_quantized_key_cache = key_cache - maybe_quantized_value_cache = value_cache - q_descale = None - k_descale = None - v_descale = None - if q_dtype is not None: - # QKV are drawn from N(0, 1): no need for a fp8 scaling factor - maybe_quantized_query = q.to(q_dtype) - maybe_quantized_key_cache = key_cache.to(q_dtype) - maybe_quantized_value_cache = value_cache.to(q_dtype) - - scale_shape = (num_seqs, num_kv_heads) - q_descale = torch.ones(scale_shape, dtype=torch.float32) - k_descale = torch.ones(scale_shape, dtype=torch.float32) - v_descale = torch.ones(scale_shape, dtype=torch.float32) - - output = flash_attn_with_kvcache( - q=maybe_quantized_query, - k_cache=maybe_quantized_key_cache, - v_cache=maybe_quantized_value_cache, - out=out, - softmax_scale=scale, - causal=True, - block_table=block_tables, - cache_seqlens=kv_lens_tensor, - softcap=soft_cap if soft_cap is not None else 0, - window_size=window_size, - fa_version=fa_version, - q_descale=q_descale, - k_descale=k_descale, - v_descale=v_descale, - ) - output = output if not use_out else out - output = output.squeeze(1) - - atol, rtol = 1.5e-2, 1e-2 - if q_dtype is not None: - atol, rtol = 1.5e-1, 1.5e-1 - - ref_output = ref_paged_attn( - query=query, - key_cache=key_cache, - value_cache=value_cache, - query_lens=[1] * num_seqs, - kv_lens=kv_lens, - block_tables=block_tables, - scale=scale, - soft_cap=soft_cap, - sliding_window=sliding_window, - ) - ( - torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol), - f"{torch.max(torch.abs(output - ref_output))}", - ) - - @pytest.mark.parametrize("use_out", [True, False]) @pytest.mark.parametrize( "seq_lens", [[(1, 1328), (5, 18), (129, 463)], [(1, 523), (1, 37), (1, 2011)]] From 35d801f13fa5bd79ae74707388b1fa4e1caf9ba5 Mon Sep 17 00:00:00 2001 From: Wentao Ye <44945378+yewentao256@users.noreply.github.com> Date: Mon, 10 Nov 2025 19:08:40 -0500 Subject: [PATCH 035/183] [Feature] Refactor batch invariant fp8 DeepGEMM (#27606) Signed-off-by: yewentao256 --- .../model_executor/layers/quantization/fp8.py | 98 +++---------------- 1 file changed, 11 insertions(+), 87 deletions(-) diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index f5fc750baaea..c7d5b251cf4e 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -43,7 +43,6 @@ QuantizationConfig, QuantizeMethodBase, ) -from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( FlashinferMoeBackend, @@ -95,11 +94,9 @@ from vllm.platforms import current_platform from vllm.scalar_type import scalar_types from vllm.utils.deep_gemm import ( - fp8_gemm_nt, get_col_major_tma_aligned_tensor, is_deep_gemm_e8m0_used, is_deep_gemm_supported, - should_use_deepgemm_for_fp8_linear, ) from vllm.utils.flashinfer import has_flashinfer_moe from vllm.utils.import_utils import has_deep_gemm @@ -554,83 +551,19 @@ def apply( # if batch invariant mode is enabled, prefer DeepGEMM FP8 path # we will use BF16 dequant when DeepGEMM is not supported. if vllm_is_batch_invariant(): - # Call is_deep_gemm_supported() ahead of time for torch.compile - # dynamo has trouble tracing through - if self.block_quant and should_use_deepgemm_for_fp8_linear( - torch.bfloat16, layer.weight, self.use_deep_gemm - ): - # use group quant consistent with block size across K - assert self.act_q_group_shape is not None - q_input, input_scale = QuantFP8( - False, - self.act_q_group_shape, - column_major_scales=True, - )(x) - - output_2d = torch.empty( - (q_input.shape[0], layer.weight.shape[0]), - dtype=torch.bfloat16, - device=q_input.device, - ) - fp8_gemm_nt( - (q_input, input_scale), - (layer.weight, layer.weight_scale), - output_2d, - ) - if bias is not None: - output_2d = output_2d + bias - return output_2d - - # Dequantize FP8 weights to BF16 - weight_fp8 = layer.weight.to(torch.bfloat16) - weight_scale = layer.weight_scale.to(torch.bfloat16) - - # Handle different quantization granularities if self.block_quant: - # Block-wise quantization: - # - Weight is NOT transposed, shape is [N, K] (output_size, input_size) - # - Scale has shape [num_blocks_k, num_blocks_n] (TRANSPOSED!) assert self.weight_block_size is not None - block_n, block_k = self.weight_block_size # Note: order is [N, K] - - N, K = weight_fp8.shape - - # determine expected number of blocks along N and K - num_blocks_n = (N + block_n - 1) // block_n - num_blocks_k = (K + block_k - 1) // block_k - - # scale layout may be [num_blocks_n, num_blocks_k] - # or [num_blocks_k, num_blocks_n] depending on backend - if weight_scale.dim() != 2: - raise RuntimeError( - f"FP8 block scale must be 2D, got {tuple(weight_scale.shape)}" - ) - - scale_rows, scale_cols = weight_scale.shape - if (scale_rows, scale_cols) == (num_blocks_k, num_blocks_n): - if num_blocks_n == num_blocks_k: - # ambiguous square case, warn and skip transpose - logger.warning( - "Batch-invariant FP8: square block-scale %dx%d; " - "skipping transpose to avoid misorientation.", - scale_rows, - scale_cols, - ) - else: - # clear KN -> transpose to NK - weight_scale = weight_scale.t() - - # Expand scale to match weight dimensions - # scale_expanded should have shape [N, K] - scale_expanded = weight_scale.repeat_interleave( - block_n, dim=0 - ).repeat_interleave(block_k, dim=1) - # Trim to exact weight size (in case of padding) - scale_expanded = scale_expanded[:N, :K] - weight_bf16 = weight_fp8 * scale_expanded + return self.w8a8_block_fp8_linear.apply( + input=x, + weight=layer.weight, + weight_scale=layer.weight_scale, + input_scale=layer.input_scale, + bias=bias, + ) else: - # Per-tensor quantization: weight IS transposed to [K, N] - # scale should be scalar or [1] or per-output-channel [N] + # per-tensor/channel: dequant to BF16 and run GEMM + weight_fp8 = layer.weight.to(torch.bfloat16) + weight_scale = layer.weight_scale.to(torch.bfloat16) if weight_scale.numel() == 1: # Per-tensor: simple scalar multiplication weight_bf16 = weight_fp8 * weight_scale @@ -649,16 +582,7 @@ def apply( else: # Fallback weight_bf16 = weight_fp8 * weight_scale - - # For block quant, weight is [N, K], for per-tensor it's [K, N] - # F.linear expects weight to be [N, K], so: - if self.block_quant: - # Already in correct shape [N, K] - output = torch.nn.functional.linear(x, weight_bf16, bias) - else: - # Need to transpose back: [K, N] -> [N, K] - output = torch.nn.functional.linear(x, weight_bf16.t(), bias) - return output + return torch.nn.functional.linear(x, weight_bf16.t(), bias) if self.use_marlin: return apply_fp8_marlin_linear( From 39029d519276fddbe0c36440e0eefcdda069b969 Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Mon, 10 Nov 2025 20:36:29 -0500 Subject: [PATCH 036/183] [CI/Test Fix] Fix CP tests on Blackwell (#28404) Signed-off-by: Lucas Wilkinson Signed-off-by: Lucas Wilkinson Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- tests/distributed/test_context_parallel.py | 12 ++++++++++++ vllm/attention/ops/common.py | 1 - 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/tests/distributed/test_context_parallel.py b/tests/distributed/test_context_parallel.py index 7f8e77a75621..3576efca591c 100644 --- a/tests/distributed/test_context_parallel.py +++ b/tests/distributed/test_context_parallel.py @@ -14,6 +14,7 @@ from typing import Literal, NamedTuple import pytest +import torch from vllm.config.model import RunnerOption from vllm.logger import init_logger @@ -254,6 +255,17 @@ def test_cp_generation( test_options: CPTestOptions, num_gpus_available, ): + if ( + model_id == "deepseek-ai/DeepSeek-V2-Lite-Chat" + and torch.cuda.get_device_capability() < (9, 0) + ): + pytest.skip(reason="MLA+DCP requires compute capability of 9.0 or higher") + if ( + model_id == "bigcode/gpt_bigcode-santacoder" + and torch.cuda.get_device_capability() != (9, 0) + ): + pytest.skip(reason="GQA+DCP currently requires compute capability of 9.0") + _compare_cp_with_tp( model_id, parallel_setup, diff --git a/vllm/attention/ops/common.py b/vllm/attention/ops/common.py index 75fdcb8f48b2..2cbb5c91cc3b 100644 --- a/vllm/attention/ops/common.py +++ b/vllm/attention/ops/common.py @@ -195,7 +195,6 @@ def cp_lse_ag_out_rs( cp_attn_lse = cp_attn_lse.contiguous() lses = cp_group.all_gather(cp_attn_lse, dim=0).view_as(lses) out, lse = correct_attn_out(cp_attn_out, lses, cp_group.rank_in_group, ctx) - assert out.is_contiguous() out = cp_group.reduce_scatter(out, dim=1) if return_lse: From de540c0354b9ecfa979c917a4599f8030d4105be Mon Sep 17 00:00:00 2001 From: Wentao Ye <44945378+yewentao256@users.noreply.github.com> Date: Mon, 10 Nov 2025 21:29:48 -0500 Subject: [PATCH 037/183] [Feature] Add env var `VLLM_MOE_USE_DEEP_GEMM` (#28422) Signed-off-by: yewentao256 --- vllm/envs.py | 6 ++++++ .../compressed_tensors/compressed_tensors_moe.py | 10 +++++++++- vllm/model_executor/layers/quantization/fp8.py | 2 +- vllm/model_executor/warmup/deep_gemm_warmup.py | 3 +++ 4 files changed, 19 insertions(+), 2 deletions(-) diff --git a/vllm/envs.py b/vllm/envs.py index 30c62e90e9fb..9421488051e5 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -147,6 +147,7 @@ VLLM_TPU_MOST_MODEL_LEN: int | None = None VLLM_TPU_USING_PATHWAYS: bool = False VLLM_USE_DEEP_GEMM: bool = True + VLLM_MOE_USE_DEEP_GEMM: bool = True VLLM_USE_DEEP_GEMM_E8M0: bool = True VLLM_DEEP_GEMM_WARMUP: Literal[ "skip", @@ -1116,6 +1117,10 @@ def get_vllm_port() -> int | None: ), # Allow use of DeepGemm kernels for fused moe ops. "VLLM_USE_DEEP_GEMM": lambda: bool(int(os.getenv("VLLM_USE_DEEP_GEMM", "1"))), + # Allow use of DeepGemm specifically for MoE fused ops (overrides only MoE). + "VLLM_MOE_USE_DEEP_GEMM": lambda: bool( + int(os.getenv("VLLM_MOE_USE_DEEP_GEMM", "1")) + ), # Whether to use E8M0 scaling when DeepGEMM is used on Blackwell GPUs. "VLLM_USE_DEEP_GEMM_E8M0": lambda: bool( int(os.getenv("VLLM_USE_DEEP_GEMM_E8M0", "1")) @@ -1569,6 +1574,7 @@ def compute_hash() -> str: "VLLM_USE_FLASHINFER_SAMPLER", "VLLM_DISABLED_KERNELS", "VLLM_USE_DEEP_GEMM", + "VLLM_MOE_USE_DEEP_GEMM", "VLLM_USE_DEEP_GEMM_E8M0", "VLLM_USE_FUSED_MOE_GROUPED_TOPK", "VLLM_USE_FLASHINFER_MOE_FP16", diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index d32ae6674ee6..59567f2ca13c 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -966,10 +966,18 @@ def select_gemm_impl( max_num_tokens=max_num_tokens_per_rank, num_dispatchers=prepare_finalize.num_dispatchers(), quant_config=self.moe_quant_config, + allow_deep_gemm=( + envs.VLLM_USE_DEEP_GEMM and envs.VLLM_MOE_USE_DEEP_GEMM + ), ) else: logger.debug("TritonOrDeepGemmExperts(%s)", self.__class__.__name__) - return TritonOrDeepGemmExperts(self.moe_quant_config, allow_deep_gemm=True) + return TritonOrDeepGemmExperts( + self.moe_quant_config, + allow_deep_gemm=( + envs.VLLM_USE_DEEP_GEMM and envs.VLLM_MOE_USE_DEEP_GEMM + ), + ) def get_fused_moe_quant_config( self, layer: torch.nn.Module diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index c7d5b251cf4e..83d136600b77 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -158,7 +158,7 @@ def get_fp8_moe_backend(block_quant: bool) -> Fp8MoeBackend: return Fp8MoeBackend.MARLIN # deepGEMM on supported platforms with block-quantized weights - if envs.VLLM_USE_DEEP_GEMM and block_quant: + if envs.VLLM_USE_DEEP_GEMM and envs.VLLM_MOE_USE_DEEP_GEMM and block_quant: if not has_deep_gemm(): logger.warning_once("DeepGEMM backend requested but not available.") elif is_deep_gemm_supported(): diff --git a/vllm/model_executor/warmup/deep_gemm_warmup.py b/vllm/model_executor/warmup/deep_gemm_warmup.py index bdcebd498ef0..e0c584df8760 100644 --- a/vllm/model_executor/warmup/deep_gemm_warmup.py +++ b/vllm/model_executor/warmup/deep_gemm_warmup.py @@ -148,6 +148,9 @@ def _fp8_linear_may_use_deep_gemm(module: torch.nn.Module) -> bool: def _fused_moe_grouped_gemm_may_use_deep_gemm(module: torch.nn.Module) -> bool: + if not (envs.VLLM_USE_DEEP_GEMM and envs.VLLM_MOE_USE_DEEP_GEMM): + return False + if not isinstance(module, FusedMoE): return False From f2d9ad0620d9aa71481527dcfafdb8357da00470 Mon Sep 17 00:00:00 2001 From: Michael Goin Date: Mon, 10 Nov 2025 19:53:24 -0700 Subject: [PATCH 038/183] Only register rocm_aiter_ops if aiter is found (#28428) Signed-off-by: mgoin --- vllm/_aiter_ops.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/_aiter_ops.py b/vllm/_aiter_ops.py index 9a4b5f3399be..8d35aa65738b 100644 --- a/vllm/_aiter_ops.py +++ b/vllm/_aiter_ops.py @@ -938,4 +938,5 @@ def shuffle_weights( return tuple(shuffle_weight(tensor, layout=layout) for tensor in tensors) -rocm_aiter_ops.register_ops_once() +if IS_AITER_FOUND: + rocm_aiter_ops.register_ops_once() From 57201a6a4c53bbd6adb9a4b702c95d5f480161d5 Mon Sep 17 00:00:00 2001 From: Xin Yang <105740670+xyang16@users.noreply.github.com> Date: Mon, 10 Nov 2025 18:57:12 -0800 Subject: [PATCH 039/183] Fix rotary embedding benchmark script (#28323) Signed-off-by: Xin Yang --- benchmarks/kernels/benchmark_rope.py | 154 +++++++++++---------------- 1 file changed, 64 insertions(+), 90 deletions(-) diff --git a/benchmarks/kernels/benchmark_rope.py b/benchmarks/kernels/benchmark_rope.py index 29ef6409bb16..074b7a440b61 100644 --- a/benchmarks/kernels/benchmark_rope.py +++ b/benchmarks/kernels/benchmark_rope.py @@ -1,97 +1,76 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from itertools import accumulate +import itertools -import nvtx import torch -from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding, get_rope -from vllm.platforms import current_platform +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.triton_utils import triton from vllm.utils.argparse_utils import FlexibleArgumentParser +batch_size_range = [2**i for i in range(0, 8, 2)] +seq_len_range = [2**i for i in range(6, 10, 1)] +num_heads_range = [32, 48] +configs = list(itertools.product(batch_size_range, seq_len_range, num_heads_range)) -def benchmark_rope_kernels_multi_lora( - is_neox_style: bool, - batch_size: int, - seq_len: int, - num_heads: int, - head_size: int, - rotary_dim: int | None, - dtype: torch.dtype, - seed: int, - device: str, - max_position: int = 8192, - base: float = 10000, -) -> None: - current_platform.seed_everything(seed) - torch.set_default_device(device) - if rotary_dim is None: - rotary_dim = head_size - # silulating serving 4 LoRAs - scaling_factors = [1, 2, 4, 8] - # batched RoPE can take multiple scaling factors - batched_rope = get_rope( - head_size, - rotary_dim, - max_position, - base, - is_neox_style, - {"rope_type": "linear", "factor": tuple(scaling_factors)}, + +def get_benchmark(head_size, rotary_dim, is_neox_style, device): + @triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["batch_size", "seq_len", "num_heads"], + x_vals=[list(_) for _ in configs], + line_arg="provider", + line_vals=["torch", "flashinfer", "vllm"], + line_names=["PyTorch", "FlashInfer", "vLLM"], + styles=[("blue", "-"), ("green", "-"), ("red", "-")], + ylabel="us", + plot_name=f"rope-perf{'-neox-style' if is_neox_style else ''}", + args={}, + ) ) - # non-batched RoPE takes only one scaling factor, we create multiple - # instances to simulate the same behavior - non_batched_ropes: list[RotaryEmbedding] = [] - for scaling_factor in scaling_factors: - non_batched_ropes.append( - get_rope( - head_size, - rotary_dim, - max_position, - base, - is_neox_style, - {"rope_type": "linear", "factor": (scaling_factor,)}, - ) + def benchmark(batch_size, seq_len, num_heads, provider): + dtype = torch.bfloat16 + max_position = 8192 + base = 10000 + rope = get_rope(head_size, rotary_dim, max_position, base, is_neox_style) + rope = rope.to(dtype=dtype, device=device) + cos_sin_cache = rope.cos_sin_cache.to(dtype=torch.float, device=device) + + positions = torch.randint(0, max_position, (batch_size, seq_len), device=device) + query = torch.randn( + (batch_size, seq_len, num_heads * head_size), dtype=dtype, device=device ) + key = torch.randn_like(query) - positions = torch.randint(0, max_position, (batch_size, seq_len)) - query = torch.randn(batch_size, seq_len, num_heads * head_size, dtype=dtype) - key = torch.randn_like(query) + quantiles = [0.5, 0.2, 0.8] - # create query offsets for batched RoPE, we concat multiple kv cache - # together and each query needs to find the right kv cache of its type - offset_map = torch.tensor( - list( - accumulate( - [0] - + [ - max_position * scaling_factor * 2 - for scaling_factor in scaling_factors[:-1] - ] + if provider == "torch": + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: rope.forward_native(positions, query.clone(), key.clone()), + quantiles=quantiles, ) - ) - ) - query_types = torch.randint( - 0, len(scaling_factors), (batch_size, seq_len), device=device - ) - # map query types to offsets - query_offsets = offset_map[query_types] - # the kernel takes flattened offsets - flatten_offsets = query_offsets.flatten() + elif provider == "flashinfer": + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: torch.ops.vllm.flashinfer_rotary_embedding( + positions, + query.clone(), + key.clone(), + head_size, + cos_sin_cache, + is_neox_style, + ), + quantiles=quantiles, + ) + else: + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: rope.forward_cuda(positions, query.clone(), key.clone()), + quantiles=quantiles, + ) + + return 1000 * ms, 1000 * max_ms, 1000 * min_ms - # batched queries of the same type together for non-batched RoPE - queries = [query[query_types == i] for i in range(len(scaling_factors))] - keys = [key[query_types == i] for i in range(len(scaling_factors))] - packed_qkr = zip(queries, keys, non_batched_ropes) - # synchronize before start timing - torch.cuda.synchronize() - with nvtx.annotate("non-batched", color="yellow"): - for q, k, r in packed_qkr: - r.forward(positions, q, k) - torch.cuda.synchronize() - with nvtx.annotate("batched", color="green"): - batched_rope.forward(positions, query, key, flatten_offsets) - torch.cuda.synchronize() + return benchmark if __name__ == "__main__": @@ -116,17 +95,12 @@ def benchmark_rope_kernels_multi_lora( parser.add_argument( "--device", type=str, choices=["cuda:0", "cuda:1"], default="cuda:0" ) + parser.add_argument("--save-path", type=str, default="./configs/rope/") args = parser.parse_args() - print(args) - benchmark_rope_kernels_multi_lora( - is_neox_style=args.is_neox_style, - batch_size=args.batch_size, - seq_len=args.seq_len, - num_heads=args.num_heads, - head_size=args.head_size, - rotary_dim=args.rotary_dim, - dtype=getattr(torch, args.dtype), - seed=args.seed, - device=args.device, + # Get the benchmark function + benchmark = get_benchmark( + args.head_size, args.rotary_dim, args.is_neox_style, args.device ) + # Run performance benchmark + benchmark.run(print_data=True, save_path=args.save_path) From 8d706cca903a008169e7ac8f1dc1f65c8ffd85c0 Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Mon, 10 Nov 2025 19:41:23 -0800 Subject: [PATCH 040/183] [Misc] FlattenLogprobs -> FlatLogprobs (#28335) --- tests/samplers/test_logprobs.py | 16 +++++-------- tests/test_logprobs.py | 40 ++++++++++++++++----------------- vllm/envs.py | 8 +++---- vllm/logprobs.py | 26 ++++++++++----------- 4 files changed, 43 insertions(+), 47 deletions(-) diff --git a/tests/samplers/test_logprobs.py b/tests/samplers/test_logprobs.py index 87f5d40ac1da..c9d227599cde 100644 --- a/tests/samplers/test_logprobs.py +++ b/tests/samplers/test_logprobs.py @@ -4,7 +4,7 @@ import pytest from vllm import SamplingParams -from vllm.logprobs import FlattenLogprobs +from vllm.logprobs import FlatLogprobs MODELS = ["distilbert/distilgpt2"] MAX_TOKENS = 5 @@ -16,17 +16,17 @@ @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("greedy", [True, False]) -@pytest.mark.parametrize("flatten_logprobs", [True, False]) +@pytest.mark.parametrize("flat_logprobs", [True, False]) def test_ranks( vllm_runner, model, dtype, greedy, - flatten_logprobs, + flat_logprobs, example_prompts, monkeypatch: pytest.MonkeyPatch, ): - monkeypatch.setenv("VLLM_FLATTEN_LOGPROBS", "1" if flatten_logprobs else "0") + monkeypatch.setenv("VLLM_FLAT_LOGPROBS", "1" if flat_logprobs else "0") with vllm_runner(model, dtype=dtype, max_logprobs=MAX_LOGPROBS) as vllm_model: tokenizer = vllm_model.llm.get_tokenizer() example_prompt_tokens = [tokenizer.encode(prompt) for prompt in example_prompts] @@ -44,12 +44,8 @@ def test_ranks( decode_tokens, _, decode_logprobs, prompt_logprobs = result # Ensure the return type of logprobs is accurate - assert isinstance( - prompt_logprobs, FlattenLogprobs if flatten_logprobs else list - ) - assert isinstance( - decode_logprobs, FlattenLogprobs if flatten_logprobs else list - ) + assert isinstance(prompt_logprobs, FlatLogprobs if flat_logprobs else list) + assert isinstance(decode_logprobs, FlatLogprobs if flat_logprobs else list) ######################## # Check prompt logprobs diff --git a/tests/test_logprobs.py b/tests/test_logprobs.py index 1799d3638178..d26a460d2bca 100644 --- a/tests/test_logprobs.py +++ b/tests/test_logprobs.py @@ -5,7 +5,7 @@ import pytest from vllm.logprobs import ( - FlattenLogprobs, + FlatLogprobs, Logprob, LogprobsOnePosition, append_logprobs_for_next_position, @@ -14,8 +14,8 @@ ) -def test_create_logprobs_non_flatten(monkeypatch: pytest.MonkeyPatch) -> None: - monkeypatch.setenv("VLLM_FLATTEN_LOGPROBS", "0") +def test_create_logprobs_non_flat(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setenv("VLLM_FLAT_LOGPROBS", "0") prompt_logprobs = create_prompt_logprobs() assert isinstance(prompt_logprobs, list) @@ -28,11 +28,11 @@ def test_create_logprobs_non_flatten(monkeypatch: pytest.MonkeyPatch) -> None: assert len(sample_logprobs) == 0 -def test_create_logprobs_flatten(monkeypatch: pytest.MonkeyPatch) -> None: - monkeypatch.setenv("VLLM_FLATTEN_LOGPROBS", "1") +def test_create_logprobs_flat(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setenv("VLLM_FLAT_LOGPROBS", "1") prompt_logprobs = create_prompt_logprobs() - assert isinstance(prompt_logprobs, FlattenLogprobs) + assert isinstance(prompt_logprobs, FlatLogprobs) assert prompt_logprobs.start_indices == [0] assert prompt_logprobs.end_indices == [0] assert len(prompt_logprobs.token_ids) == 0 @@ -44,7 +44,7 @@ def test_create_logprobs_flatten(monkeypatch: pytest.MonkeyPatch) -> None: assert prompt_logprobs[0] == dict() sample_logprobs = create_sample_logprobs() - assert isinstance(sample_logprobs, FlattenLogprobs) + assert isinstance(sample_logprobs, FlatLogprobs) assert len(sample_logprobs.start_indices) == 0 assert len(sample_logprobs.end_indices) == 0 assert len(sample_logprobs.token_ids) == 0 @@ -54,10 +54,10 @@ def test_create_logprobs_flatten(monkeypatch: pytest.MonkeyPatch) -> None: assert len(sample_logprobs) == 0 -def test_append_logprobs_for_next_position_none_flatten( +def test_append_logprobs_for_next_position_none_flat( monkeypatch: pytest.MonkeyPatch, ) -> None: - monkeypatch.setenv("VLLM_FLATTEN_LOGPROBS", "0") + monkeypatch.setenv("VLLM_FLAT_LOGPROBS", "0") logprobs = create_sample_logprobs() append_logprobs_for_next_position( logprobs, @@ -85,10 +85,10 @@ def test_append_logprobs_for_next_position_none_flatten( ] -def test_append_logprobs_for_next_position_flatten( +def test_append_logprobs_for_next_position_flat( monkeypatch: pytest.MonkeyPatch, ) -> None: - monkeypatch.setenv("VLLM_FLATTEN_LOGPROBS", "1") + monkeypatch.setenv("VLLM_FLAT_LOGPROBS", "1") logprobs = create_sample_logprobs() append_logprobs_for_next_position( logprobs, @@ -106,7 +106,7 @@ def test_append_logprobs_for_next_position_flatten( rank=11, num_logprobs=-1, ) - assert isinstance(logprobs, FlattenLogprobs) + assert isinstance(logprobs, FlatLogprobs) assert logprobs.start_indices == [0, 1] assert logprobs.end_indices == [1, 3] assert logprobs.token_ids == [1, 2, 3] @@ -129,8 +129,8 @@ def test_append_logprobs_for_next_position_flatten( } -def test_flatten_logprobs_append() -> None: - logprobs = FlattenLogprobs() +def test_flat_logprobs_append() -> None: + logprobs = FlatLogprobs() logprobs.append(LOGPROBS_ONE_POSITION_0) logprobs.append(LOGPROBS_ONE_POSITION_1) assert logprobs.start_indices == [0, 1] @@ -149,8 +149,8 @@ def test_flatten_logprobs_append() -> None: assert logprobs.decoded_tokens == ["10", "20", "30", "40", "50", "60"] -def test_flatten_logprobs_extend() -> None: - logprobs = FlattenLogprobs() +def test_flat_logprobs_extend() -> None: + logprobs = FlatLogprobs() # Extend with list[LogprobsOnePosition] logprobs.extend([LOGPROBS_ONE_POSITION_2, LOGPROBS_ONE_POSITION_0]) assert logprobs.start_indices == [0, 3] @@ -160,9 +160,9 @@ def test_flatten_logprobs_extend() -> None: assert logprobs.ranks == [40, 50, 60, 10] assert logprobs.decoded_tokens == ["40", "50", "60", "10"] - other_logprobs = FlattenLogprobs() + other_logprobs = FlatLogprobs() other_logprobs.extend([LOGPROBS_ONE_POSITION_1, LOGPROBS_ONE_POSITION_0]) - # Extend with another FlattenLogprobs + # Extend with another FlatLogprobs logprobs.extend(other_logprobs) assert logprobs.start_indices == [0, 3, 4, 6] assert logprobs.end_indices == [3, 4, 6, 7] @@ -172,8 +172,8 @@ def test_flatten_logprobs_extend() -> None: assert logprobs.decoded_tokens == ["40", "50", "60", "10", "20", "30", "10"] -def test_flatten_logprobs_access() -> None: - logprobs = FlattenLogprobs() +def test_flat_logprobs_access() -> None: + logprobs = FlatLogprobs() logprobs.extend( [LOGPROBS_ONE_POSITION_1, LOGPROBS_ONE_POSITION_2, LOGPROBS_ONE_POSITION_0] ) diff --git a/vllm/envs.py b/vllm/envs.py index 9421488051e5..52178e5f5250 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -223,7 +223,7 @@ VLLM_GC_DEBUG: str = "" VLLM_DISABLE_SHARED_EXPERTS_STREAM: bool = False VLLM_COMPILE_CACHE_SAVE_FORMAT: Literal["binary", "unpacked"] = "binary" - VLLM_FLATTEN_LOGPROBS: bool = False + VLLM_FLAT_LOGPROBS: bool = False def get_default_cache_root(): @@ -1481,11 +1481,11 @@ def get_vllm_port() -> int | None: "VLLM_COMPILE_CACHE_SAVE_FORMAT": env_with_choices( "VLLM_COMPILE_CACHE_SAVE_FORMAT", "binary", ["binary", "unpacked"] ), - # Flag to enable FlattenLogprobs whose GC overhead is significantly smaller than + # Flag to enable FlatLogprobs whose GC overhead is significantly smaller than # the original list[dict[int, Logprob]] approach. # After enabled, PromptLogprobs and SampleLogprobs would populated as - # FlattenLogprobs. - "VLLM_FLATTEN_LOGPROBS": lambda: bool(int(os.getenv("VLLM_FLATTEN_LOGPROBS", "0"))), + # FlatLogprobs. + "VLLM_FLAT_LOGPROBS": lambda: bool(int(os.getenv("VLLM_FLAT_LOGPROBS", "0"))), } # --8<-- [end:env-vars-definition] diff --git a/vllm/logprobs.py b/vllm/logprobs.py index bf66e5f75c79..a34398db2c96 100644 --- a/vllm/logprobs.py +++ b/vllm/logprobs.py @@ -30,16 +30,16 @@ class Logprob: @dataclass -class FlattenLogprobs(MutableSequence[LogprobsOnePosition]): +class FlatLogprobs(MutableSequence[LogprobsOnePosition]): """ - Flatten logprobs of a request into multiple primitive type lists. + Flat logprobs of a request into multiple primitive type lists. Compared to list[dict[int, Logprob]], this data structure reduced GC overhead significantly. As it flattened logprob information for all positions and ranks in to multiple primitive type lists (i.e. logprobs, token_ids, ranks per token_ids, decoded_tokens). So regardless of the sequence length and top_logprobs setup, - FlattenLogprobs would only introduce a constant amount of objects. + FlatLogprobs would only introduce a constant amount of objects. As each position might contains different amount of ranks, start_indices_per_position would be used to access the logprob ranges @@ -107,7 +107,7 @@ def __len__(self) -> int: def __getitem__(self, position: int) -> LogprobsOnePosition: ... @overload - def __getitem__(self, s: slice, /) -> "FlattenLogprobs": ... + def __getitem__(self, s: slice, /) -> "FlatLogprobs": ... def __getitem__(self, index: int | slice): """Extracts logprobs of a given position or slice""" @@ -123,7 +123,7 @@ def __getitem__(self, index: int | slice): elif isinstance(index, slice): min_index = self.start_indices[index][0] max_index = self.end_indices[index][-1] - return FlattenLogprobs( + return FlatLogprobs( # Shift updated start_indices and end_indices to # be 0-indexed start_indices=[i - min_index for i in self.start_indices[index]], @@ -137,13 +137,13 @@ def __getitem__(self, index: int | slice): raise TypeError(f"Invalid index type: {type(index)}") def __setitem__(self, item, value) -> None: - raise TypeError("Cannot set logprobs in FlattenLogprobs") + raise TypeError("Cannot set logprobs in FlatLogprobs") def __delitem__(self, item) -> None: - raise TypeError("Cannot delete logprobs from FlattenLogprobs") + raise TypeError("Cannot delete logprobs from FlatLogprobs") def insert(self, item) -> None: - raise TypeError("Cannot insert logprobs to FlattenLogprobs") + raise TypeError("Cannot insert logprobs to FlatLogprobs") def __iter__(self) -> Iterator[LogprobsOnePosition]: """ @@ -156,14 +156,14 @@ def __iter__(self) -> Iterator[LogprobsOnePosition]: # {token_id -> logprob} per each sequence group. None if the corresponding # sequence group doesn't require prompt logprob. -PromptLogprobs = FlattenLogprobs | list[LogprobsOnePosition | None] +PromptLogprobs = FlatLogprobs | list[LogprobsOnePosition | None] # {token_id -> logprob} for each sequence group. -SampleLogprobs = FlattenLogprobs | list[LogprobsOnePosition] +SampleLogprobs = FlatLogprobs | list[LogprobsOnePosition] def create_prompt_logprobs() -> PromptLogprobs: """Creates a container to store prompt logprobs for a request""" - logprobs = FlattenLogprobs() if envs.VLLM_FLATTEN_LOGPROBS else [] + logprobs = FlatLogprobs() if envs.VLLM_FLAT_LOGPROBS else [] # NOTE: logprob of first prompt token is None. logprobs.append(None) return logprobs @@ -171,7 +171,7 @@ def create_prompt_logprobs() -> PromptLogprobs: def create_sample_logprobs() -> SampleLogprobs: """Creates a container to store decode logprobs for a request""" - return FlattenLogprobs() if envs.VLLM_FLATTEN_LOGPROBS else [] + return FlatLogprobs() if envs.VLLM_FLAT_LOGPROBS else [] def append_logprobs_for_next_position( @@ -191,7 +191,7 @@ def append_logprobs_for_next_position( topk_ranks = range(1, num_logprobs + 1) ranks = itertools.chain((rank,), topk_ranks) - if isinstance(request_logprobs, FlattenLogprobs): + if isinstance(request_logprobs, FlatLogprobs): request_logprobs.append_fast(token_ids, logprobs, ranks, decoded_tokens) else: request_logprobs.append( From bca74e32b7ef03515cda508ba88151e2e547bdc9 Mon Sep 17 00:00:00 2001 From: Zuyi Zhao Date: Mon, 10 Nov 2025 20:57:01 -0800 Subject: [PATCH 041/183] [Frontend] Add sagemaker_standards dynamic lora adapter and stateful session management decorators to vLLM OpenAI API server (#27892) Signed-off-by: Zuyi Zhao Signed-off-by: Shen Teng Co-authored-by: Shen Teng Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com> --- requirements/common.txt | 1 + tests/entrypoints/sagemaker/__init__.py | 0 tests/entrypoints/sagemaker/conftest.py | 58 ++ .../test_sagemaker_handler_overrides.py | 734 ++++++++++++++++++ .../sagemaker/test_sagemaker_lora_adapters.py | 171 ++++ .../test_sagemaker_middleware_integration.py | 346 +++++++++ .../test_sagemaker_stateful_sessions.py | 153 ++++ vllm/entrypoints/dynamic_lora.py | 57 ++ vllm/entrypoints/openai/api_server.py | 100 +-- vllm/entrypoints/sagemaker/__init__.py | 4 + vllm/entrypoints/sagemaker/routes.py | 72 ++ 11 files changed, 1613 insertions(+), 83 deletions(-) create mode 100644 tests/entrypoints/sagemaker/__init__.py create mode 100644 tests/entrypoints/sagemaker/conftest.py create mode 100644 tests/entrypoints/sagemaker/test_sagemaker_handler_overrides.py create mode 100644 tests/entrypoints/sagemaker/test_sagemaker_lora_adapters.py create mode 100644 tests/entrypoints/sagemaker/test_sagemaker_middleware_integration.py create mode 100644 tests/entrypoints/sagemaker/test_sagemaker_stateful_sessions.py create mode 100644 vllm/entrypoints/dynamic_lora.py create mode 100644 vllm/entrypoints/sagemaker/__init__.py create mode 100644 vllm/entrypoints/sagemaker/routes.py diff --git a/requirements/common.txt b/requirements/common.txt index 8009581f62a4..90efb79a845d 100644 --- a/requirements/common.txt +++ b/requirements/common.txt @@ -49,3 +49,4 @@ cbor2 # Required for cross-language serialization of hashable objects setproctitle # Used to set process names for better debugging and monitoring openai-harmony >= 0.0.3 # Required for gpt-oss anthropic == 0.71.0 +model-hosting-container-standards < 1.0.0 \ No newline at end of file diff --git a/tests/entrypoints/sagemaker/__init__.py b/tests/entrypoints/sagemaker/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/entrypoints/sagemaker/conftest.py b/tests/entrypoints/sagemaker/conftest.py new file mode 100644 index 000000000000..4c859c2527d2 --- /dev/null +++ b/tests/entrypoints/sagemaker/conftest.py @@ -0,0 +1,58 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +"""Shared fixtures and utilities for SageMaker tests.""" + +import pytest +import pytest_asyncio + +from ...utils import RemoteOpenAIServer + +# Model name constants used across tests +MODEL_NAME_ZEPHYR = "HuggingFaceH4/zephyr-7b-beta" +MODEL_NAME_SMOLLM = "HuggingFaceTB/SmolLM2-135M-Instruct" +LORA_ADAPTER_NAME_SMOLLM = "jekunz/smollm-135m-lora-fineweb-faroese" + +# SageMaker header constants +HEADER_SAGEMAKER_CLOSED_SESSION_ID = "X-Amzn-SageMaker-Closed-Session-Id" +HEADER_SAGEMAKER_SESSION_ID = "X-Amzn-SageMaker-Session-Id" +HEADER_SAGEMAKER_NEW_SESSION_ID = "X-Amzn-SageMaker-New-Session-Id" + + +@pytest.fixture(scope="session") +def smollm2_lora_files(): + """Download LoRA files once per test session.""" + from huggingface_hub import snapshot_download + + return snapshot_download(repo_id=LORA_ADAPTER_NAME_SMOLLM) + + +@pytest.fixture(scope="module") +def basic_server_with_lora(smollm2_lora_files): + """Basic server fixture with standard configuration.""" + args = [ + "--dtype", + "bfloat16", + "--max-model-len", + "8192", + "--enforce-eager", + # lora config below + "--enable-lora", + "--max-lora-rank", + "256", + "--max-cpu-loras", + "2", + "--max-num-seqs", + "64", + ] + + envs = {"VLLM_ALLOW_RUNTIME_LORA_UPDATING": "True"} + with RemoteOpenAIServer(MODEL_NAME_SMOLLM, args, env_dict=envs) as remote_server: + yield remote_server + + +@pytest_asyncio.fixture +async def async_client(basic_server_with_lora: RemoteOpenAIServer): + """Async OpenAI client fixture for use with basic_server.""" + async with basic_server_with_lora.get_async_client() as async_client: + yield async_client diff --git a/tests/entrypoints/sagemaker/test_sagemaker_handler_overrides.py b/tests/entrypoints/sagemaker/test_sagemaker_handler_overrides.py new file mode 100644 index 000000000000..0d4f8e885824 --- /dev/null +++ b/tests/entrypoints/sagemaker/test_sagemaker_handler_overrides.py @@ -0,0 +1,734 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +"""Integration tests for handler override functionality. + +Tests real customer usage scenarios: +- Using @custom_ping_handler and @custom_invocation_handler decorators + to override handlers +- Setting environment variables for handler specifications +- Writing customer scripts with custom_sagemaker_ping_handler() and + custom_sagemaker_invocation_handler() functions +- Priority: env vars > decorators > customer script files > framework + defaults + +Note: These tests focus on validating server responses rather than directly calling +get_ping_handler() and get_invoke_handler() to ensure full integration testing. +""" + +import os +import tempfile + +import pytest +import requests + +from ...utils import RemoteOpenAIServer +from .conftest import ( + MODEL_NAME_SMOLLM, +) + + +class TestHandlerOverrideIntegration: + """Integration tests simulating real customer usage scenarios. + + Each test simulates a fresh server startup where customers: + - Use @custom_ping_handler and @custom_invocation_handler decorators + - Set environment variables (CUSTOM_FASTAPI_PING_HANDLER, etc.) + - Write customer scripts with custom_sagemaker_ping_handler() and + custom_sagemaker_invocation_handler() functions + """ + + def setup_method(self): + """Setup for each test - simulate fresh server startup.""" + self._clear_caches() + self._clear_env_vars() + + def teardown_method(self): + """Cleanup after each test.""" + self._clear_env_vars() + + def _clear_caches(self): + """Clear handler registry and function loader cache.""" + try: + from model_hosting_container_standards.common.handler import ( + handler_registry, + ) + from model_hosting_container_standards.sagemaker.sagemaker_loader import ( + SageMakerFunctionLoader, + ) + + handler_registry.clear() + SageMakerFunctionLoader._default_function_loader = None + except ImportError: + pytest.skip("model-hosting-container-standards not available") + + def _clear_env_vars(self): + """Clear SageMaker environment variables.""" + try: + from model_hosting_container_standards.common.fastapi.config import ( + FastAPIEnvVars, + ) + from model_hosting_container_standards.sagemaker.config import ( + SageMakerEnvVars, + ) + + # Clear SageMaker env vars + for var in [ + SageMakerEnvVars.SAGEMAKER_MODEL_PATH, + SageMakerEnvVars.CUSTOM_SCRIPT_FILENAME, + ]: + os.environ.pop(var, None) + + # Clear FastAPI env vars + for var in [ + FastAPIEnvVars.CUSTOM_FASTAPI_PING_HANDLER, + FastAPIEnvVars.CUSTOM_FASTAPI_INVOCATION_HANDLER, + ]: + os.environ.pop(var, None) + except ImportError: + pass + + @pytest.mark.asyncio + async def test_customer_script_functions_auto_loaded(self): + """Test customer scenario: script functions automatically override + framework defaults.""" + try: + from model_hosting_container_standards.sagemaker.config import ( + SageMakerEnvVars, + ) + except ImportError: + pytest.skip("model-hosting-container-standards not available") + + # Customer writes a script file with ping() and invoke() functions + with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f: + f.write( + """ +from fastapi import Request + +async def custom_sagemaker_ping_handler(): + return { + "status": "healthy", + "source": "customer_override", + "message": "Custom ping from customer script" + } + +async def custom_sagemaker_invocation_handler(request: Request): + return { + "predictions": ["Custom response from customer script"], + "source": "customer_override" + } +""" + ) + script_path = f.name + + try: + script_dir = os.path.dirname(script_path) + script_name = os.path.basename(script_path) + + # Customer sets SageMaker environment variables to point to their script + env_vars = { + SageMakerEnvVars.SAGEMAKER_MODEL_PATH: script_dir, + SageMakerEnvVars.CUSTOM_SCRIPT_FILENAME: script_name, + } + + args = [ + "--dtype", + "bfloat16", + "--max-model-len", + "2048", + "--enforce-eager", + "--max-num-seqs", + "32", + ] + + with RemoteOpenAIServer( + MODEL_NAME_SMOLLM, args, env_dict=env_vars + ) as server: + # Customer tests their server and sees their overrides work + # automatically + ping_response = requests.get(server.url_for("ping")) + assert ping_response.status_code == 200 + ping_data = ping_response.json() + + invoke_response = requests.post( + server.url_for("invocations"), + json={ + "model": MODEL_NAME_SMOLLM, + "messages": [{"role": "user", "content": "Hello"}], + "max_tokens": 5, + }, + ) + assert invoke_response.status_code == 200 + invoke_data = invoke_response.json() + + # Customer sees their functions are used + assert ping_data["source"] == "customer_override" + assert ping_data["message"] == "Custom ping from customer script" + assert invoke_data["source"] == "customer_override" + assert invoke_data["predictions"] == [ + "Custom response from customer script" + ] + + finally: + os.unlink(script_path) + + @pytest.mark.asyncio + async def test_customer_decorator_usage(self): + """Test customer scenario: using @custom_ping_handler and + @custom_invocation_handler decorators.""" + try: + from model_hosting_container_standards.sagemaker.config import ( + SageMakerEnvVars, + ) + except ImportError: + pytest.skip("model-hosting-container-standards not available") + + # Customer writes a script file with decorators + with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f: + f.write( + """ +import model_hosting_container_standards.sagemaker as sagemaker_standards +from fastapi import Request + +@sagemaker_standards.custom_ping_handler +async def my_ping(): + return { + "type": "ping", + "source": "customer_decorator" + } + +@sagemaker_standards.custom_invocation_handler +async def my_invoke(request: Request): + return { + "type": "invoke", + "source": "customer_decorator" + } +""" + ) + script_path = f.name + + try: + script_dir = os.path.dirname(script_path) + script_name = os.path.basename(script_path) + + env_vars = { + SageMakerEnvVars.SAGEMAKER_MODEL_PATH: script_dir, + SageMakerEnvVars.CUSTOM_SCRIPT_FILENAME: script_name, + } + + args = [ + "--dtype", + "bfloat16", + "--max-model-len", + "2048", + "--enforce-eager", + "--max-num-seqs", + "32", + ] + + with RemoteOpenAIServer( + MODEL_NAME_SMOLLM, args, env_dict=env_vars + ) as server: + ping_response = requests.get(server.url_for("ping")) + assert ping_response.status_code == 200 + ping_data = ping_response.json() + + invoke_response = requests.post( + server.url_for("invocations"), + json={ + "model": MODEL_NAME_SMOLLM, + "messages": [{"role": "user", "content": "Hello"}], + "max_tokens": 5, + }, + ) + assert invoke_response.status_code == 200 + invoke_data = invoke_response.json() + + # Customer sees their handlers are used by the server + assert ping_data["source"] == "customer_decorator" + assert invoke_data["source"] == "customer_decorator" + + finally: + os.unlink(script_path) + + @pytest.mark.asyncio + async def test_handler_priority_order(self): + """Test priority: @custom_ping_handler/@custom_invocation_handler + decorators vs script functions.""" + try: + from model_hosting_container_standards.sagemaker.config import ( + SageMakerEnvVars, + ) + except ImportError: + pytest.skip("model-hosting-container-standards not available") + + # Customer writes a script with both decorator and regular functions + with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f: + f.write( + """ +import model_hosting_container_standards.sagemaker as sagemaker_standards +from fastapi import Request + +# Customer uses @custom_ping_handler decorator (higher priority than script functions) +@sagemaker_standards.custom_ping_handler +async def decorated_ping(): + return { + "status": "healthy", + "source": "ping_decorator_in_script", + "priority": "decorator" + } + +# Customer also has a regular function (lower priority than +# @custom_ping_handler decorator) +async def custom_sagemaker_ping_handler(): + return { + "status": "healthy", + "source": "script_function", + "priority": "function" + } + +# Customer has a regular invoke function +async def custom_sagemaker_invocation_handler(request: Request): + return { + "predictions": ["Script function response"], + "source": "script_invoke_function", + "priority": "function" + } +""" + ) + script_path = f.name + + try: + script_dir = os.path.dirname(script_path) + script_name = os.path.basename(script_path) + + env_vars = { + SageMakerEnvVars.SAGEMAKER_MODEL_PATH: script_dir, + SageMakerEnvVars.CUSTOM_SCRIPT_FILENAME: script_name, + } + + args = [ + "--dtype", + "bfloat16", + "--max-model-len", + "2048", + "--enforce-eager", + "--max-num-seqs", + "32", + ] + + with RemoteOpenAIServer( + MODEL_NAME_SMOLLM, args, env_dict=env_vars + ) as server: + ping_response = requests.get(server.url_for("ping")) + assert ping_response.status_code == 200 + ping_data = ping_response.json() + + invoke_response = requests.post( + server.url_for("invocations"), + json={ + "model": MODEL_NAME_SMOLLM, + "messages": [{"role": "user", "content": "Hello"}], + "max_tokens": 5, + }, + ) + assert invoke_response.status_code == 200 + invoke_data = invoke_response.json() + + # @custom_ping_handler decorator has higher priority than + # script function + assert ping_data["source"] == "ping_decorator_in_script" + assert ping_data["priority"] == "decorator" + + # Script function is used for invoke + assert invoke_data["source"] == "script_invoke_function" + assert invoke_data["priority"] == "function" + + finally: + os.unlink(script_path) + + @pytest.mark.asyncio + async def test_environment_variable_script_loading(self): + """Test that environment variables correctly specify script location + and loading.""" + try: + from model_hosting_container_standards.sagemaker.config import ( + SageMakerEnvVars, + ) + except ImportError: + pytest.skip("model-hosting-container-standards not available") + + # Customer writes a script in a specific directory + with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f: + f.write( + """ +from fastapi import Request + +async def custom_sagemaker_ping_handler(): + return { + "status": "healthy", + "source": "env_loaded_script", + "method": "environment_variable_loading" + } + +async def custom_sagemaker_invocation_handler(request: Request): + return { + "predictions": ["Loaded via environment variables"], + "source": "env_loaded_script", + "method": "environment_variable_loading" + } +""" + ) + script_path = f.name + + try: + script_dir = os.path.dirname(script_path) + script_name = os.path.basename(script_path) + + # Test environment variable script loading + env_vars = { + SageMakerEnvVars.SAGEMAKER_MODEL_PATH: script_dir, + SageMakerEnvVars.CUSTOM_SCRIPT_FILENAME: script_name, + } + + args = [ + "--dtype", + "bfloat16", + "--max-model-len", + "2048", + "--enforce-eager", + "--max-num-seqs", + "32", + ] + + with RemoteOpenAIServer( + MODEL_NAME_SMOLLM, args, env_dict=env_vars + ) as server: + ping_response = requests.get(server.url_for("ping")) + assert ping_response.status_code == 200 + ping_data = ping_response.json() + + invoke_response = requests.post( + server.url_for("invocations"), + json={ + "model": MODEL_NAME_SMOLLM, + "messages": [{"role": "user", "content": "Hello"}], + "max_tokens": 5, + }, + ) + assert invoke_response.status_code == 200 + invoke_data = invoke_response.json() + + # Verify that the script was loaded via environment variables + assert ping_data["source"] == "env_loaded_script" + assert ping_data["method"] == "environment_variable_loading" + assert invoke_data["source"] == "env_loaded_script" + assert invoke_data["method"] == "environment_variable_loading" + + finally: + os.unlink(script_path) + + @pytest.mark.asyncio + async def test_framework_default_handlers(self): + """Test that framework default handlers work when no customer + overrides exist.""" + args = [ + "--dtype", + "bfloat16", + "--max-model-len", + "2048", + "--enforce-eager", + "--max-num-seqs", + "32", + ] + + # Explicitly pass empty env_dict to ensure no SageMaker env vars are set + # This prevents pollution from previous tests + try: + from model_hosting_container_standards.common.fastapi.config import ( + FastAPIEnvVars, + ) + from model_hosting_container_standards.sagemaker.config import ( + SageMakerEnvVars, + ) + + env_dict = { + SageMakerEnvVars.SAGEMAKER_MODEL_PATH: "", + SageMakerEnvVars.CUSTOM_SCRIPT_FILENAME: "", + FastAPIEnvVars.CUSTOM_FASTAPI_PING_HANDLER: "", + FastAPIEnvVars.CUSTOM_FASTAPI_INVOCATION_HANDLER: "", + } + except ImportError: + env_dict = {} + + with RemoteOpenAIServer(MODEL_NAME_SMOLLM, args, env_dict=env_dict) as server: + # Test that default ping works + ping_response = requests.get(server.url_for("ping")) + assert ping_response.status_code == 200 + + # Test that default invocations work + invoke_response = requests.post( + server.url_for("invocations"), + json={ + "model": MODEL_NAME_SMOLLM, + "messages": [{"role": "user", "content": "Hello"}], + "max_tokens": 5, + }, + ) + assert invoke_response.status_code == 200 + + @pytest.mark.asyncio + async def test_handler_env_var_override(self): + """Test CUSTOM_FASTAPI_PING_HANDLER and CUSTOM_FASTAPI_INVOCATION_HANDLER + environment variable overrides.""" + try: + from model_hosting_container_standards.common.fastapi.config import ( + FastAPIEnvVars, + ) + from model_hosting_container_standards.sagemaker.config import ( + SageMakerEnvVars, + ) + except ImportError: + pytest.skip("model-hosting-container-standards not available") + + # Create a script with both env var handlers and script functions + with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f: + f.write( + """ +from fastapi import Request, Response +import json + +async def env_var_ping_handler(raw_request: Request) -> Response: + return Response( + content=json.dumps({ + "status": "healthy", + "source": "env_var_ping", + "method": "environment_variable" + }), + media_type="application/json" + ) + +async def env_var_invoke_handler(raw_request: Request) -> Response: + return Response( + content=json.dumps({ + "predictions": ["Environment variable response"], + "source": "env_var_invoke", + "method": "environment_variable" + }), + media_type="application/json" + ) + +async def custom_sagemaker_ping_handler(): + return { + "status": "healthy", + "source": "script_ping", + "method": "script_function" + } + +async def custom_sagemaker_invocation_handler(request: Request): + return { + "predictions": ["Script function response"], + "source": "script_invoke", + "method": "script_function" + } +""" + ) + script_path = f.name + + try: + script_dir = os.path.dirname(script_path) + script_name = os.path.basename(script_path) + + # Set environment variables to override both handlers + env_vars = { + SageMakerEnvVars.SAGEMAKER_MODEL_PATH: script_dir, + SageMakerEnvVars.CUSTOM_SCRIPT_FILENAME: script_name, + FastAPIEnvVars.CUSTOM_FASTAPI_PING_HANDLER: ( + f"{script_name}:env_var_ping_handler" + ), + FastAPIEnvVars.CUSTOM_FASTAPI_INVOCATION_HANDLER: ( + f"{script_name}:env_var_invoke_handler" + ), + } + + args = [ + "--dtype", + "bfloat16", + "--max-model-len", + "2048", + "--enforce-eager", + "--max-num-seqs", + "32", + ] + + with RemoteOpenAIServer( + MODEL_NAME_SMOLLM, args, env_dict=env_vars + ) as server: + # Test ping handler override + ping_response = requests.get(server.url_for("ping")) + assert ping_response.status_code == 200 + ping_data = ping_response.json() + + # Environment variable should override script function + assert ping_data["method"] == "environment_variable" + assert ping_data["source"] == "env_var_ping" + + # Test invocation handler override + invoke_response = requests.post( + server.url_for("invocations"), + json={ + "model": MODEL_NAME_SMOLLM, + "messages": [{"role": "user", "content": "Hello"}], + "max_tokens": 5, + }, + ) + assert invoke_response.status_code == 200 + invoke_data = invoke_response.json() + + # Environment variable should override script function + assert invoke_data["method"] == "environment_variable" + assert invoke_data["source"] == "env_var_invoke" + + finally: + os.unlink(script_path) + + @pytest.mark.asyncio + async def test_env_var_priority_over_decorator_and_script(self): + """Test that environment variables have highest priority over decorators + and script functions for both ping and invocation handlers.""" + try: + from model_hosting_container_standards.common.fastapi.config import ( + FastAPIEnvVars, + ) + from model_hosting_container_standards.sagemaker.config import ( + SageMakerEnvVars, + ) + except ImportError: + pytest.skip("model-hosting-container-standards not available") + + # Create a script with all three handler types for both ping and invocation + with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f: + f.write( + """ +import model_hosting_container_standards.sagemaker as sagemaker_standards +from fastapi import Request, Response +import json + +# Environment variable handlers (highest priority) +async def env_priority_ping(raw_request: Request) -> Response: + return Response( + content=json.dumps({ + "status": "healthy", + "source": "env_var", + "priority": "environment_variable" + }), + media_type="application/json" + ) + +async def env_priority_invoke(raw_request: Request) -> Response: + return Response( + content=json.dumps({ + "predictions": ["Environment variable response"], + "source": "env_var", + "priority": "environment_variable" + }), + media_type="application/json" + ) + +# Decorator handlers (medium priority) +@sagemaker_standards.custom_ping_handler +async def decorator_ping(raw_request: Request) -> Response: + return Response( + content=json.dumps({ + "status": "healthy", + "source": "decorator", + "priority": "decorator" + }), + media_type="application/json" + ) + +@sagemaker_standards.custom_invocation_handler +async def decorator_invoke(raw_request: Request) -> Response: + return Response( + content=json.dumps({ + "predictions": ["Decorator response"], + "source": "decorator", + "priority": "decorator" + }), + media_type="application/json" + ) + +# Script functions (lowest priority) +async def custom_sagemaker_ping_handler(): + return { + "status": "healthy", + "source": "script", + "priority": "script_function" + } + +async def custom_sagemaker_invocation_handler(request: Request): + return { + "predictions": ["Script function response"], + "source": "script", + "priority": "script_function" + } +""" + ) + script_path = f.name + + try: + script_dir = os.path.dirname(script_path) + script_name = os.path.basename(script_path) + + # Set environment variables to specify highest priority handlers + env_vars = { + SageMakerEnvVars.SAGEMAKER_MODEL_PATH: script_dir, + SageMakerEnvVars.CUSTOM_SCRIPT_FILENAME: script_name, + FastAPIEnvVars.CUSTOM_FASTAPI_PING_HANDLER: ( + f"{script_name}:env_priority_ping" + ), + FastAPIEnvVars.CUSTOM_FASTAPI_INVOCATION_HANDLER: ( + f"{script_name}:env_priority_invoke" + ), + } + + args = [ + "--dtype", + "bfloat16", + "--max-model-len", + "2048", + "--enforce-eager", + "--max-num-seqs", + "32", + ] + + with RemoteOpenAIServer( + MODEL_NAME_SMOLLM, args, env_dict=env_vars + ) as server: + # Test ping handler priority + ping_response = requests.get(server.url_for("ping")) + assert ping_response.status_code == 200 + ping_data = ping_response.json() + + # Environment variable has highest priority and should be used + assert ping_data["priority"] == "environment_variable" + assert ping_data["source"] == "env_var" + + # Test invocation handler priority + invoke_response = requests.post( + server.url_for("invocations"), + json={ + "model": MODEL_NAME_SMOLLM, + "messages": [{"role": "user", "content": "Hello"}], + "max_tokens": 5, + }, + ) + assert invoke_response.status_code == 200 + invoke_data = invoke_response.json() + + # Environment variable has highest priority and should be used + assert invoke_data["priority"] == "environment_variable" + assert invoke_data["source"] == "env_var" + + finally: + os.unlink(script_path) diff --git a/tests/entrypoints/sagemaker/test_sagemaker_lora_adapters.py b/tests/entrypoints/sagemaker/test_sagemaker_lora_adapters.py new file mode 100644 index 000000000000..a2867efdc584 --- /dev/null +++ b/tests/entrypoints/sagemaker/test_sagemaker_lora_adapters.py @@ -0,0 +1,171 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import openai # use the official async_client for correctness check +import pytest +import requests + +from ...utils import RemoteOpenAIServer +from .conftest import MODEL_NAME_SMOLLM + + +@pytest.mark.asyncio +async def test_sagemaker_load_adapter_happy_path( + async_client: openai.AsyncOpenAI, + basic_server_with_lora: RemoteOpenAIServer, + smollm2_lora_files, +): + # The SageMaker standards library creates a POST /adapters endpoint + # that maps to the load_lora_adapter handler with request shape: + # {"lora_name": "body.name", "lora_path": "body.src"} + load_response = requests.post( + basic_server_with_lora.url_for("adapters"), + json={"name": "smollm2-lora-sagemaker", "src": smollm2_lora_files}, + ) + load_response.raise_for_status() + + models = await async_client.models.list() + models = models.data + dynamic_lora_model = models[-1] + assert dynamic_lora_model.root == smollm2_lora_files + assert dynamic_lora_model.parent == MODEL_NAME_SMOLLM + assert dynamic_lora_model.id == "smollm2-lora-sagemaker" + + +@pytest.mark.asyncio +async def test_sagemaker_unload_adapter_happy_path( + async_client: openai.AsyncOpenAI, + basic_server_with_lora: RemoteOpenAIServer, + smollm2_lora_files, +): + # First, load an adapter + adapter_name = "smollm2-lora-sagemaker-unload" + load_response = requests.post( + basic_server_with_lora.url_for("adapters"), + json={"name": adapter_name, "src": smollm2_lora_files}, + ) + load_response.raise_for_status() + + # Verify it's in the models list + models = await async_client.models.list() + adapter_ids = [model.id for model in models.data] + assert adapter_name in adapter_ids + + # Now unload it using DELETE /adapters/{adapter_name} + # The SageMaker standards maps this to unload_lora_adapter with: + # {"lora_name": "path_params.adapter_name"} + unload_response = requests.delete( + basic_server_with_lora.url_for("adapters", adapter_name), + ) + unload_response.raise_for_status() + + # Verify it's no longer in the models list + models = await async_client.models.list() + adapter_ids = [model.id for model in models.data] + assert adapter_name not in adapter_ids + + +@pytest.mark.asyncio +async def test_sagemaker_load_adapter_not_found( + basic_server_with_lora: RemoteOpenAIServer, +): + load_response = requests.post( + basic_server_with_lora.url_for("adapters"), + json={"name": "nonexistent-adapter", "src": "/path/does/not/exist"}, + ) + assert load_response.status_code == 404 + + +@pytest.mark.asyncio +async def test_sagemaker_load_adapter_invalid_files( + basic_server_with_lora: RemoteOpenAIServer, + tmp_path, +): + invalid_files = tmp_path / "invalid_adapter" + invalid_files.mkdir() + (invalid_files / "adapter_config.json").write_text("not valid json") + + load_response = requests.post( + basic_server_with_lora.url_for("adapters"), + json={"name": "invalid-adapter", "src": str(invalid_files)}, + ) + assert load_response.status_code == 400 + + +@pytest.mark.asyncio +async def test_sagemaker_unload_nonexistent_adapter( + basic_server_with_lora: RemoteOpenAIServer, +): + # Attempt to unload an adapter that doesn't exist + unload_response = requests.delete( + basic_server_with_lora.url_for("adapters", "nonexistent-adapter-name"), + ) + assert unload_response.status_code in (400, 404) + + +@pytest.mark.asyncio +async def test_sagemaker_invocations_with_adapter( + basic_server_with_lora: RemoteOpenAIServer, + smollm2_lora_files, +): + # First, load an adapter via SageMaker endpoint + adapter_name = "smollm2-lora-invoke-test" + load_response = requests.post( + basic_server_with_lora.url_for("adapters"), + json={"name": adapter_name, "src": smollm2_lora_files}, + ) + load_response.raise_for_status() + + # Now test the /invocations endpoint with the adapter + invocation_response = requests.post( + basic_server_with_lora.url_for("invocations"), + headers={ + "X-Amzn-SageMaker-Adapter-Identifier": adapter_name, + }, + json={ + "prompt": "Hello, how are you?", + "max_tokens": 10, + }, + ) + invocation_response.raise_for_status() + invocation_output = invocation_response.json() + + # Verify we got a valid completion response + assert "choices" in invocation_output + assert len(invocation_output["choices"]) > 0 + assert "text" in invocation_output["choices"][0] + + +@pytest.mark.asyncio +async def test_sagemaker_multiple_adapters_load_unload( + async_client: openai.AsyncOpenAI, + basic_server_with_lora: RemoteOpenAIServer, + smollm2_lora_files, +): + adapter_names = [f"sagemaker-adapter-{i}" for i in range(5)] + + # Load all adapters + for adapter_name in adapter_names: + load_response = requests.post( + basic_server_with_lora.url_for("adapters"), + json={"name": adapter_name, "src": smollm2_lora_files}, + ) + load_response.raise_for_status() + + # Verify all are in the models list + models = await async_client.models.list() + adapter_ids = [model.id for model in models.data] + for adapter_name in adapter_names: + assert adapter_name in adapter_ids + + # Unload all adapters + for adapter_name in adapter_names: + unload_response = requests.delete( + basic_server_with_lora.url_for("adapters", adapter_name), + ) + unload_response.raise_for_status() + + # Verify all are removed from models list + models = await async_client.models.list() + adapter_ids = [model.id for model in models.data] + for adapter_name in adapter_names: + assert adapter_name not in adapter_ids diff --git a/tests/entrypoints/sagemaker/test_sagemaker_middleware_integration.py b/tests/entrypoints/sagemaker/test_sagemaker_middleware_integration.py new file mode 100644 index 000000000000..f1ed0c7e2897 --- /dev/null +++ b/tests/entrypoints/sagemaker/test_sagemaker_middleware_integration.py @@ -0,0 +1,346 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +"""Integration test for middleware loader functionality. + +Tests that customer middlewares get called correctly with a vLLM server. +""" + +import os +import tempfile + +import pytest +import requests + +from ...utils import RemoteOpenAIServer +from .conftest import ( + MODEL_NAME_SMOLLM, +) + + +class TestMiddlewareIntegration: + """Integration test for middleware with vLLM server.""" + + def setup_method(self): + """Setup for each test - simulate fresh server startup.""" + self._clear_caches() + + def _clear_caches(self): + """Clear middleware registry and function loader cache.""" + try: + from model_hosting_container_standards.common.fastapi.middleware import ( + middleware_registry, + ) + from model_hosting_container_standards.common.fastapi.middleware.source.decorator_loader import ( # noqa: E501 + decorator_loader, + ) + from model_hosting_container_standards.sagemaker.sagemaker_loader import ( + SageMakerFunctionLoader, + ) + + middleware_registry.clear_middlewares() + decorator_loader.clear() + SageMakerFunctionLoader._default_function_loader = None + except ImportError: + pytest.skip("model-hosting-container-standards not available") + + @pytest.mark.asyncio + async def test_customer_middleware_with_vllm_server(self): + """Test that customer middlewares work with actual vLLM server. + + Tests decorator-based middlewares (@custom_middleware, @input_formatter, + @output_formatter) + on multiple endpoints (chat/completions, invocations). + """ + try: + from model_hosting_container_standards.sagemaker.config import ( + SageMakerEnvVars, + ) + except ImportError: + pytest.skip("model-hosting-container-standards not available") + + # Customer writes a middleware script with multiple decorators + with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f: + f.write( + """ +from model_hosting_container_standards.common.fastapi.middleware import ( + custom_middleware, input_formatter, output_formatter +) + +# Global flag to track if input formatter was called +_input_formatter_called = False + +@input_formatter +async def customer_input_formatter(request): + # Process input - mark that input formatter was called + global _input_formatter_called + _input_formatter_called = True + return request + +@custom_middleware("throttle") +async def customer_throttle_middleware(request, call_next): + response = await call_next(request) + response.headers["X-Customer-Throttle"] = "applied" + order = response.headers.get("X-Middleware-Order", "") + response.headers["X-Middleware-Order"] = order + "throttle," + return response + +@output_formatter +async def customer_output_formatter(response): + global _input_formatter_called + response.headers["X-Customer-Processed"] = "true" + # Since input_formatter and output_formatter are combined into + # pre_post_process middleware, + # if output_formatter is called, input_formatter should have been called too + if _input_formatter_called: + response.headers["X-Input-Formatter-Called"] = "true" + order = response.headers.get("X-Middleware-Order", "") + response.headers["X-Middleware-Order"] = order + "output_formatter," + return response +""" + ) + script_path = f.name + + try: + script_dir = os.path.dirname(script_path) + script_name = os.path.basename(script_path) + + # Set environment variables to point to customer script + env_vars = { + SageMakerEnvVars.SAGEMAKER_MODEL_PATH: script_dir, + SageMakerEnvVars.CUSTOM_SCRIPT_FILENAME: script_name, + } + + args = [ + "--dtype", + "bfloat16", + "--max-model-len", + "2048", + "--enforce-eager", + "--max-num-seqs", + "32", + ] + + with RemoteOpenAIServer( + MODEL_NAME_SMOLLM, args, env_dict=env_vars + ) as server: + # Test 1: Middlewares applied to chat/completions endpoint + chat_response = requests.post( + server.url_for("v1/chat/completions"), + json={ + "model": MODEL_NAME_SMOLLM, + "messages": [{"role": "user", "content": "Hello"}], + "max_tokens": 5, + "temperature": 0.0, + }, + ) + + assert chat_response.status_code == 200 + + # Verify all middlewares were executed + assert "X-Customer-Throttle" in chat_response.headers + assert chat_response.headers["X-Customer-Throttle"] == "applied" + assert "X-Customer-Processed" in chat_response.headers + assert chat_response.headers["X-Customer-Processed"] == "true" + + # Verify input formatter was called + assert "X-Input-Formatter-Called" in chat_response.headers + assert chat_response.headers["X-Input-Formatter-Called"] == "true" + + # Verify middleware execution order + execution_order = chat_response.headers.get( + "X-Middleware-Order", "" + ).rstrip(",") + order_parts = execution_order.split(",") if execution_order else [] + assert "throttle" in order_parts + assert "output_formatter" in order_parts + + # Test 2: Middlewares applied to invocations endpoint + invocations_response = requests.post( + server.url_for("invocations"), + json={ + "model": MODEL_NAME_SMOLLM, + "messages": [{"role": "user", "content": "Hello"}], + "max_tokens": 5, + "temperature": 0.0, + }, + ) + + assert invocations_response.status_code == 200 + + # Verify all middlewares were executed + assert "X-Customer-Throttle" in invocations_response.headers + assert invocations_response.headers["X-Customer-Throttle"] == "applied" + assert "X-Customer-Processed" in invocations_response.headers + assert invocations_response.headers["X-Customer-Processed"] == "true" + + # Verify input formatter was called + assert "X-Input-Formatter-Called" in invocations_response.headers + assert ( + invocations_response.headers["X-Input-Formatter-Called"] == "true" + ) + + finally: + os.unlink(script_path) + + @pytest.mark.asyncio + async def test_middleware_with_ping_endpoint(self): + """Test that middlewares work with SageMaker ping endpoint.""" + try: + from model_hosting_container_standards.sagemaker.config import ( + SageMakerEnvVars, + ) + except ImportError: + pytest.skip("model-hosting-container-standards not available") + + # Customer writes a middleware script + with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f: + f.write( + """ +from model_hosting_container_standards.common.fastapi.middleware import ( + custom_middleware +) + +@custom_middleware("pre_post_process") +async def ping_tracking_middleware(request, call_next): + response = await call_next(request) + if request.url.path == "/ping": + response.headers["X-Ping-Tracked"] = "true" + return response +""" + ) + script_path = f.name + + try: + script_dir = os.path.dirname(script_path) + script_name = os.path.basename(script_path) + + env_vars = { + SageMakerEnvVars.SAGEMAKER_MODEL_PATH: script_dir, + SageMakerEnvVars.CUSTOM_SCRIPT_FILENAME: script_name, + } + + args = [ + "--dtype", + "bfloat16", + "--max-model-len", + "2048", + "--enforce-eager", + "--max-num-seqs", + "32", + ] + + with RemoteOpenAIServer( + MODEL_NAME_SMOLLM, args, env_dict=env_vars + ) as server: + # Test ping endpoint with middleware + response = requests.get(server.url_for("ping")) + + assert response.status_code == 200 + assert "X-Ping-Tracked" in response.headers + assert response.headers["X-Ping-Tracked"] == "true" + + finally: + os.unlink(script_path) + + @pytest.mark.asyncio + async def test_middleware_env_var_override(self): + """Test middleware environment variable overrides.""" + try: + from model_hosting_container_standards.common.fastapi.config import ( + FastAPIEnvVars, + ) + from model_hosting_container_standards.sagemaker.config import ( + SageMakerEnvVars, + ) + except ImportError: + pytest.skip("model-hosting-container-standards not available") + + # Create a script with middleware functions specified via env vars + with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f: + f.write( + """ +from fastapi import Request + +# Global flag to track if pre_process was called +_pre_process_called = False + +async def env_throttle_middleware(request, call_next): + response = await call_next(request) + response.headers["X-Env-Throttle"] = "applied" + return response + +async def env_pre_process(request: Request) -> Request: + # Mark that pre_process was called + global _pre_process_called + _pre_process_called = True + return request + +async def env_post_process(response): + global _pre_process_called + if hasattr(response, 'headers'): + response.headers["X-Env-Post-Process"] = "applied" + # Since pre_process and post_process are combined into + # pre_post_process middleware, + # if post_process is called, pre_process should have been called too + if _pre_process_called: + response.headers["X-Pre-Process-Called"] = "true" + return response +""" + ) + script_path = f.name + + try: + script_dir = os.path.dirname(script_path) + script_name = os.path.basename(script_path) + + # Set environment variables for middleware + # Use script_name with .py extension as per plugin example + env_vars = { + SageMakerEnvVars.SAGEMAKER_MODEL_PATH: script_dir, + SageMakerEnvVars.CUSTOM_SCRIPT_FILENAME: script_name, + FastAPIEnvVars.CUSTOM_FASTAPI_MIDDLEWARE_THROTTLE: ( + f"{script_name}:env_throttle_middleware" + ), + FastAPIEnvVars.CUSTOM_PRE_PROCESS: f"{script_name}:env_pre_process", + FastAPIEnvVars.CUSTOM_POST_PROCESS: f"{script_name}:env_post_process", + } + + args = [ + "--dtype", + "bfloat16", + "--max-model-len", + "2048", + "--enforce-eager", + "--max-num-seqs", + "32", + ] + + with RemoteOpenAIServer( + MODEL_NAME_SMOLLM, args, env_dict=env_vars + ) as server: + response = requests.get(server.url_for("ping")) + assert response.status_code == 200 + + # Check if environment variable middleware was applied + headers = response.headers + + # Verify that env var middlewares were applied + assert "X-Env-Throttle" in headers, ( + "Throttle middleware should be applied via env var" + ) + assert headers["X-Env-Throttle"] == "applied" + + assert "X-Env-Post-Process" in headers, ( + "Post-process middleware should be applied via env var" + ) + assert headers["X-Env-Post-Process"] == "applied" + + # Verify that pre_process was called + assert "X-Pre-Process-Called" in headers, ( + "Pre-process should be called via env var" + ) + assert headers["X-Pre-Process-Called"] == "true" + + finally: + os.unlink(script_path) diff --git a/tests/entrypoints/sagemaker/test_sagemaker_stateful_sessions.py b/tests/entrypoints/sagemaker/test_sagemaker_stateful_sessions.py new file mode 100644 index 000000000000..6206000385bd --- /dev/null +++ b/tests/entrypoints/sagemaker/test_sagemaker_stateful_sessions.py @@ -0,0 +1,153 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + + +import openai # use the official client for correctness check +import pytest +import requests + +from ...utils import RemoteOpenAIServer +from .conftest import ( + HEADER_SAGEMAKER_CLOSED_SESSION_ID, + HEADER_SAGEMAKER_NEW_SESSION_ID, + HEADER_SAGEMAKER_SESSION_ID, + MODEL_NAME_SMOLLM, +) + +CLOSE_BADREQUEST_CASES = [ + ( + "nonexistent_session_id", + {"session_id": "nonexistent-session-id"}, + {}, + "session not found", + ), + ("malformed_close_request", {}, {"extra-field": "extra-field-data"}, None), +] + + +@pytest.mark.asyncio +async def test_create_session_badrequest(basic_server_with_lora: RemoteOpenAIServer): + bad_response = requests.post( + basic_server_with_lora.url_for("invocations"), + json={"requestType": "NEW_SESSION", "extra-field": "extra-field-data"}, + ) + + assert bad_response.status_code == 400 + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "test_name,session_id_change,request_body_change,expected_error", + CLOSE_BADREQUEST_CASES, +) +async def test_close_session_badrequest( + basic_server_with_lora: RemoteOpenAIServer, + test_name: str, + session_id_change: dict[str, str], + request_body_change: dict[str, str], + expected_error: str | None, +): + # first attempt to create a session + url = basic_server_with_lora.url_for("invocations") + create_response = requests.post(url, json={"requestType": "NEW_SESSION"}) + create_response.raise_for_status() + valid_session_id, expiration = create_response.headers.get( + HEADER_SAGEMAKER_NEW_SESSION_ID, "" + ).split(";") + assert valid_session_id + + close_request_json = {"requestType": "CLOSE"} + if request_body_change: + close_request_json.update(request_body_change) + bad_session_id = session_id_change.get("session_id") + bad_close_response = requests.post( + url, + headers={HEADER_SAGEMAKER_SESSION_ID: bad_session_id or valid_session_id}, + json=close_request_json, + ) + + # clean up created session, should succeed + clean_up_response = requests.post( + url, + headers={HEADER_SAGEMAKER_SESSION_ID: valid_session_id}, + json={"requestType": "CLOSE"}, + ) + clean_up_response.raise_for_status() + + assert bad_close_response.status_code == 400 + if expected_error: + assert expected_error in bad_close_response.json()["error"]["message"] + + +@pytest.mark.asyncio +async def test_close_session_invalidrequest( + basic_server_with_lora: RemoteOpenAIServer, async_client: openai.AsyncOpenAI +): + # first attempt to create a session + url = basic_server_with_lora.url_for("invocations") + create_response = requests.post(url, json={"requestType": "NEW_SESSION"}) + create_response.raise_for_status() + valid_session_id, expiration = create_response.headers.get( + HEADER_SAGEMAKER_NEW_SESSION_ID, "" + ).split(";") + assert valid_session_id + + close_request_json = {"requestType": "CLOSE"} + invalid_close_response = requests.post( + url, + # no headers to specify session_id + json=close_request_json, + ) + + # clean up created session, should succeed + clean_up_response = requests.post( + url, + headers={HEADER_SAGEMAKER_SESSION_ID: valid_session_id}, + json={"requestType": "CLOSE"}, + ) + clean_up_response.raise_for_status() + + assert invalid_close_response.status_code == 424 + assert "invalid session_id" in invalid_close_response.json()["error"]["message"] + + +@pytest.mark.asyncio +async def test_session(basic_server_with_lora: RemoteOpenAIServer): + # first attempt to create a session + url = basic_server_with_lora.url_for("invocations") + create_response = requests.post(url, json={"requestType": "NEW_SESSION"}) + create_response.raise_for_status() + valid_session_id, expiration = create_response.headers.get( + HEADER_SAGEMAKER_NEW_SESSION_ID, "" + ).split(";") + assert valid_session_id + + # test invocation with session id + + request_args = { + "model": MODEL_NAME_SMOLLM, + "prompt": "what is 1+1?", + "max_completion_tokens": 5, + "temperature": 0.0, + "logprobs": False, + } + + invocation_response = requests.post( + basic_server_with_lora.url_for("invocations"), + headers={HEADER_SAGEMAKER_SESSION_ID: valid_session_id}, + json=request_args, + ) + invocation_response.raise_for_status() + + # close created session, should succeed + close_response = requests.post( + url, + headers={HEADER_SAGEMAKER_SESSION_ID: valid_session_id}, + json={"requestType": "CLOSE"}, + ) + close_response.raise_for_status() + + assert ( + close_response.headers.get(HEADER_SAGEMAKER_CLOSED_SESSION_ID) + == valid_session_id + ) diff --git a/vllm/entrypoints/dynamic_lora.py b/vllm/entrypoints/dynamic_lora.py new file mode 100644 index 000000000000..cc0f437e5c77 --- /dev/null +++ b/vllm/entrypoints/dynamic_lora.py @@ -0,0 +1,57 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import model_hosting_container_standards.sagemaker as sagemaker_standards +from fastapi import APIRouter, Depends, Request +from fastapi.responses import JSONResponse, Response + +from vllm.entrypoints.openai.api_server import models, validate_json_request +from vllm.entrypoints.openai.protocol import ( + ErrorResponse, + LoadLoRAAdapterRequest, + UnloadLoRAAdapterRequest, +) +from vllm.entrypoints.openai.serving_models import OpenAIServingModels +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +def register_dynamic_lora_routes(router: APIRouter): + @sagemaker_standards.register_load_adapter_handler( + request_shape={ + "lora_name": "body.name", + "lora_path": "body.src", + }, + ) + @router.post("/v1/load_lora_adapter", dependencies=[Depends(validate_json_request)]) + async def load_lora_adapter(request: LoadLoRAAdapterRequest, raw_request: Request): + handler: OpenAIServingModels = models(raw_request) + response = await handler.load_lora_adapter(request) + if isinstance(response, ErrorResponse): + return JSONResponse( + content=response.model_dump(), status_code=response.error.code + ) + + return Response(status_code=200, content=response) + + @sagemaker_standards.register_unload_adapter_handler( + request_shape={ + "lora_name": "path_params.adapter_name", + } + ) + @router.post( + "/v1/unload_lora_adapter", dependencies=[Depends(validate_json_request)] + ) + async def unload_lora_adapter( + request: UnloadLoRAAdapterRequest, raw_request: Request + ): + handler: OpenAIServingModels = models(raw_request) + response = await handler.unload_lora_adapter(request) + if isinstance(response, ErrorResponse): + return JSONResponse( + content=response.model_dump(), status_code=response.error.code + ) + + return Response(status_code=200, content=response) + + return router diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 51191879e478..fbb2d32a229d 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -19,6 +19,7 @@ from http import HTTPStatus from typing import Annotated, Any, Literal +import model_hosting_container_standards.sagemaker as sagemaker_standards import prometheus_client import pydantic import regex as re @@ -65,7 +66,6 @@ ErrorInfo, ErrorResponse, IOProcessorResponse, - LoadLoRAAdapterRequest, PoolingBytesResponse, PoolingRequest, PoolingResponse, @@ -82,7 +82,6 @@ TranscriptionResponse, TranslationRequest, TranslationResponse, - UnloadLoRAAdapterRequest, ) from vllm.entrypoints.openai.serving_chat import OpenAIServingChat from vllm.entrypoints.openai.serving_classification import ServingClassification @@ -387,13 +386,6 @@ async def get_server_load_metrics(request: Request): return JSONResponse(content={"server_load": request.app.state.server_load_metrics}) -@router.get("/ping", response_class=Response) -@router.post("/ping", response_class=Response) -async def ping(raw_request: Request) -> Response: - """Ping check. Endpoint required for SageMaker""" - return await health(raw_request) - - @router.post( "/tokenize", dependencies=[Depends(validate_json_request)], @@ -1236,47 +1228,6 @@ async def is_scaling_elastic_ep(raw_request: Request): ] -@router.post( - "/invocations", - dependencies=[Depends(validate_json_request)], - responses={ - HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse}, - HTTPStatus.UNSUPPORTED_MEDIA_TYPE.value: {"model": ErrorResponse}, - HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse}, - }, -) -async def invocations(raw_request: Request): - """For SageMaker, routes requests based on the request type.""" - try: - body = await raw_request.json() - except json.JSONDecodeError as e: - raise HTTPException( - status_code=HTTPStatus.BAD_REQUEST.value, detail=f"JSON decode error: {e}" - ) from e - - valid_endpoints = [ - (validator, endpoint) - for validator, (get_handler, endpoint) in INVOCATION_VALIDATORS - if get_handler(raw_request) is not None - ] - - for request_validator, endpoint in valid_endpoints: - try: - request = request_validator.validate_python(body) - except pydantic.ValidationError: - continue - - return await endpoint(request, raw_request) - - type_names = [ - t.__name__ if isinstance(t := validator._type, type) else str(t) - for validator, _ in valid_endpoints - ] - msg = f"Cannot find suitable handler for request. Expected one of: {type_names}" - res = base(raw_request).create_error_response(message=msg) - return JSONResponse(content=res.model_dump(), status_code=res.error.code) - - if envs.VLLM_TORCH_PROFILER_DIR: logger.warning_once( "Torch Profiler is enabled in the API server. This should ONLY be " @@ -1304,39 +1255,6 @@ async def stop_profile(raw_request: Request): return Response(status_code=200) -if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING: - logger.warning( - "LoRA dynamic loading & unloading is enabled in the API server. " - "This should ONLY be used for local development!" - ) - - @router.post("/v1/load_lora_adapter", dependencies=[Depends(validate_json_request)]) - async def load_lora_adapter(request: LoadLoRAAdapterRequest, raw_request: Request): - handler = models(raw_request) - response = await handler.load_lora_adapter(request) - if isinstance(response, ErrorResponse): - return JSONResponse( - content=response.model_dump(), status_code=response.error.code - ) - - return Response(status_code=200, content=response) - - @router.post( - "/v1/unload_lora_adapter", dependencies=[Depends(validate_json_request)] - ) - async def unload_lora_adapter( - request: UnloadLoRAAdapterRequest, raw_request: Request - ): - handler = models(raw_request) - response = await handler.unload_lora_adapter(request) - if isinstance(response, ErrorResponse): - return JSONResponse( - content=response.model_dump(), status_code=response.error.code - ) - - return Response(status_code=200, content=response) - - def load_log_config(log_config_file: str | None) -> dict | None: if not log_config_file: return None @@ -1606,6 +1524,20 @@ def build_app(args: Namespace) -> FastAPI: ) else: app = FastAPI(lifespan=lifespan) + + if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING: + logger.warning( + "LoRA dynamic loading & unloading is enabled in the API server. " + "This should ONLY be used for local development!" + ) + from vllm.entrypoints.dynamic_lora import register_dynamic_lora_routes + + register_dynamic_lora_routes(router) + + from vllm.entrypoints.sagemaker.routes import register_sagemaker_routes + + register_sagemaker_routes(router) + app.include_router(router) app.root_path = args.root_path @@ -1696,6 +1628,8 @@ async def log_response(request: Request, call_next): f"Invalid middleware {middleware}. Must be a function or a class." ) + app = sagemaker_standards.bootstrap(app) + return app diff --git a/vllm/entrypoints/sagemaker/__init__.py b/vllm/entrypoints/sagemaker/__init__.py new file mode 100644 index 000000000000..c1767137e4ea --- /dev/null +++ b/vllm/entrypoints/sagemaker/__init__.py @@ -0,0 +1,4 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +"""SageMaker-specific integration for vLLM.""" diff --git a/vllm/entrypoints/sagemaker/routes.py b/vllm/entrypoints/sagemaker/routes.py new file mode 100644 index 000000000000..498b7294f0d8 --- /dev/null +++ b/vllm/entrypoints/sagemaker/routes.py @@ -0,0 +1,72 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import json +from http import HTTPStatus + +import model_hosting_container_standards.sagemaker as sagemaker_standards +import pydantic +from fastapi import APIRouter, Depends, HTTPException, Request +from fastapi.responses import JSONResponse, Response + +from vllm.entrypoints.openai.api_server import ( + INVOCATION_VALIDATORS, + base, + health, + validate_json_request, +) +from vllm.entrypoints.openai.protocol import ErrorResponse + + +def register_sagemaker_routes(router: APIRouter): + @router.post("/ping", response_class=Response) + @router.get("/ping", response_class=Response) + @sagemaker_standards.register_ping_handler + async def ping(raw_request: Request) -> Response: + """Ping check. Endpoint required for SageMaker""" + return await health(raw_request) + + @router.post( + "/invocations", + dependencies=[Depends(validate_json_request)], + responses={ + HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse}, + HTTPStatus.UNSUPPORTED_MEDIA_TYPE.value: {"model": ErrorResponse}, + HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse}, + }, + ) + @sagemaker_standards.register_invocation_handler + @sagemaker_standards.stateful_session_manager() + @sagemaker_standards.inject_adapter_id(adapter_path="model") + async def invocations(raw_request: Request): + """For SageMaker, routes requests based on the request type.""" + try: + body = await raw_request.json() + except json.JSONDecodeError as e: + raise HTTPException( + status_code=HTTPStatus.BAD_REQUEST.value, + detail=f"JSON decode error: {e}", + ) from e + + valid_endpoints = [ + (validator, endpoint) + for validator, (get_handler, endpoint) in INVOCATION_VALIDATORS + if get_handler(raw_request) is not None + ] + + for request_validator, endpoint in valid_endpoints: + try: + request = request_validator.validate_python(body) + except pydantic.ValidationError: + continue + + return await endpoint(request, raw_request) + + type_names = [ + t.__name__ if isinstance(t := validator._type, type) else str(t) + for validator, _ in valid_endpoints + ] + msg = f"Cannot find suitable handler for request. Expected one of: {type_names}" + res = base(raw_request).create_error_response(message=msg) + return JSONResponse(content=res.model_dump(), status_code=res.error.code) + + return router From e605e8e3233f895340f46665f93ab37b307491aa Mon Sep 17 00:00:00 2001 From: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com> Date: Tue, 11 Nov 2025 00:59:08 -0500 Subject: [PATCH 042/183] [Bugfix] Fix Stream Sync for Shared Expert Overlap (#28430) Signed-off-by: Vadim Gimpelson Signed-off-by: Robert Shaw Co-authored-by: Vadim Gimpelson --- .../gsm8k/configs/Qwen1.5-MoE-W4A16-CT.yaml | 3 -- vllm/model_executor/layers/fused_moe/layer.py | 45 +++++++------------ 2 files changed, 15 insertions(+), 33 deletions(-) diff --git a/tests/evals/gsm8k/configs/Qwen1.5-MoE-W4A16-CT.yaml b/tests/evals/gsm8k/configs/Qwen1.5-MoE-W4A16-CT.yaml index ea9c95158405..9297bf6ddf2d 100644 --- a/tests/evals/gsm8k/configs/Qwen1.5-MoE-W4A16-CT.yaml +++ b/tests/evals/gsm8k/configs/Qwen1.5-MoE-W4A16-CT.yaml @@ -3,6 +3,3 @@ accuracy_threshold: 0.45 num_questions: 1319 num_fewshot: 5 max_model_len: 4096 -# Duo stream incompatabilbe with this model: https://github.com/vllm-project/vllm/issues/28220 -env: - VLLM_DISABLE_SHARED_EXPERTS_STREAM: "1" diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 27ad9c8fd1c2..39547cc83c7b 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -2456,28 +2456,6 @@ def process_chunk(chunk_start, chunk_end, skip_result_store=False): staged_hidden_states.copy_(hidden_states, non_blocking=True) staged_router_logits.copy_(router_logits, non_blocking=True) - # If there are shared experts but we are not using a modular kernel, - # the shared experts must be called here - if has_separate_shared_experts: - assert self.shared_experts is not None - - if self.shared_experts_stream is not None: - # For chunked, we start the shared experts stream here - # (Note that no concurrency with the router/gate) - self.shared_experts_stream.wait_stream(current_stream()) - - with torch.cuda.stream(self.shared_experts_stream): - # Note that staged_hidden_states clone() is necessary - # here to avoid conflict with the main stream - shared_output = self.shared_experts( - staged_hidden_states.clone() - ) - else: - shared_output = self.shared_experts(staged_hidden_states) - - else: - shared_output = None - # Matrix multiply. final_hidden_states = self.quant_method.apply( layer=self, @@ -2506,11 +2484,7 @@ def process_chunk(chunk_start, chunk_end, skip_result_store=False): if has_separate_shared_experts: assert not isinstance(final_hidden_states, tuple) assert self.shared_experts is not None - - # Here we finish the shared experts stream - if self.shared_experts_stream is not None: - current_stream().wait_stream(self.shared_experts_stream) - + shared_output = self.shared_experts(staged_hidden_states) final_hidden_states = ( shared_output, final_hidden_states, @@ -2619,11 +2593,22 @@ def forward_impl( assert self.shared_experts is not None if self.shared_experts_stream is not None: + # Clone BEFORE switching streams to avoid race condition + # where routed_expert kernel may mutate hidden_states. + hidden_states_clone = hidden_states.clone() + self.shared_experts_stream.wait_stream(current_stream()) + # Run shared experts in parallel on a separate stream with torch.cuda.stream(self.shared_experts_stream): - # Note that hidden_states clone() is necessary here to avoid - # conflict with the main stream - shared_output = self.shared_experts(hidden_states.clone()) + shared_output = self.shared_experts(hidden_states_clone) + + # Record that the clone will be used by shared_experts_stream + # to avoid gc issue from deallocation of hidden_states_clone + # For more details: https://docs.pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html # noqa: E501 + # NOTE: we dont need shared_output.record_stream(current_stream()) + # because we synch the streams before using shared_output. + hidden_states_clone.record_stream(self.shared_experts_stream) + else: shared_output = self.shared_experts(hidden_states) else: From a7adbc6c6b4bcdef5cfffdcd06edf86fcbfb7c69 Mon Sep 17 00:00:00 2001 From: iAmir97 <71513472+iAmir97@users.noreply.github.com> Date: Tue, 11 Nov 2025 13:44:35 +0700 Subject: [PATCH 043/183] [Doc] Sleep mode documentation (#28357) Signed-off-by: Amir Balwel Signed-off-by: iAmir97 <71513472+iAmir97@users.noreply.github.com> Co-authored-by: Amir Balwel Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- docs/features/sleep_mode.md | 39 +++++++++++++++++++++++++++++++++++++ 1 file changed, 39 insertions(+) diff --git a/docs/features/sleep_mode.md b/docs/features/sleep_mode.md index e7dd9fee12d3..edcbaa716447 100644 --- a/docs/features/sleep_mode.md +++ b/docs/features/sleep_mode.md @@ -13,6 +13,9 @@ Key benefits: !!! note This feature is only supported on CUDA platform. +!!! note + For more information, see this [Blog Post](https://blog.vllm.ai/2025/10/26/sleep-mode.html). + ## Sleep levels Level 1 sleep will offload the model weights and discard the KV cache. The content of KV cache is forgotten. Level 1 sleep is good for sleeping and waking up the engine to run the same model again. The model weights are backed up in CPU memory. Please make sure there's enough CPU memory to store the model weights. Level 2 sleep will discard both the model weights and the KV cache (while the model's buffers are kept in CPU, like rope scaling tensors). The content of both the model weights and KV cache is forgotten. Level 2 sleep is good for sleeping and waking up the engine to run a different model or update the model, where previous model weights are not needed, e.g. RLHF weight update. @@ -31,6 +34,7 @@ llm = LLM("Qwen/Qwen3-0.6B", enable_sleep_mode=True) #### Python API ```python +# Sleep level 1 # Put the engine to sleep (level=1: offload weights to CPU RAM, discard KV cache) llm.sleep(level=1) @@ -38,6 +42,21 @@ llm.sleep(level=1) llm.wake_up() ``` +```python +# Sleep level 2 +# Put the engine to sleep (level=2: discard both weights and KV cache) +llm.sleep(level=2) + +# Reallocate weights memory only +llm.wake_up(tags=["weights"]) + +# Load weights in-place +llm.collective_rpc("reload_weights") + +# Reallocate KV cache +llm.wake_up(tags=["kv_cache"]) +``` + #### RLHF weight updates During RLHF training, vLLM allows you to selectively wake up only the model weights or the KV cache using the tags argument in wake_up(). This fine-grained control is especially useful when updating model weights: by waking up just the weights (e.g., llm.wake_up(tags=["weights"])), you avoid allocating memory for the KV cache until after the weight update is complete. This approach helps prevent GPU out-of-memory (OOM) errors, particularly with large models, by minimizing peak memory usage during weight synchronization and update operations. @@ -69,10 +88,30 @@ VLLM_SERVER_DEV_MODE=1 vllm serve Qwen/Qwen3-0.6B \ --port 8000 ``` +Below is an example of how to sleep and wake up a model in level 1. + +```bash +curl -X POST 'http://localhost:8000/sleep?level=1' +curl -X POST 'http://localhost:8000/wake_up' +``` + +And this is an example of how to sleep and wake up a model in level 2. + +```bash +curl -X POST 'http://localhost:8000/sleep?level=2' +# Reallocate weights memory only +curl -X POST 'http://localhost:8000/wake_up?tags=weights' +# Load weights in-place +curl -X POST 'http://localhost:8000/collective_rpc' -H 'Content-Type: application/json' -d '{"method":"reload_weights"}' +# Reallocate KV cache +curl -X POST 'http://localhost:8000/wake_up?tags=kv_cache' +``` + #### HTTP endpoints - `POST /sleep?level=1` — Put the model to sleep (`level=1`). - `POST /wake_up` — Wake up the model. Supports optional `tags` query parameters for partial wake-up (e.g., `?tags=weights`). +- `POST /collective_rpc` — Perform a collective remote procedure call (RPC). - `GET /is_sleeping` — Check if the model is sleeping. !!! note From cc079763c59adb8c03305663a5b8857ba85deb1b Mon Sep 17 00:00:00 2001 From: David Ben-David Date: Tue, 11 Nov 2025 09:39:36 +0200 Subject: [PATCH 044/183] [BugFix] Avoid calling KV connector layer APIs when metadata is unset (#28253) Signed-off-by: David Ben-David Co-authored-by: David Ben-David Co-authored-by: Mark McLoughlin --- vllm/attention/layer.py | 4 ++++ vllm/distributed/kv_transfer/kv_connector/v1/base.py | 9 ++++++++- .../kv_transfer/kv_connector/v1/multi_connector.py | 6 ++++++ 3 files changed, 18 insertions(+), 1 deletion(-) diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 96272981692c..acab0529f352 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -837,6 +837,8 @@ def wait_for_kv_layer_from_connector(layer_name: str): return connector = get_kv_transfer_group() + if not connector.has_connector_metadata(): + return forward_context: ForwardContext = get_forward_context() attn_metadata = forward_context.attn_metadata @@ -854,6 +856,8 @@ def maybe_save_kv_layer_to_connector( return connector = get_kv_transfer_group() + if not connector.has_connector_metadata(): + return forward_context: ForwardContext = get_forward_context() attn_metadata = forward_context.attn_metadata diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/base.py b/vllm/distributed/kv_transfer/kv_connector/v1/base.py index 354aa9a87183..f85eb414b222 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/base.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/base.py @@ -204,11 +204,18 @@ def _get_connector_metadata(self) -> KVConnectorMetadata: Returns: ConnectorMetadata: the connector metadata. """ - # Should only be called while set to valid metadata. assert self._connector_metadata is not None return self._connector_metadata + def has_connector_metadata(self) -> bool: + """Check whether the connector metadata is currently set. + + Returns: + bool: True if connector metadata exists, False otherwise. + """ + return self._connector_metadata is not None + def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): """ Initialize with the KV caches. Useful for pre-registering the diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py index d7bbf02c8367..c9d08e9b78ed 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py @@ -171,16 +171,22 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): # We must override the base class method here because we need to bind # the metadata to each connector in the order of the connectors in the # MultiKVConnectorMetadata. + # + # Note: Call the base class method to ensure metadata is also set on the + # MultiConnector instance itself; otherwise, `has_connector_metadata()` will + # always return False. def bind_connector_metadata(self, connector_metadata: KVConnectorMetadata) -> None: assert isinstance(connector_metadata, MultiKVConnectorMetadata) if connector_metadata.extra_async_saves: self._extra_async_saves.update(connector_metadata.extra_async_saves) for c, cm in zip(self._connectors, connector_metadata.metadata): c.bind_connector_metadata(cm) + super().bind_connector_metadata(connector_metadata) def clear_connector_metadata(self) -> None: for c in self._connectors: c.clear_connector_metadata() + super().clear_connector_metadata() def shutdown(self): exception: Exception | None = None From 4fd4b743a23cc6ccbd832f11be12317a8c2f0fbc Mon Sep 17 00:00:00 2001 From: Roger Wang Date: Tue, 11 Nov 2025 00:07:24 -0800 Subject: [PATCH 045/183] [Bugfix] Fix max image size for PaddleOCR-VL (#28442) Signed-off-by: Roger Wang --- vllm/model_executor/models/paddleocr_vl.py | 36 +++++++++++++--------- 1 file changed, 21 insertions(+), 15 deletions(-) diff --git a/vllm/model_executor/models/paddleocr_vl.py b/vllm/model_executor/models/paddleocr_vl.py index 631475c964c0..12ae15699e7d 100644 --- a/vllm/model_executor/models/paddleocr_vl.py +++ b/vllm/model_executor/models/paddleocr_vl.py @@ -198,23 +198,18 @@ def get_num_image_tokens( if image_processor is None: image_processor = self.get_image_processor() - do_resize = True hf_config = self.get_hf_config() vision_config = hf_config.vision_config patch_size = vision_config.patch_size merge_size = vision_config.spatial_merge_size - - if do_resize: - resized_height, resized_width = smart_resize( - height=image_height, - width=image_width, - factor=patch_size * merge_size, - min_pixels=image_processor.min_pixels, - max_pixels=image_processor.max_pixels, - ) - preprocessed_size = ImageSize(width=resized_width, height=resized_height) - else: - preprocessed_size = ImageSize(width=image_width, height=image_height) + resized_height, resized_width = smart_resize( + height=image_height, + width=image_width, + factor=patch_size * merge_size, + min_pixels=image_processor.min_pixels, + max_pixels=image_processor.max_pixels, + ) + preprocessed_size = ImageSize(width=resized_width, height=resized_height) grid_t = 1 grid_h = preprocessed_size.height // patch_size @@ -227,8 +222,19 @@ def get_num_image_tokens( def get_image_size_with_most_features(self) -> ImageSize: hf_config = self.get_hf_config() - image_size = hf_config.vision_config.image_size - return ImageSize(height=image_size, width=image_size) + + # See `smart_resize` for the calculation of the image size. + merge_size = hf_config.vision_config.spatial_merge_size + patch_size = hf_config.vision_config.patch_size + factor = merge_size * patch_size + max_num_tokens = self.get_image_processor().max_pixels // (factor**2) + # Find factors of max_num_tokens close to its square root + # to create a dummy image with a reasonable aspect ratio. + h_patches = int(math.sqrt(max_num_tokens)) + while max_num_tokens % h_patches != 0: + h_patches -= 1 + w_patches = max_num_tokens // h_patches + return ImageSize(height=h_patches * factor, width=w_patches * factor) class PaddleOCRVLDummyInputsBuilder(BaseDummyInputsBuilder[PaddleOCRVLProcessingInfo]): From 798c7bebca5e3ea48b947af4cc7904a4507ba873 Mon Sep 17 00:00:00 2001 From: Sage Moore Date: Tue, 11 Nov 2025 00:19:51 -0800 Subject: [PATCH 046/183] [EPLB] Refactor balance_packing to use numpy and optimize GPU-CPU transfers in EPLB (#28369) Signed-off-by: Sage Moore --- vllm/distributed/eplb/rebalance_algo.py | 40 +++++++++++++++------- vllm/distributed/eplb/rebalance_execute.py | 14 +++++--- 2 files changed, 37 insertions(+), 17 deletions(-) diff --git a/vllm/distributed/eplb/rebalance_algo.py b/vllm/distributed/eplb/rebalance_algo.py index c9d30d6481ab..e6645e524cc3 100644 --- a/vllm/distributed/eplb/rebalance_algo.py +++ b/vllm/distributed/eplb/rebalance_algo.py @@ -12,6 +12,7 @@ on how the EPLB algorithm works. """ +import numpy as np import torch @@ -34,29 +35,44 @@ def balanced_packing( assert num_groups % num_packs == 0 groups_per_pack = num_groups // num_packs + device = weight.device + if groups_per_pack == 1: pack_index = torch.arange( - weight.size(-1), dtype=torch.int64, device=weight.device + weight.size(-1), dtype=torch.int64, device=device ).expand(weight.shape) - rank_in_pack = torch.zeros_like(weight, dtype=torch.int64) + rank_in_pack = torch.zeros_like(weight, dtype=torch.int64, device=device) return pack_index, rank_in_pack - indices = weight.float().sort(-1, descending=True).indices.cpu() - pack_index = torch.full_like(weight, fill_value=-1, dtype=torch.int64, device="cpu") - rank_in_pack = torch.full_like(pack_index, fill_value=-1) + weight_np = weight.cpu().numpy() + + # Sort and get indices in decending order + indices_np = np.argsort(-weight_np, axis=-1) + + pack_index_np = np.full((num_layers, num_groups), -1, dtype=np.int64) + rank_in_pack_np = np.full((num_layers, num_groups), -1, dtype=np.int64) + + # Run the packing algorithm for i in range(num_layers): - pack_weights = [0] * num_packs + pack_weights = [0.0] * num_packs pack_items = [0] * num_packs - for group in indices[i]: + + for group in indices_np[i]: + # Find a pack with capacity that has the lowest weight pack = min( - (i for i in range(num_packs) if pack_items[i] < groups_per_pack), + (j for j in range(num_packs) if pack_items[j] < groups_per_pack), key=pack_weights.__getitem__, ) + assert pack_items[pack] < groups_per_pack - pack_index[i, group] = pack - rank_in_pack[i, group] = pack_items[pack] - pack_weights[pack] += weight[i, group] + pack_index_np[i, group] = pack + rank_in_pack_np[i, group] = pack_items[pack] + pack_weights[pack] += weight_np[i, group] pack_items[pack] += 1 + + pack_index = torch.from_numpy(pack_index_np).to(device) + rank_in_pack = torch.from_numpy(rank_in_pack_np).to(device) + return pack_index, rank_in_pack @@ -212,7 +228,7 @@ def rebalance_experts( replicas for each logical expert """ num_layers, num_logical_experts = weight.shape - weight = weight.float().cpu() + weight = weight.float() if num_groups % num_nodes == 0: # use hierarchical load-balance policy phy2log, phyrank, logcnt = rebalance_experts_hierarchical( diff --git a/vllm/distributed/eplb/rebalance_execute.py b/vllm/distributed/eplb/rebalance_execute.py index f8ec3e956401..5c1efbaf03ba 100644 --- a/vllm/distributed/eplb/rebalance_execute.py +++ b/vllm/distributed/eplb/rebalance_execute.py @@ -321,15 +321,19 @@ def rearrange_expert_weights_inplace( ) return + old_global_expert_indices_cpu = old_global_expert_indices.cpu() + new_global_expert_indices_cpu = new_global_expert_indices.cpu() + + # NOTE(bowen): We need this synchronize to run, but I don't know why. + # If you figure out the reason, please let me know -- thank you! + torch.cuda.synchronize() + for layer in range(num_moe_layers): - # NOTE(bowen): We need this synchronize to run, but I don't know why. - # If you figure out the reason, please let me know -- thank you! - torch.cuda.synchronize() shuffle_layer( num_local_physical_experts, ep_rank, - old_global_expert_indices[layer].tolist(), - new_global_expert_indices[layer].tolist(), + old_global_expert_indices_cpu[layer].tolist(), + new_global_expert_indices_cpu[layer].tolist(), expert_weights[layer], expert_weights_buffer, ep_group, From f0359fffa434a4fce981389f9dff93a2a4c2b13e Mon Sep 17 00:00:00 2001 From: Jiangyun Zhu Date: Tue, 11 Nov 2025 16:24:28 +0800 Subject: [PATCH 047/183] [Bugfix] fix qwen3-next crash (#28202) Signed-off-by: zjy0516 --- vllm/model_executor/models/qwen3_next.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/models/qwen3_next.py b/vllm/model_executor/models/qwen3_next.py index aa7de5aa5f29..ddb8693c16e2 100644 --- a/vllm/model_executor/models/qwen3_next.py +++ b/vllm/model_executor/models/qwen3_next.py @@ -587,7 +587,7 @@ def _forward_core( self.conv1d.bias, self.activation, conv_state_indices=non_spec_state_indices_tensor[ - : attn_metadata.num_decodes + : attn_metadata.num_actual_tokens ], validate_data=True, ) From c7991269dd8fe86096a3eee5040e855801ae9665 Mon Sep 17 00:00:00 2001 From: Fanli Lin Date: Tue, 11 Nov 2025 16:45:38 +0800 Subject: [PATCH 048/183] [BugFix] 'DeepseekV2Config' object has no attribute 'use_mla'` (#28387) Signed-off-by: Lin, Fanli --- vllm/model_executor/models/kimi_vl.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/models/kimi_vl.py b/vllm/model_executor/models/kimi_vl.py index b54f53931d71..b79bdf8595ca 100644 --- a/vllm/model_executor/models/kimi_vl.py +++ b/vllm/model_executor/models/kimi_vl.py @@ -456,7 +456,11 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): (".gate_up_proj", ".gate_proj", 0), (".gate_up_proj", ".up_proj", 1), ] - if not config.use_mla: + use_mha = ( + config.model_type == "deepseek" + or config.qk_nope_head_dim + config.qk_rope_head_dim == 0 + ) + if use_mha: stacked_params_mapping += [ (".qkv_proj", ".q_proj", "q"), (".qkv_proj", ".k_proj", "k"), From 9973e6e04ad3e4a6c74c51a2dc87b2d3ddc4837f Mon Sep 17 00:00:00 2001 From: Lukas Geiger Date: Tue, 11 Nov 2025 10:35:10 +0000 Subject: [PATCH 049/183] [Model][Qwen3VL] Slighly speedup `fast_pos_embed_interpolate` (#28434) Signed-off-by: Lukas Geiger --- vllm/model_executor/models/qwen3_vl.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/qwen3_vl.py b/vllm/model_executor/models/qwen3_vl.py index fe0124ef3258..1cd34bf54a35 100644 --- a/vllm/model_executor/models/qwen3_vl.py +++ b/vllm/model_executor/models/qwen3_vl.py @@ -491,8 +491,8 @@ def fast_pos_embed_interpolate(self, grid_thw: list[list[int]]) -> torch.Tensor: weights = weights.to(dtype=self.dtype) embeds = self.pos_embed(indices) - weighted_embeds = embeds * weights - combined = weighted_embeds.sum(dim=0) + embeds *= weights + combined = embeds.sum(dim=0) combined = combined.reshape( h // m_size, m_size, w // m_size, m_size, hidden_dim From d381eb967f171ea8824357075b15bf2895619609 Mon Sep 17 00:00:00 2001 From: Ido Segev Date: Tue, 11 Nov 2025 13:06:04 +0200 Subject: [PATCH 050/183] Multi turn benchmark progress bar for synthetic conversation generation (#28394) Signed-off-by: Ido Segev --- benchmarks/multi_turn/bench_dataset.py | 18 +++++++++++++++--- benchmarks/multi_turn/requirements.txt | 3 ++- 2 files changed, 17 insertions(+), 4 deletions(-) diff --git a/benchmarks/multi_turn/bench_dataset.py b/benchmarks/multi_turn/bench_dataset.py index 2674899d1cc5..8cb8a2f386a9 100644 --- a/benchmarks/multi_turn/bench_dataset.py +++ b/benchmarks/multi_turn/bench_dataset.py @@ -11,6 +11,7 @@ Color, logger, ) +from tqdm import tqdm from transformers import AutoTokenizer # type: ignore # Conversation ID is a string (e.g: "UzTK34D") @@ -417,6 +418,10 @@ def generate_conversations( data = file.read() tokens_in_file = tokenizer.encode(data, add_special_tokens=False) list_of_tokens.extend(tokens_in_file) + logger.info( + f"Loaded {len(tokens_in_file)} tokens from file {filename}, " + f"total tokens so far: {len(list_of_tokens)}" + ) conversations: ConversationsMap = {} conv_id = 0 @@ -449,18 +454,25 @@ def generate_conversations( ) base_offset += common_prefix_tokens - for conv_id in range(args.num_conversations): + for conv_id in tqdm( + range(args.num_conversations), + total=args.num_conversations, + desc="Generating conversations", + unit="conv", + ): # Generate a single conversation messages: MessagesList = [] nturns = turn_count[conv_id] # User prompt token count per turn (with lower limit) - input_token_count: np.ndarray = args.input_num_tokens.sample(nturns) + input_token_count: np.ndarray = args.input_num_tokens.sample(nturns).astype(int) input_token_count = np.maximum(input_token_count, base_prompt_token_count) # Assistant answer token count per turn (with lower limit) - output_token_count: np.ndarray = args.output_num_tokens.sample(nturns) + output_token_count: np.ndarray = args.output_num_tokens.sample(nturns).astype( + int + ) output_token_count = np.maximum(output_token_count, 1) user_turn = True diff --git a/benchmarks/multi_turn/requirements.txt b/benchmarks/multi_turn/requirements.txt index f0e1935914a1..bae656a5c5c4 100644 --- a/benchmarks/multi_turn/requirements.txt +++ b/benchmarks/multi_turn/requirements.txt @@ -2,4 +2,5 @@ numpy>=1.24 pandas>=2.0.0 aiohttp>=3.10 transformers>=4.46 -xlsxwriter>=3.2.1 \ No newline at end of file +xlsxwriter>=3.2.1 +tqdm>=4.66 From 2e78150d24e339bf6420a623cdae655051127d8f Mon Sep 17 00:00:00 2001 From: Michael Goin Date: Tue, 11 Nov 2025 05:28:28 -0700 Subject: [PATCH 051/183] [CI] Add mergify rules for `nvidia` label (#28417) Signed-off-by: mgoin --- .github/mergify.yml | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/.github/mergify.yml b/.github/mergify.yml index 18d4a2e83144..997a40e18e58 100644 --- a/.github/mergify.yml +++ b/.github/mergify.yml @@ -151,6 +151,23 @@ pull_request_rules: add: - gpt-oss +- name: label-nvidia + description: Automatically apply nvidia label + conditions: + - label != stale + - or: + - files~=cuda + - files~=cutlass + - files~=flashinfer + - files~=trtllm + - title~=(?i)NVIDIA + - title~=(?i)CUDA + - title~=(?i)CUTLASS + actions: + label: + add: + - nvidia + - name: label-rocm description: Automatically apply rocm label conditions: From b30dfa03c564ce51c56bf2dd16283f074253c27c Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Tue, 11 Nov 2025 06:40:44 -0600 Subject: [PATCH 052/183] [Attention] Refactor CUDA attention backend selection logic (#24794) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Matthew Bonanni Signed-off-by: Matthew Bonanni Co-authored-by: Luka Govedič --- .buildkite/test-pipeline.yaml | 5 + tests/compile/test_fusion_attn.py | 31 +- tests/compile/test_fusions_e2e.py | 24 +- tests/config/test_multimodal_config.py | 6 +- .../attention/test_attention_selector.py | 75 ++-- tests/kernels/attention/test_mha_attn.py | 12 +- tests/models/test_initialization.py | 11 + tests/v1/attention/test_attention_backends.py | 47 ++- tests/v1/attention/test_mla_backends.py | 29 +- tests/v1/attention/utils.py | 10 +- tests/v1/spec_decode/test_eagle.py | 18 +- tests/v1/spec_decode/test_mtp.py | 6 +- tests/v1/spec_decode/test_tree_attention.py | 8 +- tests/v1/worker/test_gpu_model_runner.py | 25 +- vllm/attention/backends/abstract.py | 149 ++++++- vllm/attention/backends/registry.py | 252 ++++++++---- vllm/attention/layer.py | 68 ++-- vllm/attention/selector.py | 124 +++--- vllm/config/cache.py | 10 +- vllm/config/model.py | 8 +- vllm/config/multimodal.py | 32 +- .../kv_connector/v1/nixl_connector.py | 8 +- vllm/engine/arg_utils.py | 4 +- vllm/envs.py | 6 +- vllm/model_executor/models/dots_ocr.py | 37 +- vllm/model_executor/models/ernie45_vl.py | 37 +- vllm/model_executor/models/glm4_1v.py | 35 +- vllm/model_executor/models/keye.py | 24 +- vllm/model_executor/models/ovis2_5.py | 6 +- vllm/model_executor/models/paddleocr_vl.py | 47 +-- vllm/model_executor/models/qwen2_5_vl.py | 42 +- vllm/model_executor/models/qwen2_vl.py | 38 +- .../models/qwen3_omni_moe_thinker.py | 15 +- vllm/model_executor/models/qwen3_vl.py | 26 +- vllm/model_executor/models/siglip2navit.py | 26 +- vllm/model_executor/models/vision.py | 8 +- vllm/platforms/cpu.py | 12 +- vllm/platforms/cuda.py | 362 +++++++++--------- vllm/platforms/interface.py | 42 +- vllm/platforms/rocm.py | 49 ++- vllm/platforms/tpu.py | 15 +- vllm/platforms/xpu.py | 34 +- vllm/v1/attention/backends/cpu_attn.py | 32 +- vllm/v1/attention/backends/flash_attn.py | 71 ++-- vllm/v1/attention/backends/flashinfer.py | 63 +-- vllm/v1/attention/backends/flex_attention.py | 21 +- vllm/v1/attention/backends/mla/common.py | 22 +- vllm/v1/attention/backends/mla/cutlass_mla.py | 16 +- .../attention/backends/mla/flashattn_mla.py | 27 ++ .../attention/backends/mla/flashinfer_mla.py | 26 +- vllm/v1/attention/backends/mla/flashmla.py | 37 +- .../attention/backends/mla/flashmla_sparse.py | 30 +- vllm/v1/attention/backends/mla/indexer.py | 6 +- vllm/v1/attention/backends/mla/triton_mla.py | 10 + vllm/v1/attention/backends/rocm_aiter_fa.py | 25 +- vllm/v1/attention/backends/rocm_attn.py | 10 +- vllm/v1/attention/backends/tree_attn.py | 26 +- vllm/v1/attention/backends/triton_attn.py | 47 ++- vllm/v1/attention/backends/xformers.py | 26 +- vllm/v1/spec_decode/eagle.py | 8 +- vllm/v1/worker/gpu_model_runner.py | 4 +- 61 files changed, 1333 insertions(+), 997 deletions(-) diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index a0d2076199b1..83a7df3b093f 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -890,11 +890,16 @@ steps: - vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py - vllm/model_executor/layers/quantization/utils/flashinfer_utils.py - vllm/v1/attention/backends/flashinfer.py + - vllm/v1/attention/backends/mla/cutlass_mla.py + - vllm/v1/attention/backends/mla/flashinfer_mla.py + - vllm/platforms/cuda.py + - vllm/attention/selector.py commands: - nvidia-smi - python3 examples/offline_inference/basic/chat.py # Attention # num_heads2 broken by https://github.com/flashinfer-ai/flashinfer/issues/1353 + - pytest -v -s tests/kernels/attention/test_attention_selector.py - pytest -v -s tests/kernels/attention/test_flashinfer.py -k 'not num_heads2' - pytest -v -s tests/kernels/attention/test_flashinfer_trtllm_attention.py - pytest -v -s tests/kernels/attention/test_cutlass_mla_decode.py diff --git a/tests/compile/test_fusion_attn.py b/tests/compile/test_fusion_attn.py index fecb1e2e918f..ea61c94953a7 100644 --- a/tests/compile/test_fusion_attn.py +++ b/tests/compile/test_fusion_attn.py @@ -10,7 +10,7 @@ from tests.v1.attention.utils import BatchSpec, create_common_attn_metadata from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant from vllm.attention import Attention, AttentionMetadata -from vllm.attention.backends.registry import _Backend +from vllm.attention.backends.registry import AttentionBackendEnum from vllm.attention.selector import global_force_attn_backend_context_manager from vllm.compilation.fusion_attn import ATTN_OP, AttnFusionPass from vllm.compilation.fx_utils import find_op_nodes @@ -104,7 +104,7 @@ def build_attn_metadata(self, batch_size: int) -> AttentionMetadata: # TODO(luka) use get_kv_cache_stride_order # Create dummy KV cache for the selected backend - if backend == _Backend.ROCM_ATTN: + if backend == AttentionBackendEnum.ROCM_ATTN: # k/v as 1st dimention # HND: [num_blocks, num_kv_heads, block_size, head_size] kv_cache = torch.zeros( @@ -116,7 +116,7 @@ def build_attn_metadata(self, batch_size: int) -> AttentionMetadata: dtype=self.kv_cache_dtype, device=self.device, ) - elif backend == _Backend.ROCM_AITER_UNIFIED_ATTN: + elif backend == AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN: # k/v as 1st dimention # NHD: [num_blocks, block_size, num_kv_heads, head_size] kv_cache = torch.zeros( @@ -128,7 +128,7 @@ def build_attn_metadata(self, batch_size: int) -> AttentionMetadata: dtype=self.kv_cache_dtype, device=self.device, ) - elif backend == _Backend.TRITON_ATTN: + elif backend == AttentionBackendEnum.TRITON_ATTN: # k/v as 2nd dimention # NHD: [num_blocks, block_size, num_kv_heads, head_size] kv_cache = torch.zeros( @@ -140,7 +140,7 @@ def build_attn_metadata(self, batch_size: int) -> AttentionMetadata: dtype=self.kv_cache_dtype, device=self.device, ) - elif backend == _Backend.FLASHINFER: + elif backend == AttentionBackendEnum.FLASHINFER: kv_cache = torch.zeros( num_blocks, 2, @@ -244,8 +244,8 @@ def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor): MODELS_FP4: list[tuple[str, type]] = [] HEADS: list[tuple[int, int]] = [] SPLIT_ATTENTION: list[bool] = [] -BACKENDS_FP8: list[_Backend] = [] -BACKENDS_FP4: list[_Backend] = [] +BACKENDS_FP8: list[AttentionBackendEnum] = [] +BACKENDS_FP4: list[AttentionBackendEnum] = [] if current_platform.is_cuda(): HEADS = [(64, 8), (40, 8)] @@ -261,8 +261,8 @@ def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor): TestAttentionNvfp4QuantPatternModel, ) ] - BACKENDS_FP8 = [_Backend.TRITON_ATTN, _Backend.FLASHINFER] - BACKENDS_FP4 = [_Backend.FLASHINFER] + BACKENDS_FP8 = [AttentionBackendEnum.TRITON_ATTN, AttentionBackendEnum.FLASHINFER] + BACKENDS_FP4 = [AttentionBackendEnum.FLASHINFER] elif current_platform.is_rocm(): HEADS = [(32, 8), (40, 8)] @@ -270,9 +270,9 @@ def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor): ("amd/Llama-3.1-8B-Instruct-FP8-KV", TestAttentionFp8StaticQuantPatternModel) ] BACKENDS = [ - _Backend.ROCM_AITER_UNIFIED_ATTN, - _Backend.ROCM_ATTN, - _Backend.TRITON_ATTN, + AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN, + AttentionBackendEnum.ROCM_ATTN, + AttentionBackendEnum.TRITON_ATTN, ] @@ -302,11 +302,11 @@ def test_attention_quant_pattern( custom_ops: str, model_name: str, model_class: type[AttentionQuantPatternModel], - backend: _Backend, + backend: AttentionBackendEnum, dist_init, ): """Test AttentionStaticQuantPattern fusion pass""" - if backend == _Backend.FLASHINFER and ( + if backend == AttentionBackendEnum.FLASHINFER and ( not current_platform.is_device_capability((10, 0)) or not has_flashinfer() ): pytest.skip("FlashInfer attn fusion requires Blackwell and flashinfer") @@ -314,6 +314,7 @@ def test_attention_quant_pattern( custom_ops_list = custom_ops.split(",") if custom_ops else [] device = torch.device("cuda:0") + torch.set_default_dtype(dtype) torch.manual_seed(42) vllm_config = VllmConfig( @@ -402,7 +403,7 @@ def test_attention_quant_pattern( result_fused_1 = model_compiled(q, k, v) - if backend == _Backend.FLASHINFER: + if backend == AttentionBackendEnum.FLASHINFER: # With the Flashinfer backend after the 1st round of the forward # pass, output quant scale should be loaded into the attn layer's # _o_scale_float, the 2nd round should reuse the loaded diff --git a/tests/compile/test_fusions_e2e.py b/tests/compile/test_fusions_e2e.py index 4b910bc28579..f67063cdf42e 100644 --- a/tests/compile/test_fusions_e2e.py +++ b/tests/compile/test_fusions_e2e.py @@ -11,7 +11,7 @@ import pytest import regex as re -from tests.v1.attention.utils import _Backend +from tests.v1.attention.utils import AttentionBackendEnum from vllm import LLM, SamplingParams from vllm.config import CompilationConfig, CompilationMode, CUDAGraphMode, PassConfig from vllm.platforms import current_platform @@ -24,7 +24,7 @@ class ModelBackendTestCase(NamedTuple): model_name: str model_kwargs: dict[str, Any] - backend: _Backend + backend: AttentionBackendEnum attention_fusions: int allreduce_fusions: int | None = None @@ -39,14 +39,14 @@ class ModelBackendTestCase(NamedTuple): # Use smaller model for L40s in CI model_name="RedHatAI/Meta-Llama-3.1-8B-Instruct-FP8", model_kwargs=dict(max_model_len=1024), - backend=_Backend.TRITON_ATTN, + backend=AttentionBackendEnum.TRITON_ATTN, attention_fusions=32, allreduce_fusions=65, ), ModelBackendTestCase( model_name="nvidia/Llama-4-Scout-17B-16E-Instruct-FP8", model_kwargs=dict(max_model_len=1024, kv_cache_dtype="fp8"), - backend=_Backend.FLASHINFER, + backend=AttentionBackendEnum.FLASHINFER, attention_fusions=48, allreduce_fusions=96, ), @@ -56,7 +56,7 @@ class ModelBackendTestCase(NamedTuple): ModelBackendTestCase( model_name="nvidia/Llama-3.1-8B-Instruct-FP4", model_kwargs=dict(max_model_len=1024, kv_cache_dtype="fp8"), - backend=_Backend.FLASHINFER, + backend=AttentionBackendEnum.FLASHINFER, attention_fusions=32, allreduce_fusions=65, ), @@ -67,7 +67,7 @@ class ModelBackendTestCase(NamedTuple): ModelBackendTestCase( model_name="meta-llama/Llama-3.1-8B-Instruct", model_kwargs=dict(max_model_len=1024), - backend=_Backend.TRITON_ATTN, + backend=AttentionBackendEnum.TRITON_ATTN, attention_fusions=0, allreduce_fusions=65, ), @@ -85,19 +85,19 @@ class ModelBackendTestCase(NamedTuple): ModelBackendTestCase( model_name="amd/Llama-3.1-8B-Instruct-FP8-KV", model_kwargs=dict(max_model_len=1024), - backend=_Backend.TRITON_ATTN, + backend=AttentionBackendEnum.TRITON_ATTN, attention_fusions=32, ), ModelBackendTestCase( model_name="amd/Llama-3.1-8B-Instruct-FP8-KV", model_kwargs=dict(max_model_len=1024), - backend=_Backend.ROCM_ATTN, + backend=AttentionBackendEnum.ROCM_ATTN, attention_fusions=32, ), ModelBackendTestCase( model_name="amd/Llama-3.1-8B-Instruct-FP8-KV", model_kwargs=dict(max_model_len=1024), - backend=_Backend.ROCM_AITER_UNIFIED_ATTN, + backend=AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN, attention_fusions=32, ), ] @@ -117,7 +117,7 @@ class ModelBackendTestCase(NamedTuple): def test_attn_quant( model_name: str, model_kwargs: dict[str, Any], - backend: _Backend, + backend: AttentionBackendEnum, attention_fusions: int, allreduce_fusions: int, custom_ops: str, @@ -125,7 +125,7 @@ def test_attn_quant( caplog_mp_spawn, monkeypatch, ): - if backend == _Backend.FLASHINFER and ( + if backend == AttentionBackendEnum.FLASHINFER and ( not current_platform.is_device_capability((10, 0)) or not has_flashinfer() ): pytest.skip("FlashInfer attn fusion requires Blackwell and flashinfer") @@ -208,7 +208,7 @@ def custom_ops_product(*custom_ops_lists: list[str]) -> Iterable[str]: def test_tp2_attn_quant_allreduce_rmsnorm( model_name: str, model_kwargs: dict, - backend: _Backend, + backend: AttentionBackendEnum, attention_fusions: int, allreduce_fusions: int, custom_ops: str, diff --git a/tests/config/test_multimodal_config.py b/tests/config/test_multimodal_config.py index b1a09d88ed9d..3d02893e52f1 100644 --- a/tests/config/test_multimodal_config.py +++ b/tests/config/test_multimodal_config.py @@ -3,13 +3,13 @@ import pytest -from vllm.attention.backends.registry import _Backend +from vllm.attention.backends.registry import AttentionBackendEnum from vllm.config.multimodal import MultiModalConfig def test_mm_encoder_attn_backend_str_conversion(): config = MultiModalConfig(mm_encoder_attn_backend="FLASH_ATTN") - assert config.mm_encoder_attn_backend == _Backend.FLASH_ATTN + assert config.mm_encoder_attn_backend == AttentionBackendEnum.FLASH_ATTN def test_mm_encoder_attn_backend_invalid(): @@ -20,6 +20,6 @@ def test_mm_encoder_attn_backend_invalid(): def test_mm_encoder_attn_backend_hash_updates(): base_hash = MultiModalConfig().compute_hash() overridden_hash = MultiModalConfig( - mm_encoder_attn_backend=_Backend.FLASH_ATTN + mm_encoder_attn_backend=AttentionBackendEnum.FLASH_ATTN ).compute_hash() assert base_hash != overridden_hash diff --git a/tests/kernels/attention/test_attention_selector.py b/tests/kernels/attention/test_attention_selector.py index 8149ce7672cd..29cc81be12e4 100644 --- a/tests/kernels/attention/test_attention_selector.py +++ b/tests/kernels/attention/test_attention_selector.py @@ -120,12 +120,13 @@ def test_env( elif device == "cuda": with patch("vllm.platforms.current_platform", CudaPlatform()): + capability = torch.cuda.get_device_capability() if use_mla: # CUDA MLA backend logic: # - CUTLASS_MLA: only supported with block_size == 128 - # and Blackwell GPUs (SM 10.0), V1 only + # and Blackwell GPUs (SM 10.x), V1 only # - FLASHINFER_MLA: only supported on Blackwell GPUs - # (SM 10.0+), V1 only + # (SM 10.x), V1 only # - FLASHMLA: only supported with block_size == 64 # - FLASH_ATTN_MLA: V1 only # - TRITON_MLA: fallback for other cases @@ -134,58 +135,72 @@ def test_env( if block_size != 128: # CUTLASS_MLA only supports block_size == 128 pytest.skip("CUTLASS_MLA only supports block_size 128") - else: - backend = get_attn_backend( - 16, torch.float16, None, block_size, use_mla=use_mla - ) - expected = "CUTLASS_MLA" - assert backend.get_name() == expected + if capability[0] != 10: + pytest.skip("CUTLASS MLA is not supported on this platform") + backend = get_attn_backend( + 576, torch.float16, None, block_size, use_mla=use_mla + ) + expected = "CUTLASS_MLA" + assert backend.get_name() == expected elif name == "FLASHINFER_MLA": + if capability[0] != 10: + pytest.skip( + "FlashInfer MLA is not supported on this platform" + ) if block_size not in [32, 64]: # FlashInfer MLA only supports block_size 32 or 64 pytest.skip( "FlashInfer MLA only supports block_size 32 or 64" ) - else: - backend = get_attn_backend( - 16, torch.float16, None, block_size, use_mla=use_mla - ) - expected = "FLASHINFER_MLA" - assert backend.get_name() == expected + backend = get_attn_backend( + 576, torch.float16, None, block_size, use_mla=use_mla + ) + expected = "FLASHINFER_MLA" + assert backend.get_name() == expected elif name == "FLASHMLA": if block_size != 64: # FlashMLA only supports block_size == 64 pytest.skip("FlashMLA only supports block_size 64") - else: - from vllm.v1.attention.backends.mla.flashmla import ( - is_flashmla_dense_supported, - ) + from vllm.v1.attention.backends.mla.flashmla import ( + is_flashmla_dense_supported, + ) - is_supported, _ = is_flashmla_dense_supported() - if not is_supported: - pytest.skip("FlashMLA not supported on this platform") - else: - backend = get_attn_backend( - 16, torch.float16, None, block_size, use_mla=use_mla - ) - expected = name - assert backend.get_name() == expected + is_supported, _ = is_flashmla_dense_supported() + if not is_supported: + pytest.skip("FlashMLA not supported on this platform") + backend = get_attn_backend( + 576, + torch.float16, + None, + block_size, + use_mla=use_mla, + ) + expected = name + assert backend.get_name() == expected elif name == "FLASH_ATTN_MLA": + from vllm.attention.utils.fa_utils import ( + flash_attn_supports_mla, + ) + + if not flash_attn_supports_mla(): + pytest.skip( + "FlashAttention MLA not supported on this platform" + ) backend = get_attn_backend( - 16, torch.float16, None, block_size, use_mla=use_mla + 576, torch.float16, None, block_size, use_mla=use_mla ) expected = "FLASH_ATTN_MLA" assert backend.get_name() == expected else: # TRITON_MLA or other fallback backend = get_attn_backend( - 16, torch.float16, None, block_size, use_mla=use_mla + 576, torch.float16, None, block_size, use_mla=use_mla ) expected = "TRITON_MLA" assert backend.get_name() == expected elif name == "FLASHINFER": backend = get_attn_backend( - 16, torch.float16, None, block_size, use_mla=use_mla + 64, torch.float16, None, block_size, use_mla=use_mla ) expected = "FLASHINFER" assert backend.get_name() == expected diff --git a/tests/kernels/attention/test_mha_attn.py b/tests/kernels/attention/test_mha_attn.py index 14d1618bca3c..183bbf3bf4e0 100644 --- a/tests/kernels/attention/test_mha_attn.py +++ b/tests/kernels/attention/test_mha_attn.py @@ -11,7 +11,7 @@ import pytest import torch -from vllm.attention.backends.registry import _Backend +from vllm.attention.backends.registry import AttentionBackendEnum from vllm.attention.layer import MultiHeadAttention from vllm.attention.selector import _cached_get_attn_backend from vllm.platforms import current_platform @@ -43,14 +43,14 @@ def test_mha_attn_platform(device: str): patch("vllm.model_executor.models.vision.current_platform", CpuPlatform()), ): attn = MultiHeadAttention(16, 64, scale=1) - assert attn.attn_backend == _Backend.TORCH_SDPA + assert attn.attn_backend == AttentionBackendEnum.TORCH_SDPA elif device == "hip": with ( patch("vllm.attention.layer.current_platform", RocmPlatform()), patch("vllm.model_executor.models.vision.current_platform", RocmPlatform()), ): attn = MultiHeadAttention(16, 64, scale=1) - assert attn.attn_backend == _Backend.TORCH_SDPA + assert attn.attn_backend == AttentionBackendEnum.TORCH_SDPA else: # Test CUDA with head_size=64 (divisible by 32) # - should use vLLM's FlashAttention @@ -59,7 +59,7 @@ def test_mha_attn_platform(device: str): patch("vllm.model_executor.models.vision.current_platform", CudaPlatform()), ): attn = MultiHeadAttention(16, 64, scale=1) - assert attn.attn_backend == _Backend.FLASH_ATTN + assert attn.attn_backend == AttentionBackendEnum.FLASH_ATTN # Test CUDA with head_size=72 (not divisible by 32) # - with upstream FA not available @@ -73,7 +73,7 @@ def test_mha_attn_platform(device: str): ), ): attn = MultiHeadAttention(16, 72, scale=1) - assert attn.attn_backend == _Backend.XFORMERS + assert attn.attn_backend == AttentionBackendEnum.XFORMERS # Test CUDA with head_size=72 (not divisible by 32) # - with upstream FA available @@ -96,7 +96,7 @@ def test_mha_attn_platform(device: str): ), ): attn = MultiHeadAttention(16, 72, scale=1) - assert attn.attn_backend == _Backend.FLASH_ATTN + assert attn.attn_backend == AttentionBackendEnum.FLASH_ATTN def ref_attention( diff --git a/tests/models/test_initialization.py b/tests/models/test_initialization.py index 48a6f34366cf..8c4bd6eaa2dd 100644 --- a/tests/models/test_initialization.py +++ b/tests/models/test_initialization.py @@ -93,6 +93,17 @@ def _initialize_kv_caches_v1(self, vllm_config): "pickle error when loading `transformers.models.auto.CONFIG_MAPPING`" ) + if model_arch == "DeepseekV32ForCausalLM": + from vllm.platforms import current_platform + + capability = current_platform.get_device_capability() + if capability and capability.major < 9: + pytest.skip( + f"DeepseekV32 requires Hopper (9.0+) or Blackwell (10.0+) " + f"for FLASHMLA_SPARSE backend. Current device has compute " + f"capability {capability.major}.{capability.minor}" + ) + with ( patch.object(V1EngineCore, "_initialize_kv_caches", _initialize_kv_caches_v1), monkeypatch.context() as m, diff --git a/tests/v1/attention/test_attention_backends.py b/tests/v1/attention/test_attention_backends.py index 08aeb6f298f6..b46002c5fa8f 100644 --- a/tests/v1/attention/test_attention_backends.py +++ b/tests/v1/attention/test_attention_backends.py @@ -15,7 +15,7 @@ create_vllm_config, try_get_attention_backend, ) -from vllm.attention.backends.registry import _Backend +from vllm.attention.backends.registry import AttentionBackendEnum from vllm.config import ModelConfig from vllm.platforms import current_platform from vllm.utils.math_utils import cdiv @@ -27,11 +27,11 @@ from vllm.v1.kv_cache_interface import FullAttentionSpec BACKENDS_TO_TEST = [ - _Backend.FLASH_ATTN, - _Backend.FLASHINFER, - _Backend.FLEX_ATTENTION, - _Backend.TRITON_ATTN, - _Backend.TREE_ATTN, + AttentionBackendEnum.FLASH_ATTN, + AttentionBackendEnum.FLASHINFER, + AttentionBackendEnum.FLEX_ATTENTION, + AttentionBackendEnum.TRITON_ATTN, + AttentionBackendEnum.TREE_ATTN, "FLEX_ATTENTION_SLOW", ] @@ -39,7 +39,7 @@ try: import flashinfer # noqa: F401 except ImportError: - BACKENDS_TO_TEST.remove(_Backend.FLASHINFER) + BACKENDS_TO_TEST.remove(AttentionBackendEnum.FLASHINFER) def _convert_dtype_to_torch(dtype): @@ -192,7 +192,7 @@ def __init__(self, device: torch.device): def run_attention_backend( - backend: _Backend, + backend: AttentionBackendEnum, kv_cache_spec: FullAttentionSpec, layer_names: list[str], vllm_config, @@ -211,13 +211,13 @@ def run_attention_backend( use_direct_block_mask = is_torch_equal_or_newer("2.9.0.dev0") if backend == "FLEX_ATTENTION_SLOW": - actual_backend = _Backend.FLEX_ATTENTION + actual_backend = AttentionBackendEnum.FLEX_ATTENTION use_direct_block_mask = False builder_cls, impl_cls = try_get_attention_backend(actual_backend) # Mock flashinfer's get_per_layer_parameters if needed - if actual_backend == _Backend.FLASHINFER: + if actual_backend == AttentionBackendEnum.FLASHINFER: import unittest.mock from vllm.v1.attention.backends.utils import PerLayerParameters @@ -246,7 +246,7 @@ def mock_get_per_layer_parameters(vllm_config, layer_names, impl_cls): else: # Build metadata builder = builder_cls(kv_cache_spec, layer_names, vllm_config, device) - if actual_backend == _Backend.FLEX_ATTENTION: + if actual_backend == AttentionBackendEnum.FLEX_ATTENTION: builder.direct_build = use_direct_block_mask attn_metadata = builder.build( common_prefix_len=0, @@ -289,7 +289,7 @@ def mock_get_per_layer_parameters(vllm_config, layer_names, impl_cls): def _test_backend_correctness( batch_spec: BatchSpec, model: str, - backend_to_test: list[_Backend | str], + backend_to_test: list[AttentionBackendEnum | str], mask_mod, *, block_size: int = 16, @@ -455,17 +455,20 @@ def _test_backend_correctness( # Select the appropriate KV cache format for each backend kv_cache_for_backend = kv_cache reset_kv_cache_layout = False - if backend_name in (_Backend.FLASHINFER, _Backend.TRITON_ATTN): + if backend_name in ( + AttentionBackendEnum.FLASHINFER, + AttentionBackendEnum.TRITON_ATTN, + ): kv_cache_for_backend = kv_cache.transpose(0, 1) - if backend_name == _Backend.FLASHINFER: + if backend_name == AttentionBackendEnum.FLASHINFER: # For FlashInfer default to HND layout and kv_cache_for_backend = ( kv_cache_for_backend.transpose(2, 3).contiguous().transpose(2, 3) ) set_kv_cache_layout("HND") reset_kv_cache_layout = True - elif backend_name == _Backend.TRITON_ATTN: + elif backend_name == AttentionBackendEnum.TRITON_ATTN: kv_cache_for_backend = kv_cache_for_backend.contiguous() try: @@ -547,7 +550,9 @@ def causal_mask_mod( batch_spec = BATCH_SPECS[batch_spec_name] LARGE_BLOCK_BACKENDS = ( - [_Backend.FLEX_ATTENTION] if is_torch_equal_or_newer("2.9.0.dev0") else [] + [AttentionBackendEnum.FLEX_ATTENTION] + if is_torch_equal_or_newer("2.9.0.dev0") + else [] ) SMALL_BLOCK_BACKENDS = [ x for x in BACKENDS_TO_TEST if x not in LARGE_BLOCK_BACKENDS @@ -573,9 +578,9 @@ def causal_mask_mod( SLIDING_WINDOW_BACKENDS_TO_TEST = [ - _Backend.FLASH_ATTN, - _Backend.FLEX_ATTENTION, - _Backend.TRITON_ATTN, + AttentionBackendEnum.FLASH_ATTN, + AttentionBackendEnum.FLEX_ATTENTION, + AttentionBackendEnum.TRITON_ATTN, "FLEX_ATTENTION_SLOW", ] @@ -612,7 +617,9 @@ def sliding_window_mask_mod( ) LARGE_BLOCK_BACKENDS = ( - [_Backend.FLEX_ATTENTION] if is_torch_equal_or_newer("2.9.0.dev0") else [] + [AttentionBackendEnum.FLEX_ATTENTION] + if is_torch_equal_or_newer("2.9.0.dev0") + else [] ) SMALL_BLOCK_BACKENDS = [ x for x in SLIDING_WINDOW_BACKENDS_TO_TEST if x not in LARGE_BLOCK_BACKENDS diff --git a/tests/v1/attention/test_mla_backends.py b/tests/v1/attention/test_mla_backends.py index 5679fafe63ee..1bd05e6183dc 100644 --- a/tests/v1/attention/test_mla_backends.py +++ b/tests/v1/attention/test_mla_backends.py @@ -18,12 +18,11 @@ try_get_attention_backend, ) from vllm import _custom_ops as ops -from vllm.attention.backends.registry import _Backend, backend_to_class_str +from vllm.attention.backends.registry import AttentionBackendEnum from vllm.attention.ops.flashmla import is_flashmla_dense_supported from vllm.attention.utils.fa_utils import flash_attn_supports_mla from vllm.config.vllm import set_current_vllm_config from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase -from vllm.utils.import_utils import resolve_obj_by_qualname from vllm.utils.math_utils import cdiv from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE from vllm.v1.attention.backends.mla.common import QueryLenSupport @@ -31,25 +30,25 @@ from vllm.v1.kv_cache_interface import FullAttentionSpec BACKENDS_TO_TEST = [ - _Backend.CUTLASS_MLA, - _Backend.FLASHMLA, - _Backend.FLASH_ATTN_MLA, - _Backend.FLASHINFER_MLA, - _Backend.TRITON_MLA, + AttentionBackendEnum.CUTLASS_MLA, + AttentionBackendEnum.FLASHMLA, + AttentionBackendEnum.FLASH_ATTN_MLA, + AttentionBackendEnum.FLASHINFER_MLA, + AttentionBackendEnum.TRITON_MLA, ] # Remove sm100 backends from the list if not using sm100 if not torch.cuda.is_available() or torch.cuda.get_device_properties(0).major < 10: - BACKENDS_TO_TEST.remove(_Backend.CUTLASS_MLA) - BACKENDS_TO_TEST.remove(_Backend.FLASHINFER_MLA) + BACKENDS_TO_TEST.remove(AttentionBackendEnum.CUTLASS_MLA) + BACKENDS_TO_TEST.remove(AttentionBackendEnum.FLASHINFER_MLA) # Remove FLASH_ATTN_MLA from the list if not supported if not flash_attn_supports_mla(): - BACKENDS_TO_TEST.remove(_Backend.FLASH_ATTN_MLA) + BACKENDS_TO_TEST.remove(AttentionBackendEnum.FLASH_ATTN_MLA) # Remove FLASHMLA from the list if not supported if not is_flashmla_dense_supported()[0]: - BACKENDS_TO_TEST.remove(_Backend.FLASHMLA) + BACKENDS_TO_TEST.remove(AttentionBackendEnum.FLASHMLA) SPEC_DECODE_BACKENDS = [] for backend in BACKENDS_TO_TEST: @@ -62,9 +61,7 @@ BACKEND_BLOCK_SIZES = {} for backend in BACKENDS_TO_TEST: - backend_class_str = backend_to_class_str(backend) - backend_class = resolve_obj_by_qualname(backend_class_str) - supported_sizes = backend_class.get_supported_kernel_block_size() + supported_sizes = backend.get_class().supported_kernel_block_sizes if supported_sizes: default_size = supported_sizes[0] block_size = ( @@ -291,7 +288,7 @@ def get_kv_cache_spec(self, vllm_config): def run_attention_backend( - backend: _Backend, + backend: AttentionBackendEnum, kv_cache_spec: FullAttentionSpec, layer_names: list[str], vllm_config, @@ -813,7 +810,7 @@ def test_backend_correctness( # Create a summary for the single-line failure message backend_names = [] for f in failures: - if "[_Backend." in f: + if "[AttentionBackendEnum." in f: backend_name = f.split("[")[1].split("]")[0] backend_names.append(backend_name) diff --git a/tests/v1/attention/utils.py b/tests/v1/attention/utils.py index b166d9d4ff68..dea89babd4b4 100644 --- a/tests/v1/attention/utils.py +++ b/tests/v1/attention/utils.py @@ -8,7 +8,7 @@ import torch from vllm.attention.backends.abstract import AttentionImpl -from vllm.attention.backends.registry import _Backend, backend_to_class_str +from vllm.attention.backends.registry import AttentionBackendEnum from vllm.config import ( CacheConfig, CompilationConfig, @@ -20,7 +20,6 @@ VllmConfig, ) from vllm.config.model import ModelDType -from vllm.utils.import_utils import resolve_obj_by_qualname from vllm.v1.attention.backends.utils import ( AttentionMetadataBuilder, CommonAttentionMetadata, @@ -120,15 +119,14 @@ def create_common_attn_metadata( def try_get_attention_backend( - backend: _Backend, + backend: AttentionBackendEnum, ) -> tuple[type[AttentionMetadataBuilder], type[AttentionImpl]]: """Try to get the attention backend class, skipping test if not found.""" - backend_class_str = backend_to_class_str(backend) try: - backend_class = resolve_obj_by_qualname(backend_class_str) + backend_class = backend.get_class() return backend_class.get_builder_cls(), backend_class.get_impl_cls() except ImportError as e: - pytest.skip(f"{backend_class_str} not available: {e}") + pytest.skip(f"{backend.name} not available: {e}") raise AssertionError("unreachable") from None diff --git a/tests/v1/spec_decode/test_eagle.py b/tests/v1/spec_decode/test_eagle.py index 47d05a20a65d..89d0ec769ac0 100644 --- a/tests/v1/spec_decode/test_eagle.py +++ b/tests/v1/spec_decode/test_eagle.py @@ -13,7 +13,7 @@ create_standard_kv_cache_spec, try_get_attention_backend, ) -from vllm.attention.backends.registry import _Backend +from vllm.attention.backends.registry import AttentionBackendEnum from vllm.config import ( CacheConfig, DeviceConfig, @@ -534,11 +534,17 @@ def create_deterministic_logits(token_ids): sampling_metadata = mock.MagicMock() if attn_backend == "FLASH_ATTN": - attn_metadata_builder_cls, _ = try_get_attention_backend(_Backend.FLASH_ATTN) + attn_metadata_builder_cls, _ = try_get_attention_backend( + AttentionBackendEnum.FLASH_ATTN + ) elif attn_backend == "TRITON_ATTN": - attn_metadata_builder_cls, _ = try_get_attention_backend(_Backend.TRITON_ATTN) + attn_metadata_builder_cls, _ = try_get_attention_backend( + AttentionBackendEnum.TRITON_ATTN + ) elif attn_backend == "TREE_ATTN": - attn_metadata_builder_cls, _ = try_get_attention_backend(_Backend.TREE_ATTN) + attn_metadata_builder_cls, _ = try_get_attention_backend( + AttentionBackendEnum.TREE_ATTN + ) else: raise ValueError(f"Unsupported attention backend: {attn_backend}") @@ -673,7 +679,9 @@ def create_deterministic_logits(token_ids, k: int): proposer.attn_layer_names = ["layer.0"] # Get the tree attention metadata builder. - attn_metadata_builder_cls, _ = try_get_attention_backend(_Backend.TREE_ATTN) + attn_metadata_builder_cls, _ = try_get_attention_backend( + AttentionBackendEnum.TREE_ATTN + ) attn_metadata_builder = attn_metadata_builder_cls( kv_cache_spec=create_standard_kv_cache_spec(proposer.vllm_config), layer_names=proposer.attn_layer_names, diff --git a/tests/v1/spec_decode/test_mtp.py b/tests/v1/spec_decode/test_mtp.py index 9ca7cf9e3e0e..6d59b58e739e 100644 --- a/tests/v1/spec_decode/test_mtp.py +++ b/tests/v1/spec_decode/test_mtp.py @@ -12,7 +12,7 @@ create_standard_kv_cache_spec, try_get_attention_backend, ) -from vllm.attention.backends.registry import _Backend +from vllm.attention.backends.registry import AttentionBackendEnum from vllm.config import ( CacheConfig, DeviceConfig, @@ -177,7 +177,9 @@ def create_deterministic_logits(batch_size, vocab_size, token_offset): sampling_metadata = mock.MagicMock() # Setup attention metadata - attn_metadata_builder_cls, _ = try_get_attention_backend(_Backend.FLASH_ATTN) + attn_metadata_builder_cls, _ = try_get_attention_backend( + AttentionBackendEnum.FLASH_ATTN + ) attn_metadata_builder = attn_metadata_builder_cls( kv_cache_spec=create_standard_kv_cache_spec(proposer.vllm_config), diff --git a/tests/v1/spec_decode/test_tree_attention.py b/tests/v1/spec_decode/test_tree_attention.py index b365e75d5514..6958d62dc7e9 100644 --- a/tests/v1/spec_decode/test_tree_attention.py +++ b/tests/v1/spec_decode/test_tree_attention.py @@ -10,7 +10,7 @@ create_vllm_config, try_get_attention_backend, ) -from vllm.attention.backends.registry import _Backend +from vllm.attention.backends.registry import AttentionBackendEnum from vllm.config import ParallelConfig, SpeculativeConfig from vllm.v1.attention.backends.utils import CommonAttentionMetadata @@ -35,7 +35,7 @@ def forward_attention( block_table: torch.Tensor, slot_mapping: torch.Tensor, seqlen_k: int, - backend: _Backend, + backend: AttentionBackendEnum, spec_token_tree: str | None = None, num_spec_tokens: int = 0, ) -> torch.Tensor: @@ -241,7 +241,7 @@ def test_tree_attn_correctness() -> None: block_table=block_table, slot_mapping=tree_slot_mapping, seqlen_k=seqlen_k, - backend=_Backend.TREE_ATTN, + backend=AttentionBackendEnum.TREE_ATTN, spec_token_tree=spec_token_tree, num_spec_tokens=tree_size_q - 1, ).view(batch_size, -1, num_heads, dim_per_head) @@ -278,7 +278,7 @@ def test_tree_attn_correctness() -> None: block_table=block_table, slot_mapping=branch_slot_mapping, seqlen_k=sequence_position + q_len, - backend=_Backend.FLASH_ATTN, + backend=AttentionBackendEnum.FLASH_ATTN, ).view(batch_size, -1, num_heads, dim_per_head) # Compare the outputs. diff --git a/tests/v1/worker/test_gpu_model_runner.py b/tests/v1/worker/test_gpu_model_runner.py index bc624658308b..b02d9a657407 100644 --- a/tests/v1/worker/test_gpu_model_runner.py +++ b/tests/v1/worker/test_gpu_model_runner.py @@ -185,9 +185,7 @@ def _make_mock_backend_for_kernel_block_size( supported_sizes: list[int | MultipleOf], ): class _MockBackend: - @staticmethod - def get_supported_kernel_block_size(): - return supported_sizes + supported_kernel_block_sizes = supported_sizes return _MockBackend() @@ -466,13 +464,20 @@ def test_kv_cache_stride_order(monkeypatch, model_runner): # This test checks if GPUModelRunner initializes correctly when an attention # backend enforces a non-default KV cache stride order. n_heads = model_runner.model_config.get_num_kv_heads(model_runner.parallel_config) - expected_kv_cache_shape = [ - 2, - NUM_BLOCKS, - BLOCK_SIZE, - n_heads, - model_runner.model_config.get_head_size(), - ] + head_size = model_runner.model_config.get_head_size() + + # Get the expected shape from the backend's get_kv_cache_shape method + # to ensure compatibility with different backends (triton vs flexattention) + attn_backend = None + for attn_group in model_runner._attn_group_iterator(): + attn_backend = attn_group.backend + break + + assert attn_backend is not None, "No attention backend found" + expected_kv_cache_shape = list( + attn_backend.get_kv_cache_shape(NUM_BLOCKS, BLOCK_SIZE, n_heads, head_size) + ) + # TODO mla test default_stride = tuple(range(5)) # Permutation that gets you back to expected kv shape diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index b54eaf4e2872..697beed91869 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -2,13 +2,18 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from abc import ABC, abstractmethod -from typing import Generic, Protocol, TypeVar +from typing import TYPE_CHECKING, ClassVar, Generic, Protocol, TypeVar, get_args import torch from vllm.model_executor.layers.linear import ColumnParallelLinear from vllm.model_executor.layers.quantization.utils.quant_utils import QuantKey +if TYPE_CHECKING: + from vllm.config.cache import CacheDType + from vllm.platforms.interface import DeviceCapability + from vllm.v1.attention.backends.utils import KVCacheLayoutType + class AttentionType: """ @@ -40,6 +45,9 @@ class AttentionBackend(ABC): # calling the custom op. When piecewise cudagraph is enabled, this # makes sure the output tensor is allocated inside the cudagraph. accept_output_buffer: bool = False + supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16] + supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [MultipleOf(1)] + supported_kv_cache_dtypes: ClassVar[list["CacheDType"]] = ["auto"] @staticmethod @abstractmethod @@ -51,10 +59,6 @@ def get_name() -> str: def get_impl_cls() -> type["AttentionImpl"]: raise NotImplementedError - @classmethod - def get_supported_kernel_block_size(cls) -> list[int | MultipleOf]: - return cls.get_impl_cls().get_supported_kernel_block_size() - @staticmethod @abstractmethod def get_builder_cls(): # -> Type["AttentionMetadataBuilder"]: @@ -79,6 +83,136 @@ def get_kv_cache_stride_order() -> tuple[int, ...]: def full_cls_name(cls) -> tuple[str, str]: return (cls.__module__, cls.__qualname__) + @classmethod + def get_supported_head_sizes(cls) -> list[int]: + return [] + + @classmethod + def supports_head_size(cls, head_size: int) -> bool: + supported_head_sizes = cls.get_supported_head_sizes() + return (not supported_head_sizes) or head_size in supported_head_sizes + + @classmethod + def supports_dtype(cls, dtype: torch.dtype) -> bool: + return dtype in cls.supported_dtypes + + @classmethod + def supports_kv_cache_dtype(cls, kv_cache_dtype: "CacheDType | None") -> bool: + if kv_cache_dtype is None: + return True + return (not cls.supported_kv_cache_dtypes) or ( + kv_cache_dtype in cls.supported_kv_cache_dtypes + ) + + @classmethod + def supports_block_size(cls, block_size: int | None) -> bool: + from vllm.config.cache import BlockSize + + if block_size is None: + return True + + valid_sizes = get_args(BlockSize) + if block_size not in valid_sizes: + return False + + if not cls.supported_kernel_block_sizes: + return True + + for supported_size in cls.supported_kernel_block_sizes: + is_multiple_of = ( + isinstance(supported_size, MultipleOf) + and block_size % supported_size.base == 0 + ) + is_int_equal = ( + isinstance(supported_size, int) and block_size == supported_size + ) + if is_multiple_of or is_int_equal: + return True + return False + + @classmethod + def is_mla(cls) -> bool: + return False + + @classmethod + def supports_sink(cls) -> bool: + return False + + @classmethod + def is_sparse(cls) -> bool: + return False + + @classmethod + def supports_compute_capability(cls, capability: "DeviceCapability") -> bool: + return True + + @classmethod + def supports_combination( + cls, + head_size: int, + dtype: torch.dtype, + kv_cache_dtype: "CacheDType | None", + block_size: int | None, + use_mla: bool, + has_sink: bool, + use_sparse: bool, + device_capability: "DeviceCapability", + ) -> str | None: + return None + + @classmethod + def validate_configuration( + cls, + head_size: int, + dtype: torch.dtype, + kv_cache_dtype: "CacheDType | None", + block_size: int | None, + use_mla: bool, + has_sink: bool, + use_sparse: bool, + device_capability: "DeviceCapability", + ) -> list[str]: + invalid_reasons = [] + if not cls.supports_head_size(head_size): + invalid_reasons.append("head_size not supported") + if not cls.supports_dtype(dtype): + invalid_reasons.append("dtype not supported") + if not cls.supports_kv_cache_dtype(kv_cache_dtype): + invalid_reasons.append("kv_cache_dtype not supported") + if not cls.supports_block_size(block_size): + invalid_reasons.append("block_size not supported") + if use_mla != cls.is_mla(): + if use_mla: + invalid_reasons.append("MLA not supported") + else: + invalid_reasons.append("non-MLA not supported") + if has_sink and not cls.supports_sink(): + invalid_reasons.append("sink setting not supported") + if use_sparse != cls.is_sparse(): + if use_sparse: + invalid_reasons.append("sparse not supported") + else: + invalid_reasons.append("non-sparse not supported") + if not cls.supports_compute_capability(device_capability): + invalid_reasons.append("compute capability not supported") + combination_reason = cls.supports_combination( + head_size, + dtype, + kv_cache_dtype, + block_size, + use_mla, + has_sink, + use_sparse, + device_capability, + ) + if combination_reason is not None: + invalid_reasons.append(combination_reason) + return invalid_reasons + + @classmethod + def get_required_kv_cache_layout(cls) -> "KVCacheLayoutType | None": + return None + class AttentionMetadata: pass @@ -151,11 +285,6 @@ def __init__( ) -> None: raise NotImplementedError - @staticmethod - def get_supported_kernel_block_size() -> list[int | MultipleOf]: - # TODO: implement this function for all backends. - return [MultipleOf(1)] - @abstractmethod def forward( self, diff --git a/vllm/attention/backends/registry.py b/vllm/attention/backends/registry.py index 05d0159d0861..768d15cb9c82 100644 --- a/vllm/attention/backends/registry.py +++ b/vllm/attention/backends/registry.py @@ -3,108 +3,192 @@ """Attention backend registry""" import enum +from collections.abc import Callable +from typing import TYPE_CHECKING, cast +from vllm.logger import init_logger from vllm.utils.import_utils import resolve_obj_by_qualname +if TYPE_CHECKING: + from vllm.attention.backends.abstract import AttentionBackend -class _Backend(enum.Enum): - FLASH_ATTN = enum.auto() - TRITON_ATTN = enum.auto() - XFORMERS = enum.auto() - ROCM_ATTN = enum.auto() - ROCM_AITER_MLA = enum.auto() - ROCM_AITER_FA = enum.auto() # used for ViT attn backend - TORCH_SDPA = enum.auto() - FLASHINFER = enum.auto() - FLASHINFER_MLA = enum.auto() - TRITON_MLA = enum.auto() - CUTLASS_MLA = enum.auto() - FLASHMLA = enum.auto() - FLASHMLA_SPARSE = enum.auto() - FLASH_ATTN_MLA = enum.auto() - PALLAS = enum.auto() - IPEX = enum.auto() - NO_ATTENTION = enum.auto() - FLEX_ATTENTION = enum.auto() - TREE_ATTN = enum.auto() - ROCM_AITER_UNIFIED_ATTN = enum.auto() - - -BACKEND_MAP = { - _Backend.FLASH_ATTN: "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend", # noqa: E501 - _Backend.TRITON_ATTN: "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend", # noqa: E501 - _Backend.XFORMERS: "vllm.v1.attention.backends.xformers.XFormersAttentionBackend", # noqa: E501 - _Backend.ROCM_ATTN: "vllm.v1.attention.backends.rocm_attn.RocmAttentionBackend", # noqa: E501 - _Backend.ROCM_AITER_MLA: "vllm.v1.attention.backends.mla.rocm_aiter_mla.AiterMLABackend", # noqa: E501 - _Backend.ROCM_AITER_FA: "vllm.v1.attention.backends.rocm_aiter_fa.AiterFlashAttentionBackend", # noqa: E501 - _Backend.TORCH_SDPA: "vllm.v1.attention.backends.cpu_attn.TorchSDPABackend", # noqa: E501 - _Backend.FLASHINFER: "vllm.v1.attention.backends.flashinfer.FlashInferBackend", # noqa: E501 - _Backend.FLASHINFER_MLA: "vllm.v1.attention.backends.mla.flashinfer_mla.FlashInferMLABackend", # noqa: E501 - _Backend.TRITON_MLA: "vllm.v1.attention.backends.mla.triton_mla.TritonMLABackend", # noqa: E501 - _Backend.CUTLASS_MLA: "vllm.v1.attention.backends.mla.cutlass_mla.CutlassMLABackend", # noqa: E501 - _Backend.FLASHMLA: "vllm.v1.attention.backends.mla.flashmla.FlashMLABackend", # noqa: E501 - _Backend.FLASHMLA_SPARSE: "vllm.v1.attention.backends.mla.flashmla_sparse.FlashMLASparseBackend", # noqa: E501 - _Backend.FLASH_ATTN_MLA: "vllm.v1.attention.backends.mla.flashattn_mla.FlashAttnMLABackend", # noqa: E501 - _Backend.PALLAS: "vllm.v1.attention.backends.pallas.PallasAttentionBackend", # noqa: E501 - _Backend.FLEX_ATTENTION: "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend", # noqa: E501 - _Backend.TREE_ATTN: "vllm.v1.attention.backends.tree_attn.TreeAttentionBackend", # noqa: E501 - _Backend.ROCM_AITER_UNIFIED_ATTN: "vllm.v1.attention.backends.rocm_aiter_unified_attn.RocmAiterUnifiedAttentionBackend", # noqa: E501 -} - - -def register_attn_backend(backend: _Backend, class_path: str | None = None): - """ - Decorator: register a custom attention backend into BACKEND_MAPPING. - - If class_path is provided, use it. - - Otherwise, auto-generate from the class object. - Validation: only checks if 'backend' is a valid _Backend enum member. - Overwriting existing mappings is allowed. This enables other hardware - platforms to plug in custom out-of-tree backends. - """ - if not isinstance(backend, _Backend): - raise ValueError(f"{backend} is not a valid _Backend enum value.") +logger = init_logger(__name__) - def decorator(cls): - path = class_path or f"{cls.__module__}.{cls.__qualname__}" - BACKEND_MAP[backend] = path - return cls - return decorator +class _AttentionBackendEnumMeta(enum.EnumMeta): + """Metaclass for AttentionBackendEnum to provide better error messages.""" + def __getitem__(cls, name: str): + """Get backend by name with helpful error messages.""" + try: + return super().__getitem__(name) + except KeyError: + members = cast("dict[str, AttentionBackendEnum]", cls.__members__).values() + valid_backends = ", ".join(m.name for m in members) + raise ValueError( + f"Unknown attention backend: '{name}'. " + f"Valid options are: {valid_backends}" + ) from None -def backend_to_class_str(backend: _Backend) -> str: - """Get the backend class string - Args: - backend: The backend enum value +class AttentionBackendEnum(enum.Enum, metaclass=_AttentionBackendEnumMeta): + """Enumeration of all supported attention backends. - Returns: - The backend class string - """ - return BACKEND_MAP[backend] + The enum value is the default class path, but this can be overridden + at runtime using register_backend(). + To get the actual backend class (respecting overrides), use: + backend.get_class() + """ -def backend_to_class(backend: _Backend) -> type: - """Get the backend class. + FLASH_ATTN = "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend" + TRITON_ATTN = "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend" + XFORMERS = "vllm.v1.attention.backends.xformers.XFormersAttentionBackend" + ROCM_ATTN = "vllm.v1.attention.backends.rocm_attn.RocmAttentionBackend" + ROCM_AITER_MLA = "vllm.v1.attention.backends.mla.rocm_aiter_mla.AiterMLABackend" + ROCM_AITER_FA = ( + "vllm.v1.attention.backends.rocm_aiter_fa.AiterFlashAttentionBackend" + ) + TORCH_SDPA = "vllm.v1.attention.backends.cpu_attn.TorchSDPABackend" + FLASHINFER = "vllm.v1.attention.backends.flashinfer.FlashInferBackend" + FLASHINFER_MLA = ( + "vllm.v1.attention.backends.mla.flashinfer_mla.FlashInferMLABackend" + ) + TRITON_MLA = "vllm.v1.attention.backends.mla.triton_mla.TritonMLABackend" + CUTLASS_MLA = "vllm.v1.attention.backends.mla.cutlass_mla.CutlassMLABackend" + FLASHMLA = "vllm.v1.attention.backends.mla.flashmla.FlashMLABackend" + FLASHMLA_SPARSE = ( + "vllm.v1.attention.backends.mla.flashmla_sparse.FlashMLASparseBackend" + ) + FLASH_ATTN_MLA = "vllm.v1.attention.backends.mla.flashattn_mla.FlashAttnMLABackend" + PALLAS = "vllm.v1.attention.backends.pallas.PallasAttentionBackend" + IPEX = "vllm.v1.attention.backends.ipex.IpexAttentionBackend" + NO_ATTENTION = "vllm.v1.attention.backends.no_attention.NoAttentionBackend" + FLEX_ATTENTION = "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend" + TREE_ATTN = "vllm.v1.attention.backends.tree_attn.TreeAttentionBackend" + ROCM_AITER_UNIFIED_ATTN = ( + "vllm.v1.attention.backends.rocm_aiter_unified_attn." + "RocmAiterUnifiedAttentionBackend" + ) + # Placeholder for third-party/custom backends - must be registered before use + CUSTOM = "" + + def get_path(self, include_classname: bool = True) -> str: + """Get the class path for this backend (respects overrides). + + Returns: + The fully qualified class path string + + Raises: + ValueError: If Backend.CUSTOM is used without being registered + """ + path = _OVERRIDES.get(self, self.value) + if not path: + raise ValueError( + f"Backend {self.name} must be registered before use. " + f"Use register_backend(Backend.{self.name}, 'your.module.YourClass')" + ) + if not include_classname: + path = path.rsplit(".", 1)[0] + return path + + def get_class(self) -> "type[AttentionBackend]": + """Get the backend class (respects overrides). + + Returns: + The backend class + + Raises: + ImportError: If the backend class cannot be imported + ValueError: If Backend.CUSTOM is used without being registered + """ + return resolve_obj_by_qualname(self.get_path()) + + def is_overridden(self) -> bool: + """Check if this backend has been overridden. + + Returns: + True if the backend has a registered override + """ + return self in _OVERRIDES + + def clear_override(self) -> None: + """Clear any override for this backend, reverting to the default.""" + _OVERRIDES.pop(self, None) + + +_OVERRIDES: dict[AttentionBackendEnum, str] = {} + + +def register_backend( + backend: AttentionBackendEnum, class_path: str | None = None +) -> Callable[[type], type]: + """Register or override a backend implementation. Args: - backend: The backend enum value + backend: The AttentionBackendEnum member to register + class_path: Optional class path. If not provided and used as + decorator, will be auto-generated from the class. Returns: - The backend class + Decorator function if class_path is None, otherwise a no-op + + Examples: + # Override an existing backend + @register_backend(AttentionBackendEnum.FLASH_ATTN) + class MyCustomFlashAttn: + ... + + # Register a custom third-party backend + @register_backend(AttentionBackendEnum.CUSTOM) + class MyCustomBackend: + ... + + # Direct registration + register_backend( + AttentionBackendEnum.CUSTOM, + "my.module.MyCustomBackend" + ) """ - backend_class_name = backend_to_class_str(backend) - return resolve_obj_by_qualname(backend_class_name) + def decorator(cls: type) -> type: + _OVERRIDES[backend] = f"{cls.__module__}.{cls.__qualname__}" + return cls -def backend_name_to_enum(backend_name: str) -> _Backend | None: - """ - Convert a string backend name to a _Backend enum value. + if class_path is not None: + _OVERRIDES[backend] = class_path + return lambda x: x - Returns: - _Backend: enum value if backend_name is a valid in-tree type - None: otherwise it's an invalid in-tree type or an out-of-tree platform - is loaded. + return decorator + + +# Backwards compatibility alias for plugins +class _BackendMeta(type): + """Metaclass to provide deprecation warnings when accessing _Backend.""" + + def __getattribute__(cls, name: str): + if name not in ("__class__", "__mro__", "__name__"): + logger.warning( + "_Backend has been renamed to AttentionBackendEnum. " + "Please update your code to use AttentionBackendEnum instead. " + "_Backend will be removed in a future release." + ) + return getattr(AttentionBackendEnum, name) + + def __getitem__(cls, name: str): + logger.warning( + "_Backend has been renamed to AttentionBackendEnum. " + "Please update your code to use AttentionBackendEnum instead. " + "_Backend will be removed in a future release." + ) + return AttentionBackendEnum[name] + + +class _Backend(metaclass=_BackendMeta): + """Deprecated: Use AttentionBackendEnum instead. + + This class is provided for backwards compatibility with plugins + and will be removed in a future release. """ - assert backend_name is not None - return _Backend[backend_name] if backend_name in _Backend.__members__ else None + + pass diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index acab0529f352..ec705126c710 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -12,7 +12,7 @@ import vllm.envs as envs from vllm.attention import AttentionType from vllm.attention.backends.abstract import AttentionBackend, MLAAttentionImpl -from vllm.attention.backends.registry import _Backend, backend_name_to_enum +from vllm.attention.backends.registry import AttentionBackendEnum from vllm.attention.selector import get_attn_backend from vllm.attention.utils.kv_sharing_utils import validate_kv_sharing_target from vllm.config import CacheConfig, get_current_vllm_config @@ -99,40 +99,44 @@ def check_upstream_fa_availability(dtype: torch.dtype): def maybe_get_vit_flash_attn_backend( - attn_backend: _Backend, + attn_backend: AttentionBackendEnum, use_upstream_fa: bool, - attn_backend_override: _Backend | None = None, -) -> tuple[_Backend, Callable | None]: + attn_backend_override: AttentionBackendEnum | None = None, +) -> tuple[AttentionBackendEnum, Callable | None]: if current_platform.is_rocm(): if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA and on_gfx9(): - attn_backend = _Backend.ROCM_AITER_FA + attn_backend = AttentionBackendEnum.ROCM_AITER_FA elif ( check_upstream_fa_availability(torch.get_default_dtype()) and on_gfx9() and attn_backend_override is None ): - attn_backend = _Backend.FLASH_ATTN + attn_backend = AttentionBackendEnum.FLASH_ATTN use_upstream_fa = True else: - return _Backend.TORCH_SDPA, None + return AttentionBackendEnum.TORCH_SDPA, None elif current_platform.is_cuda(): - if attn_backend != _Backend.FLASH_ATTN and check_upstream_fa_availability( - torch.get_default_dtype() + if ( + attn_backend != AttentionBackendEnum.FLASH_ATTN + and check_upstream_fa_availability(torch.get_default_dtype()) ): - attn_backend = _Backend.FLASH_ATTN + attn_backend = AttentionBackendEnum.FLASH_ATTN use_upstream_fa = True elif current_platform.is_xpu(): - assert attn_backend == _Backend.FLASH_ATTN, ( + assert attn_backend == AttentionBackendEnum.FLASH_ATTN, ( "XPU platform only supports FLASH_ATTN as vision attention backend." ) use_upstream_fa = False else: - return _Backend.TORCH_SDPA, None + return AttentionBackendEnum.TORCH_SDPA, None - if attn_backend in {_Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA}: - if attn_backend == _Backend.ROCM_AITER_FA: + if attn_backend in { + AttentionBackendEnum.FLASH_ATTN, + AttentionBackendEnum.ROCM_AITER_FA, + }: + if attn_backend == AttentionBackendEnum.ROCM_AITER_FA: from aiter import flash_attn_varlen_func else: if use_upstream_fa: @@ -309,7 +313,7 @@ def __init__( kv_sharing_target_layer_name, **extra_impl_args, ) - self.backend = backend_name_to_enum(self.attn_backend.get_name()) + self.backend = AttentionBackendEnum[self.attn_backend.get_name()] self.dtype = dtype # For cuda-alike (CUDA and ROCM) and cpu platforms, we control how @@ -530,13 +534,13 @@ def __init__( backend if backend in { - _Backend.TORCH_SDPA, - _Backend.XFORMERS, - _Backend.PALLAS, - _Backend.ROCM_AITER_FA, - _Backend.FLASH_ATTN, + AttentionBackendEnum.TORCH_SDPA, + AttentionBackendEnum.XFORMERS, + AttentionBackendEnum.PALLAS, + AttentionBackendEnum.ROCM_AITER_FA, + AttentionBackendEnum.FLASH_ATTN, } - else _Backend.TORCH_SDPA + else AttentionBackendEnum.TORCH_SDPA ) self.attn_backend, self._flash_attn_varlen_func = ( @@ -547,17 +551,23 @@ def __init__( ) ) - if self.attn_backend == _Backend.XFORMERS and not check_xformers_availability(): - self.attn_backend = _Backend.TORCH_SDPA + if ( + self.attn_backend == AttentionBackendEnum.XFORMERS + and not check_xformers_availability() + ): + self.attn_backend = AttentionBackendEnum.TORCH_SDPA self.is_flash_attn_backend = self.attn_backend in { - _Backend.FLASH_ATTN, - _Backend.ROCM_AITER_FA, + AttentionBackendEnum.FLASH_ATTN, + AttentionBackendEnum.ROCM_AITER_FA, } # this condition is just to make sure that the # use_upstream_fa in the log is correct - if current_platform.is_rocm() and self.attn_backend == _Backend.FLASH_ATTN: + if ( + current_platform.is_rocm() + and self.attn_backend == AttentionBackendEnum.FLASH_ATTN + ): use_upstream_fa = True logger.info_once( @@ -606,17 +616,17 @@ def forward( max_seqlen_k=kv_len, softmax_scale=self.scale, ) - elif self.attn_backend == _Backend.XFORMERS: + elif self.attn_backend == AttentionBackendEnum.XFORMERS: from xformers import ops as xops out = xops.memory_efficient_attention_forward( query, key, value, scale=self.scale ) - elif self.attn_backend == _Backend.TORCH_SDPA: + elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA: query, key, value = (x.transpose(1, 2) for x in (query, key, value)) out = F.scaled_dot_product_attention(query, key, value, scale=self.scale) out = out.transpose(1, 2) - elif self.attn_backend == _Backend.PALLAS: + elif self.attn_backend == AttentionBackendEnum.PALLAS: query, key, value = (x.transpose(1, 2) for x in (query, key, value)) from torch_xla.experimental.custom_kernel import flash_attention diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index 9c26a8d40eda..6e5fa854d35f 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -4,14 +4,15 @@ import os from collections.abc import Generator from contextlib import contextmanager -from dataclasses import dataclass from functools import cache +from typing import cast, get_args import torch import vllm.envs as envs from vllm.attention.backends.abstract import AttentionBackend -from vllm.attention.backends.registry import _Backend, backend_name_to_enum +from vllm.attention.backends.registry import AttentionBackendEnum +from vllm.config.cache import CacheDType from vllm.logger import init_logger from vllm.utils import STR_BACKEND_ENV_VAR from vllm.utils.import_utils import resolve_obj_by_qualname @@ -19,18 +20,18 @@ logger = init_logger(__name__) -def get_env_variable_attn_backend() -> _Backend | None: +def get_env_variable_attn_backend() -> AttentionBackendEnum | None: """ Get the backend override specified by the vLLM attention backend environment variable, if one is specified. Returns: - * _Backend enum value if an override is specified + * AttentionBackendEnum value if an override is specified * None otherwise """ backend_name = os.environ.get(STR_BACKEND_ENV_VAR) - return None if backend_name is None else backend_name_to_enum(backend_name) + return None if backend_name is None else AttentionBackendEnum[backend_name] # Global state allows a particular choice of backend @@ -40,10 +41,10 @@ def get_env_variable_attn_backend() -> _Backend | None: # # THIS SELECTION TAKES PRECEDENCE OVER THE # VLLM_ATTENTION_BACKEND ENVIRONMENT VARIABLE -forced_attn_backend: _Backend | None = None +forced_attn_backend: AttentionBackendEnum | None = None -def global_force_attn_backend(attn_backend: _Backend | None) -> None: +def global_force_attn_backend(attn_backend: AttentionBackendEnum | None) -> None: """ Force all attention operations to use a specified backend. @@ -58,7 +59,7 @@ def global_force_attn_backend(attn_backend: _Backend | None) -> None: forced_attn_backend = attn_backend -def get_global_forced_attn_backend() -> _Backend | None: +def get_global_forced_attn_backend() -> AttentionBackendEnum | None: """ Get the currently-forced choice of attention backend, or None if auto-selection is currently enabled. @@ -66,78 +67,28 @@ def get_global_forced_attn_backend() -> _Backend | None: return forced_attn_backend -@dataclass(frozen=True) -class _IsSupported: - can_import: bool - head_size: bool - dtype: bool - - def __bool__(self) -> bool: - return self.can_import and self.head_size and self.dtype - - -def is_attn_backend_supported( - attn_backend: str | type[AttentionBackend], - head_size: int, - dtype: torch.dtype, - *, - allow_import_error: bool = True, -) -> _IsSupported: - if isinstance(attn_backend, str): - try: - attn_backend = resolve_obj_by_qualname(attn_backend) - except ImportError: - if not allow_import_error: - raise - - return _IsSupported(can_import=False, head_size=False, dtype=False) - - assert isinstance(attn_backend, type) - - # TODO: Update the interface once V0 is removed - if get_supported_head_sizes := getattr( - attn_backend, "get_supported_head_sizes", None - ): - is_head_size_supported = head_size in get_supported_head_sizes() - elif validate_head_size := getattr(attn_backend, "validate_head_size", None): - try: - validate_head_size(head_size) - is_head_size_supported = True - except Exception: - is_head_size_supported = False - else: - raise NotImplementedError( - f"{attn_backend.__name__} does not support head size validation" - ) - - if get_supported_dtypes := getattr(attn_backend, "get_supported_dtypes", None): - is_dtype_supported = dtype in get_supported_dtypes() - else: - raise NotImplementedError( - f"{attn_backend.__name__} does not support dtype validation" - ) - - return _IsSupported( - can_import=True, - head_size=is_head_size_supported, - dtype=is_dtype_supported, - ) - - def get_attn_backend( head_size: int, dtype: torch.dtype, kv_cache_dtype: str | None, - block_size: int, + block_size: int | None, use_mla: bool = False, has_sink: bool = False, use_sparse: bool = False, ) -> type[AttentionBackend]: """Selects which attention backend to use and lazily imports it.""" + + if kv_cache_dtype is not None: + valid_cache_dtypes = get_args(CacheDType) + assert kv_cache_dtype in valid_cache_dtypes, ( + f"Invalid kv_cache_dtype: {kv_cache_dtype}. " + f"Valid values are: {valid_cache_dtypes}" + ) + return _cached_get_attn_backend( head_size=head_size, dtype=dtype, - kv_cache_dtype=kv_cache_dtype, + kv_cache_dtype=cast(CacheDType | None, kv_cache_dtype), block_size=block_size, use_mla=use_mla, has_sink=has_sink, @@ -149,8 +100,8 @@ def get_attn_backend( def _cached_get_attn_backend( head_size: int, dtype: torch.dtype, - kv_cache_dtype: str | None, - block_size: int, + kv_cache_dtype: CacheDType | None, + block_size: int | None, use_mla: bool = False, has_sink: bool = False, use_sparse: bool = False, @@ -161,7 +112,9 @@ def _cached_get_attn_backend( # THIS SELECTION OVERRIDES THE VLLM_ATTENTION_BACKEND # ENVIRONMENT VARIABLE. selected_backend = None - backend_by_global_setting: _Backend | None = get_global_forced_attn_backend() + backend_by_global_setting: AttentionBackendEnum | None = ( + get_global_forced_attn_backend() + ) if backend_by_global_setting is not None: selected_backend = backend_by_global_setting else: @@ -177,12 +130,13 @@ def _cached_get_attn_backend( STR_BACKEND_ENV_VAR, ) backend_by_env_var = backend_by_env_var.removesuffix("_VLLM_V1") - selected_backend = backend_name_to_enum(backend_by_env_var) - if selected_backend is None: + try: + selected_backend = AttentionBackendEnum[backend_by_env_var] + except KeyError as e: raise ValueError( - f"Invalid attention backend: '{backend_by_env_var}'. " - f"Valid backends are: {list(_Backend.__members__.keys())}" - ) + f"Invalid attention backend: '{backend_by_env_var}'. Valid " + f"backends are: {list(AttentionBackendEnum.__members__.keys())}" + ) from e # get device-specific attn_backend from vllm.platforms import current_platform @@ -202,12 +156,26 @@ def _cached_get_attn_backend( raise ValueError( f"Invalid attention backend for {current_platform.device_name}" ) - return resolve_obj_by_qualname(attention_cls) + backend = resolve_obj_by_qualname(attention_cls) + + # Adjust kv cache layout if the selected backend requires a specific one + required_layout = backend.get_required_kv_cache_layout() + if required_layout is not None: + from vllm.v1.attention.backends.utils import set_kv_cache_layout + + set_kv_cache_layout(required_layout) + logger.info( + "Using %s KV cache layout for %s backend.", + required_layout, + backend.get_name(), + ) + + return backend @contextmanager def global_force_attn_backend_context_manager( - attn_backend: _Backend, + attn_backend: AttentionBackendEnum, ) -> Generator[None, None, None]: """ Globally force a vLLM attention backend override within a diff --git a/vllm/config/cache.py b/vllm/config/cache.py index 031df3091f1c..864cf1be81b2 100644 --- a/vllm/config/cache.py +++ b/vllm/config/cache.py @@ -21,7 +21,15 @@ logger = init_logger(__name__) BlockSize = Literal[1, 8, 16, 32, 64, 128, 256] -CacheDType = Literal["auto", "bfloat16", "fp8", "fp8_e4m3", "fp8_e5m2", "fp8_inc"] +CacheDType = Literal[ + "auto", + "bfloat16", + "fp8", + "fp8_e4m3", + "fp8_e5m2", + "fp8_inc", + "fp8_ds_mla", +] MambaDType = Literal["auto", "float32"] PrefixCachingHashAlgo = Literal["sha256", "sha256_cbor"] KVOffloadingBackend = Literal["native", "lmcache"] diff --git a/vllm/config/model.py b/vllm/config/model.py index 44c044c76168..6ce91ebb87b9 100644 --- a/vllm/config/model.py +++ b/vllm/config/model.py @@ -45,7 +45,7 @@ import vllm.model_executor.layers.quantization as me_quant import vllm.model_executor.models as me_models - from vllm.attention.backends.registry import _Backend + from vllm.attention.backends.registry import AttentionBackendEnum from vllm.config.load import LoadConfig from vllm.config.parallel import ParallelConfig from vllm.model_executor.layers.quantization import QuantizationMethods @@ -53,7 +53,7 @@ else: PretrainedConfig = Any - _Backend = Any + AttentionBackendEnum = Any me_quant = LazyLoader( "model_executor", globals(), "vllm.model_executor.layers.quantization" ) @@ -302,7 +302,7 @@ class ModelConfig: mm_processor_cache_type: InitVar[MMCacheType | None] = None mm_shm_cache_max_object_size_mb: InitVar[int | None] = None mm_encoder_tp_mode: InitVar[MMEncoderTPMode | None] = None - mm_encoder_attn_backend: InitVar[_Backend | str | None] = None + mm_encoder_attn_backend: InitVar[AttentionBackendEnum | str | None] = None interleave_mm_strings: InitVar[bool | None] = None skip_mm_profiling: InitVar[bool | None] = None video_pruning_rate: InitVar[float | None] = None @@ -420,7 +420,7 @@ def __post_init__( mm_processor_cache_type: MMCacheType | None, mm_shm_cache_max_object_size_mb: int | None, mm_encoder_tp_mode: MMEncoderTPMode | None, - mm_encoder_attn_backend: _Backend | str | None, + mm_encoder_attn_backend: AttentionBackendEnum | str | None, interleave_mm_strings: bool | None, skip_mm_profiling: bool | None, video_pruning_rate: float | None, diff --git a/vllm/config/multimodal.py b/vllm/config/multimodal.py index ef73720efe09..9348c1b2af8c 100644 --- a/vllm/config/multimodal.py +++ b/vllm/config/multimodal.py @@ -11,9 +11,9 @@ from vllm.config.utils import config if TYPE_CHECKING: - from vllm.attention.backends.registry import _Backend + from vllm.attention.backends.registry import AttentionBackendEnum else: - _Backend = Any + AttentionBackendEnum = Any @dataclass @@ -125,10 +125,10 @@ class MultiModalConfig: DP (which is controlled by `--data-parallel-size`). This is only supported on a per-model basis and falls back to `"weights"` if the encoder does not support DP.""" - mm_encoder_attn_backend: _Backend | None = None + mm_encoder_attn_backend: AttentionBackendEnum | None = None """Optional override for the multi-modal encoder attention backend when using vision transformers. Accepts any value from - `vllm.attention.backends.registry._Backend` (e.g. `FLASH_ATTN`).""" + `vllm.attention.backends.registry.AttentionBackendEnum` (e.g. `FLASH_ATTN`).""" interleave_mm_strings: bool = False """Enable fully interleaved support for multimodal prompts, while using --chat-template-content-format=string.""" @@ -167,26 +167,16 @@ def _validate_limit_per_prompt( @field_validator("mm_encoder_attn_backend", mode="before") @classmethod - def _validate_mm_encoder_attn_backend(cls, value: object) -> _Backend | None: - from vllm.attention.backends.registry import ( - _Backend as BackendEnum, - ) - from vllm.attention.backends.registry import ( - backend_name_to_enum, - ) - - if value is None or isinstance(value, BackendEnum): + def _validate_mm_encoder_attn_backend( + cls, value: str | AttentionBackendEnum | None + ) -> AttentionBackendEnum | None: + if value is None or isinstance(value, AttentionBackendEnum): return value - if isinstance(value, str): - candidate = backend_name_to_enum(value.upper()) - if candidate is not None: - return candidate - - valid_backends = ", ".join(sorted(BackendEnum.__members__.keys())) - raise ValueError( - f"Invalid mm encoder attention backend. Expected one of: {valid_backends}." + assert isinstance(value, str), ( + "mm_encoder_attn_backend must be a string or an AttentionBackendEnum." ) + return AttentionBackendEnum[value.upper()] @model_validator(mode="after") def _validate_multimodal_config(self): diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index ff9770b72bd3..6c20eee1ecbf 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -21,7 +21,7 @@ import zmq from vllm import envs -from vllm.attention.backends.registry import _Backend, backend_name_to_enum +from vllm.attention.backends.registry import AttentionBackendEnum from vllm.attention.selector import get_attn_backend from vllm.config import VllmConfig from vllm.distributed.kv_transfer.kv_connector.v1.base import ( @@ -876,9 +876,9 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): use_mla=self.use_mla, ) self.backend_name = backend.get_name() - attn_backend = backend_name_to_enum(self.backend_name) - self._use_flashinfer = attn_backend == _Backend.FLASHINFER - self._use_pallas = attn_backend == _Backend.PALLAS + attn_backend = AttentionBackendEnum[self.backend_name] + self._use_flashinfer = attn_backend == AttentionBackendEnum.FLASHINFER + self._use_pallas = attn_backend == AttentionBackendEnum.PALLAS self.kv_cache_layout = get_kv_cache_layout() self.host_buffer_kv_cache_layout = self.kv_cache_layout logger.debug("Detected attention backend %s", self.backend_name) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index b12b7082af62..d3913553320f 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -32,7 +32,7 @@ from typing_extensions import TypeIs, deprecated import vllm.envs as envs -from vllm.attention.backends.registry import _Backend +from vllm.attention.backends.registry import AttentionBackendEnum from vllm.config import ( CacheConfig, CompilationConfig, @@ -462,7 +462,7 @@ class EngineArgs: MultiModalConfig.mm_shm_cache_max_object_size_mb ) mm_encoder_tp_mode: MMEncoderTPMode = MultiModalConfig.mm_encoder_tp_mode - mm_encoder_attn_backend: _Backend | str | None = ( + mm_encoder_attn_backend: AttentionBackendEnum | str | None = ( MultiModalConfig.mm_encoder_attn_backend ) io_processor_plugin: str | None = None diff --git a/vllm/envs.py b/vllm/envs.py index 52178e5f5250..52a9671bc46e 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -626,14 +626,14 @@ def get_vllm_port() -> int | None: # - "FLASH_ATTN_MLA": use FlashAttention for MLA # - "FLASHINFER_MLA": use FlashInfer for MLA # - "CUTLASS_MLA": use CUTLASS for MLA - # All possible options loaded dynamically from _Backend enum + # All possible options loaded dynamically from AttentionBackendEnum "VLLM_ATTENTION_BACKEND": env_with_choices( "VLLM_ATTENTION_BACKEND", None, lambda: list( __import__( - "vllm.attention.backends.registry", fromlist=["_Backend"] - )._Backend.__members__.keys() + "vllm.attention.backends.registry", fromlist=["AttentionBackendEnum"] + ).AttentionBackendEnum.__members__.keys() ), ), # If set, vllm will use flashinfer sampler diff --git a/vllm/model_executor/models/dots_ocr.py b/vllm/model_executor/models/dots_ocr.py index 6d462ad8ae62..1b2bb60a17c1 100644 --- a/vllm/model_executor/models/dots_ocr.py +++ b/vllm/model_executor/models/dots_ocr.py @@ -9,7 +9,7 @@ from torch.nn import LayerNorm from transformers.models.qwen2_vl import Qwen2VLProcessor -from vllm.attention.backends.registry import _Backend +from vllm.attention.backends.registry import AttentionBackendEnum from vllm.attention.layer import ( check_upstream_fa_availability, maybe_get_vit_flash_attn_backend, @@ -256,7 +256,7 @@ def __init__( quant_config: QuantizationConfig | None = None, prefix: str = "", use_data_parallel: bool = False, - attn_backend_override: _Backend | None = None, + attn_backend_override: AttentionBackendEnum | None = None, ) -> None: super().__init__() @@ -303,17 +303,17 @@ def __init__( ) ) if self.attn_backend not in { - _Backend.FLASH_ATTN, - _Backend.TORCH_SDPA, - _Backend.XFORMERS, - _Backend.ROCM_AITER_FA, + AttentionBackendEnum.FLASH_ATTN, + AttentionBackendEnum.TORCH_SDPA, + AttentionBackendEnum.XFORMERS, + AttentionBackendEnum.ROCM_AITER_FA, }: raise RuntimeError( f"Unsupported vision attention backend: {self.attn_backend}" ) self.is_flash_attn_backend = self.attn_backend in { - _Backend.FLASH_ATTN, - _Backend.ROCM_AITER_FA, + AttentionBackendEnum.FLASH_ATTN, + AttentionBackendEnum.ROCM_AITER_FA, } def forward( @@ -361,7 +361,7 @@ def forward( self.num_attention_heads_per_partition, self.hidden_size_per_attention_head, ) - elif self.attn_backend == _Backend.TORCH_SDPA: + elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA: outputs = [] for i in range(1, len(cu_seqlens)): s = int(cu_seqlens[i - 1]) @@ -373,7 +373,7 @@ def forward( out_i = out_i.permute(0, 2, 1, 3) outputs.append(out_i) context_layer = torch.cat(outputs, dim=1) if outputs else q[:, :0] - elif self.attn_backend == _Backend.XFORMERS: + elif self.attn_backend == AttentionBackendEnum.XFORMERS: from xformers import ops as xops from xformers.ops.fmha.attn_bias import BlockDiagonalMask @@ -514,7 +514,7 @@ def __init__( quant_config: QuantizationConfig | None = None, prefix: str = "", use_data_parallel: bool = False, - attn_backend_override: _Backend | None = None, + attn_backend_override: AttentionBackendEnum | None = None, ): super().__init__() @@ -567,7 +567,7 @@ def __init__( require_post_norm: bool | None = None, prefix: str = "", use_data_parallel: bool = False, - attn_backend_override: _Backend | None = None, + attn_backend_override: AttentionBackendEnum | None = None, ) -> None: super().__init__() self.config = config @@ -582,10 +582,11 @@ def __init__( dtype=torch.get_default_dtype(), attn_backend_override=attn_backend_override, ) - if self.attn_backend != _Backend.FLASH_ATTN and check_upstream_fa_availability( - torch.get_default_dtype() + if ( + self.attn_backend != AttentionBackendEnum.FLASH_ATTN + and check_upstream_fa_availability(torch.get_default_dtype()) ): - self.attn_backend = _Backend.FLASH_ATTN + self.attn_backend = AttentionBackendEnum.FLASH_ATTN self.out_hidden_size = config.hidden_size # Keep blocks for compatibility with other vision towers num_layers = ( @@ -666,11 +667,11 @@ def compute_attn_mask_seqlen( ) -> tuple[int | None, list[int] | None]: max_seqlen, seqlens = None, None if ( - self.attn_backend == _Backend.FLASH_ATTN - or self.attn_backend == _Backend.ROCM_AITER_FA + self.attn_backend == AttentionBackendEnum.FLASH_ATTN + or self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA ): max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() - elif self.attn_backend == _Backend.XFORMERS: + elif self.attn_backend == AttentionBackendEnum.XFORMERS: seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() return max_seqlen, seqlens diff --git a/vllm/model_executor/models/ernie45_vl.py b/vllm/model_executor/models/ernie45_vl.py index f287cff12086..97182a25f82b 100644 --- a/vllm/model_executor/models/ernie45_vl.py +++ b/vllm/model_executor/models/ernie45_vl.py @@ -36,7 +36,7 @@ from einops import rearrange, repeat from transformers import BatchFeature, PretrainedConfig -from vllm.attention.backends.registry import _Backend +from vllm.attention.backends.registry import AttentionBackendEnum from vllm.attention.layer import ( check_upstream_fa_availability, maybe_get_vit_flash_attn_backend, @@ -164,7 +164,7 @@ def __init__( projection_size: int, quant_config: QuantizationConfig | None = None, prefix: str = "", - attn_backend_override: _Backend | None = None, + attn_backend_override: AttentionBackendEnum | None = None, ) -> None: super().__init__() # Per attention head and per partition values. @@ -211,17 +211,17 @@ def __init__( ) if self.attn_backend not in { - _Backend.FLASH_ATTN, - _Backend.TORCH_SDPA, - _Backend.XFORMERS, - _Backend.ROCM_AITER_FA, + AttentionBackendEnum.FLASH_ATTN, + AttentionBackendEnum.TORCH_SDPA, + AttentionBackendEnum.XFORMERS, + AttentionBackendEnum.ROCM_AITER_FA, }: raise RuntimeError( f"Ernie45-VL does not support {self.attn_backend} backend now." ) self.is_flash_attn_backend = self.attn_backend in { - _Backend.FLASH_ATTN, - _Backend.ROCM_AITER_FA, + AttentionBackendEnum.FLASH_ATTN, + AttentionBackendEnum.ROCM_AITER_FA, } def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]: @@ -291,7 +291,7 @@ def forward( context_layer = rearrange( output, "(b s) h d -> s b (h d)", b=batch_size ).contiguous() - elif self.attn_backend == _Backend.TORCH_SDPA: + elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA: # Execute attention entry by entry for speed & less VRAM. outputs = [] for i in range(1, len(cu_seqlens)): @@ -310,7 +310,7 @@ def forward( context_layer = rearrange( context_layer, "b s h d -> s b (h d)" ).contiguous() - elif self.attn_backend == _Backend.XFORMERS: + elif self.attn_backend == AttentionBackendEnum.XFORMERS: from xformers import ops as xops from xformers.ops.fmha.attn_bias import BlockDiagonalMask @@ -370,7 +370,7 @@ def __init__( norm_layer: Callable[[int], nn.Module] | None = None, quant_config: QuantizationConfig | None = None, prefix: str = "", - attn_backend_override: _Backend | None = None, + attn_backend_override: AttentionBackendEnum | None = None, ) -> None: super().__init__() @@ -463,7 +463,7 @@ def __init__( norm_eps: float = 1e-6, quant_config: QuantizationConfig | None = None, prefix: str = "", - attn_backend_override: _Backend | None = None, + attn_backend_override: AttentionBackendEnum | None = None, ) -> None: super().__init__() patch_size = vision_config.patch_size @@ -515,10 +515,11 @@ def __init__( dtype=torch.get_default_dtype(), attn_backend_override=attn_backend_override, ) - if self.attn_backend != _Backend.FLASH_ATTN and check_upstream_fa_availability( - torch.get_default_dtype() + if ( + self.attn_backend != AttentionBackendEnum.FLASH_ATTN + and check_upstream_fa_availability(torch.get_default_dtype()) ): - self.attn_backend = _Backend.FLASH_ATTN + self.attn_backend = AttentionBackendEnum.FLASH_ATTN @property def dtype(self) -> torch.dtype: @@ -565,11 +566,11 @@ def compute_attn_mask_seqlen( ) -> tuple[int | None, list[int] | None]: max_seqlen, seqlens = None, None if ( - self.attn_backend == _Backend.FLASH_ATTN - or self.attn_backend == _Backend.ROCM_AITER_FA + self.attn_backend == AttentionBackendEnum.FLASH_ATTN + or self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA ): max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() - elif self.attn_backend == _Backend.XFORMERS: + elif self.attn_backend == AttentionBackendEnum.XFORMERS: seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() return max_seqlen, seqlens diff --git a/vllm/model_executor/models/glm4_1v.py b/vllm/model_executor/models/glm4_1v.py index b9cd3545ec45..776527fdd973 100644 --- a/vllm/model_executor/models/glm4_1v.py +++ b/vllm/model_executor/models/glm4_1v.py @@ -46,7 +46,7 @@ from transformers.models.glm4v.video_processing_glm4v import Glm4vVideoProcessor from transformers.video_utils import VideoMetadata -from vllm.attention.backends.registry import _Backend +from vllm.attention.backends.registry import AttentionBackendEnum from vllm.attention.layer import ( check_upstream_fa_availability, maybe_get_vit_flash_attn_backend, @@ -252,7 +252,7 @@ def __init__( quant_config: QuantizationConfig | None = None, prefix: str = "", use_data_parallel: bool = False, - attn_backend_override: _Backend | None = None, + attn_backend_override: AttentionBackendEnum | None = None, ) -> None: super().__init__() # Per attention head and per partition values. @@ -306,18 +306,18 @@ def __init__( ) if self.attn_backend not in { - _Backend.FLASH_ATTN, - _Backend.TORCH_SDPA, - _Backend.XFORMERS, - _Backend.ROCM_AITER_FA, + AttentionBackendEnum.FLASH_ATTN, + AttentionBackendEnum.TORCH_SDPA, + AttentionBackendEnum.XFORMERS, + AttentionBackendEnum.ROCM_AITER_FA, }: raise RuntimeError( f"GLM-4V does not support {self.attn_backend} backend now." ) self.is_flash_attn_backend = self.attn_backend in { - _Backend.FLASH_ATTN, - _Backend.ROCM_AITER_FA, + AttentionBackendEnum.FLASH_ATTN, + AttentionBackendEnum.ROCM_AITER_FA, } def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]: @@ -377,7 +377,7 @@ def forward( context_layer = rearrange( output, "(b s) h d -> s b (h d)", b=batch_size ).contiguous() - elif self.attn_backend == _Backend.TORCH_SDPA: + elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA: # Execute attention entry by entry for speed & less VRAM. outputs = [] for i in range(1, len(cu_seqlens)): @@ -396,7 +396,7 @@ def forward( context_layer = rearrange( context_layer, "b s h d -> s b (h d)" ).contiguous() - elif self.attn_backend == _Backend.XFORMERS: + elif self.attn_backend == AttentionBackendEnum.XFORMERS: from xformers import ops as xops from xformers.ops.fmha.attn_bias import BlockDiagonalMask @@ -425,7 +425,7 @@ def __init__( quant_config: QuantizationConfig | None = None, prefix: str = "", use_data_parallel: bool = False, - attn_backend_override: _Backend | None = None, + attn_backend_override: AttentionBackendEnum | None = None, ) -> None: super().__init__() if norm_layer is None: @@ -703,7 +703,7 @@ def __init__( quant_config: QuantizationConfig | None = None, prefix: str = "", use_data_parallel: bool = False, - attn_backend_override: _Backend | None = None, + attn_backend_override: AttentionBackendEnum | None = None, ) -> None: super().__init__() @@ -772,10 +772,11 @@ def __init__( dtype=torch.get_default_dtype(), attn_backend_override=attn_backend_override, ) - if self.attn_backend != _Backend.FLASH_ATTN and check_upstream_fa_availability( - torch.get_default_dtype() + if ( + self.attn_backend != AttentionBackendEnum.FLASH_ATTN + and check_upstream_fa_availability(torch.get_default_dtype()) ): - self.attn_backend = _Backend.FLASH_ATTN + self.attn_backend = AttentionBackendEnum.FLASH_ATTN @property def dtype(self) -> torch.dtype: @@ -824,8 +825,8 @@ def compute_attn_mask_seqlen( max_seqlen, seqlens = None, None seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() if ( - self.attn_backend == _Backend.FLASH_ATTN - or self.attn_backend == _Backend.ROCM_AITER_FA + self.attn_backend == AttentionBackendEnum.FLASH_ATTN + or self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA ): max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() return max_seqlen, seqlens diff --git a/vllm/model_executor/models/keye.py b/vllm/model_executor/models/keye.py index 42f16ad9f3b3..80d7e6c5b0cd 100644 --- a/vllm/model_executor/models/keye.py +++ b/vllm/model_executor/models/keye.py @@ -16,7 +16,7 @@ from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling from transformers.utils import torch_int -from vllm.attention.backends.registry import _Backend +from vllm.attention.backends.registry import AttentionBackendEnum from vllm.attention.layer import ( maybe_get_vit_flash_attn_backend, ) @@ -360,7 +360,7 @@ def __init__( config: PretrainedConfig, quant_config: QuantizationConfig | None = None, prefix: str = "", - attn_backend_override: _Backend | None = None, + attn_backend_override: AttentionBackendEnum | None = None, ): super().__init__() self.config = config @@ -414,17 +414,17 @@ def __init__( ) if self.attn_backend not in { - _Backend.FLASH_ATTN, - _Backend.XFORMERS, - _Backend.ROCM_AITER_FA, + AttentionBackendEnum.FLASH_ATTN, + AttentionBackendEnum.XFORMERS, + AttentionBackendEnum.ROCM_AITER_FA, }: raise RuntimeError( f"Keye-VL does not support {self.attn_backend} backend now." ) self.is_flash_attn_backend = self.attn_backend in { - _Backend.FLASH_ATTN, - _Backend.ROCM_AITER_FA, + AttentionBackendEnum.FLASH_ATTN, + AttentionBackendEnum.ROCM_AITER_FA, } def forward( @@ -489,7 +489,7 @@ def forward( softmax_scale=self.scale, ) context_layer = rearrange(output, "(b s) ... -> b s ...", b=batch_size) - elif self.attn_backend == _Backend.XFORMERS: + elif self.attn_backend == AttentionBackendEnum.XFORMERS: from xformers import ops as xops from xformers.ops.fmha.attn_bias import BlockDiagonalMask @@ -536,7 +536,7 @@ def __init__( config: PretrainedConfig, quant_config: QuantizationConfig | None = None, prefix: str = "", - attn_backend_override: _Backend | None = None, + attn_backend_override: AttentionBackendEnum | None = None, ): super().__init__() self.embed_dim = config.hidden_size @@ -590,7 +590,7 @@ def __init__( config: PretrainedConfig, quant_config: QuantizationConfig | None = None, prefix: str = "", - attn_backend_override: _Backend | None = None, + attn_backend_override: AttentionBackendEnum | None = None, ): super().__init__() self.config = config @@ -685,7 +685,7 @@ def __init__( config: PretrainedConfig, quant_config: QuantizationConfig | None = None, prefix: str = "", - attn_backend_override: _Backend | None = None, + attn_backend_override: AttentionBackendEnum | None = None, ): super().__init__() self.config = config @@ -768,7 +768,7 @@ def __init__( config: PretrainedConfig, quant_config: QuantizationConfig | None = None, prefix: str = "", - attn_backend_override: _Backend | None = None, + attn_backend_override: AttentionBackendEnum | None = None, ): super().__init__() diff --git a/vllm/model_executor/models/ovis2_5.py b/vllm/model_executor/models/ovis2_5.py index f6461ae9a412..9a4d69dea096 100644 --- a/vllm/model_executor/models/ovis2_5.py +++ b/vllm/model_executor/models/ovis2_5.py @@ -10,7 +10,7 @@ import torch.nn as nn from transformers import BaseImageProcessor, BatchFeature, PretrainedConfig -from vllm.attention.backends.registry import _Backend +from vllm.attention.backends.registry import AttentionBackendEnum from vllm.config import VllmConfig from vllm.config.multimodal import BaseDummyOptions from vllm.model_executor.layers.linear import ReplicatedLinear @@ -106,7 +106,7 @@ def __init__( quant_config: QuantizationConfig | None = None, prefix: str = "", use_data_parallel: bool = False, - attn_backend_override: _Backend | None = None, + attn_backend_override: AttentionBackendEnum | None = None, ): super().__init__() self.config = config @@ -135,7 +135,7 @@ def _init_backbone( quant_config: QuantizationConfig | None = None, prefix: str = "", use_data_parallel: bool = False, - attn_backend_override: _Backend | None = None, + attn_backend_override: AttentionBackendEnum | None = None, ): model_type = config.model_type if model_type == "siglip2_navit": diff --git a/vllm/model_executor/models/paddleocr_vl.py b/vllm/model_executor/models/paddleocr_vl.py index 12ae15699e7d..86d7d1c11ffe 100644 --- a/vllm/model_executor/models/paddleocr_vl.py +++ b/vllm/model_executor/models/paddleocr_vl.py @@ -31,7 +31,7 @@ ) from transformers.utils import torch_int -from vllm.attention.backends.registry import _Backend +from vllm.attention.backends.registry import AttentionBackendEnum from vllm.attention.layer import ( check_upstream_fa_availability, maybe_get_vit_flash_attn_backend, @@ -580,8 +580,8 @@ def __init__( projection_size: int, quant_config: QuantizationConfig | None = None, prefix: str = "", - attn_backend: _Backend = _Backend.TORCH_SDPA, - attn_backend_override: _Backend | None = None, + attn_backend: AttentionBackendEnum = AttentionBackendEnum.TORCH_SDPA, + attn_backend_override: AttentionBackendEnum | None = None, use_upstream_fa: bool = False, ) -> None: super().__init__() @@ -621,8 +621,8 @@ def __init__( ) ) self.is_flash_attn_backend = self.attn_backend in { - _Backend.FLASH_ATTN, - _Backend.ROCM_AITER_FA, + AttentionBackendEnum.FLASH_ATTN, + AttentionBackendEnum.ROCM_AITER_FA, } def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]: @@ -680,10 +680,10 @@ def forward( cu_seqlens, max_seqlen, batch_size, - self.attn_backend == _Backend.ROCM_AITER_FA, + self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA, self.use_upstream_fa, ) - elif self.attn_backend == _Backend.TORCH_SDPA: + elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA: outputs = [] for i in range(1, len(cu_seqlens)): start_idx = cu_seqlens[i - 1] @@ -702,7 +702,7 @@ def forward( context_layer = rearrange( context_layer, "b s h d -> s b (h d)" ).contiguous() - elif self.attn_backend == _Backend.XFORMERS: + elif self.attn_backend == AttentionBackendEnum.XFORMERS: if seqlens is None: raise ValueError("xFormers attention backend requires seqlens tensor.") context_layer = vit_xformers_attn_wrapper(q, k, v, seqlens) @@ -786,8 +786,8 @@ def __init__( quant_config: QuantizationConfig | None = None, prefix: str = "", *, - attn_backend: _Backend = _Backend.TORCH_SDPA, - attn_backend_override: _Backend | None = None, + attn_backend: AttentionBackendEnum = AttentionBackendEnum.TORCH_SDPA, + attn_backend_override: AttentionBackendEnum | None = None, use_upstream_fa: bool = False, ): super().__init__() @@ -847,7 +847,7 @@ def __init__( config: PretrainedConfig, quant_config: QuantizationConfig | None = None, prefix: str = "", - attn_backend_override: _Backend | None = None, + attn_backend_override: AttentionBackendEnum | None = None, ): super().__init__() self.config = config @@ -861,16 +861,16 @@ def __init__( ) self.use_upstream_fa = False if self.attn_backend not in { - _Backend.FLASH_ATTN, - _Backend.ROCM_AITER_FA, + AttentionBackendEnum.FLASH_ATTN, + AttentionBackendEnum.ROCM_AITER_FA, } and check_upstream_fa_availability(torch.get_default_dtype()): - self.attn_backend = _Backend.FLASH_ATTN + self.attn_backend = AttentionBackendEnum.FLASH_ATTN self.use_upstream_fa = True if self.attn_backend not in { - _Backend.FLASH_ATTN, - _Backend.TORCH_SDPA, - _Backend.XFORMERS, - _Backend.ROCM_AITER_FA, + AttentionBackendEnum.FLASH_ATTN, + AttentionBackendEnum.TORCH_SDPA, + AttentionBackendEnum.XFORMERS, + AttentionBackendEnum.ROCM_AITER_FA, }: raise RuntimeError( f"PaddleOCR-VL does not support {self.attn_backend} backend now." @@ -943,9 +943,12 @@ def forward( max_seqlen = None seqlens = None - if self.attn_backend in {_Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA}: + if self.attn_backend in { + AttentionBackendEnum.FLASH_ATTN, + AttentionBackendEnum.ROCM_AITER_FA, + }: max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max() - elif self.attn_backend == _Backend.XFORMERS: + elif self.attn_backend == AttentionBackendEnum.XFORMERS: seqlens = cu_seqlens[1:] - cu_seqlens[:-1] hidden_states = inputs_embeds @@ -966,7 +969,7 @@ def __init__( config: PretrainedConfig, quant_config: QuantizationConfig | None = None, prefix: str = "", - attn_backend_override: _Backend | None = None, + attn_backend_override: AttentionBackendEnum | None = None, ): super().__init__() self.config = config @@ -1016,7 +1019,7 @@ def __init__( config, quant_config: QuantizationConfig | None = None, prefix: str = "", - attn_backend_override: _Backend | None = None, + attn_backend_override: AttentionBackendEnum | None = None, ): super().__init__() diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index 48834ba699e4..3292cf8220ff 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -42,7 +42,7 @@ Qwen2_5_VLVisionConfig, ) -from vllm.attention.backends.registry import _Backend +from vllm.attention.backends.registry import AttentionBackendEnum from vllm.attention.layer import maybe_get_vit_flash_attn_backend from vllm.attention.ops.vit_attn_wrappers import ( vit_flash_attn_wrapper, @@ -315,9 +315,9 @@ def __init__( quant_config: QuantizationConfig | None = None, prefix: str = "", use_data_parallel: bool = False, - attn_backend: _Backend = _Backend.TORCH_SDPA, + attn_backend: AttentionBackendEnum = AttentionBackendEnum.TORCH_SDPA, use_upstream_fa: bool = False, - attn_backend_override: _Backend | None = None, + attn_backend_override: AttentionBackendEnum | None = None, ) -> None: super().__init__() # Per attention head and per partition values. @@ -364,13 +364,16 @@ def __init__( # On ROCm with FLASH_ATTN backend, upstream flash_attn is used from vllm.platforms import current_platform - if current_platform.is_rocm() and self.attn_backend == _Backend.FLASH_ATTN: + if ( + current_platform.is_rocm() + and self.attn_backend == AttentionBackendEnum.FLASH_ATTN + ): self.use_upstream_fa = True if current_platform.is_xpu(): self.use_upstream_fa = False self.is_flash_attn_backend = self.attn_backend in { - _Backend.FLASH_ATTN, - _Backend.ROCM_AITER_FA, + AttentionBackendEnum.FLASH_ATTN, + AttentionBackendEnum.ROCM_AITER_FA, } def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]: @@ -431,10 +434,10 @@ def forward( cu_seqlens, max_seqlen, batch_size, - self.attn_backend == _Backend.ROCM_AITER_FA, + self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA, self.use_upstream_fa, ) - elif self.attn_backend == _Backend.TORCH_SDPA: + elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA: # Execute attention entry by entry for speed & less VRAM. from vllm.platforms import current_platform @@ -450,7 +453,7 @@ def forward( v, cu_seqlens, ) - elif self.attn_backend == _Backend.XFORMERS: + elif self.attn_backend == AttentionBackendEnum.XFORMERS: context_layer = vit_xformers_attn_wrapper(q, k, v, seqlens) output, _ = self.proj(context_layer) @@ -478,9 +481,9 @@ def __init__( quant_config: QuantizationConfig | None = None, prefix: str = "", use_data_parallel: bool = False, - attn_backend: _Backend = _Backend.TORCH_SDPA, + attn_backend: AttentionBackendEnum = AttentionBackendEnum.TORCH_SDPA, use_upstream_fa: bool = False, - attn_backend_override: _Backend | None = None, + attn_backend_override: AttentionBackendEnum | None = None, ) -> None: super().__init__() if norm_layer is None: @@ -656,7 +659,7 @@ def __init__( quant_config: QuantizationConfig | None = None, prefix: str = "", use_data_parallel: bool = False, - attn_backend_override: _Backend | None = None, + attn_backend_override: AttentionBackendEnum | None = None, ) -> None: super().__init__() @@ -708,10 +711,10 @@ def __init__( ) if self.attn_backend not in { - _Backend.FLASH_ATTN, - _Backend.TORCH_SDPA, - _Backend.XFORMERS, - _Backend.ROCM_AITER_FA, + AttentionBackendEnum.FLASH_ATTN, + AttentionBackendEnum.TORCH_SDPA, + AttentionBackendEnum.XFORMERS, + AttentionBackendEnum.ROCM_AITER_FA, }: raise RuntimeError( f"Qwen2.5-VL does not support {self.attn_backend} backend now." @@ -850,9 +853,12 @@ def compute_attn_mask_seqlen( ) -> tuple[torch.Tensor, torch.Tensor]: max_seqlen = torch.zeros([], device=cu_seqlens.device) seqlens = torch.zeros(1, device=cu_seqlens.device) - if self.attn_backend in {_Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA}: + if self.attn_backend in { + AttentionBackendEnum.FLASH_ATTN, + AttentionBackendEnum.ROCM_AITER_FA, + }: max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max() - elif self.attn_backend == _Backend.XFORMERS: + elif self.attn_backend == AttentionBackendEnum.XFORMERS: seqlens = cu_seqlens[1:] - cu_seqlens[:-1] return max_seqlen, seqlens diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index b3999e6c934e..61057fa145f4 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -43,7 +43,7 @@ from transformers.models.qwen2_vl.image_processing_qwen2_vl import smart_resize from transformers.models.qwen2_vl.video_processing_qwen2_vl import Qwen2VLVideoProcessor -from vllm.attention.backends.registry import _Backend +from vllm.attention.backends.registry import AttentionBackendEnum from vllm.attention.layer import ( check_upstream_fa_availability, maybe_get_vit_flash_attn_backend, @@ -329,7 +329,7 @@ def __init__( quant_config: QuantizationConfig | None = None, prefix: str = "", use_data_parallel: bool = False, - attn_backend_override: _Backend | None = None, + attn_backend_override: AttentionBackendEnum | None = None, ) -> None: super().__init__() # Per attention head and per partition values. @@ -378,18 +378,18 @@ def __init__( ) if self.attn_backend not in { - _Backend.FLASH_ATTN, - _Backend.TORCH_SDPA, - _Backend.XFORMERS, - _Backend.ROCM_AITER_FA, + AttentionBackendEnum.FLASH_ATTN, + AttentionBackendEnum.TORCH_SDPA, + AttentionBackendEnum.XFORMERS, + AttentionBackendEnum.ROCM_AITER_FA, }: raise RuntimeError( f"Qwen2-VL does not support {self.attn_backend} backend now." ) self.is_flash_attn_backend = self.attn_backend in { - _Backend.FLASH_ATTN, - _Backend.ROCM_AITER_FA, + AttentionBackendEnum.FLASH_ATTN, + AttentionBackendEnum.ROCM_AITER_FA, } def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]: @@ -460,7 +460,7 @@ def forward( context_layer = rearrange( output, "(b s) h d -> s b (h d)", b=batch_size ).contiguous() - elif self.attn_backend == _Backend.TORCH_SDPA: + elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA: # Execute attention entry by entry for speed & less VRAM. from vllm.platforms import current_platform @@ -485,7 +485,7 @@ def forward( context_layer = rearrange( context_layer, "b s h d -> s b (h d)" ).contiguous() - elif self.attn_backend == _Backend.XFORMERS: + elif self.attn_backend == AttentionBackendEnum.XFORMERS: from xformers import ops as xops from xformers.ops.fmha.attn_bias import BlockDiagonalMask @@ -515,7 +515,7 @@ def __init__( quant_config: QuantizationConfig | None = None, prefix: str = "", use_data_parallel: bool = False, - attn_backend_override: _Backend | None = None, + attn_backend_override: AttentionBackendEnum | None = None, ) -> None: super().__init__() if norm_layer is None: @@ -679,7 +679,7 @@ def __init__( quant_config: QuantizationConfig | None = None, prefix: str = "", use_data_parallel: bool = False, - attn_backend_override: _Backend | None = None, + attn_backend_override: AttentionBackendEnum | None = None, ) -> None: super().__init__() @@ -739,10 +739,11 @@ def __init__( dtype=torch.get_default_dtype(), attn_backend_override=attn_backend_override, ) - if self.attn_backend != _Backend.FLASH_ATTN and check_upstream_fa_availability( - torch.get_default_dtype() + if ( + self.attn_backend != AttentionBackendEnum.FLASH_ATTN + and check_upstream_fa_availability(torch.get_default_dtype()) ): - self.attn_backend = _Backend.FLASH_ATTN + self.attn_backend = AttentionBackendEnum.FLASH_ATTN @property def dtype(self) -> torch.dtype: @@ -789,9 +790,12 @@ def compute_attn_mask_seqlen( self, cu_seqlens: torch.Tensor ) -> tuple[int | None, list[int] | None]: max_seqlen, seqlens = None, None - if self.attn_backend in {_Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA}: + if self.attn_backend in { + AttentionBackendEnum.FLASH_ATTN, + AttentionBackendEnum.ROCM_AITER_FA, + }: max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() - elif self.attn_backend == _Backend.XFORMERS: + elif self.attn_backend == AttentionBackendEnum.XFORMERS: seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() return max_seqlen, seqlens diff --git a/vllm/model_executor/models/qwen3_omni_moe_thinker.py b/vllm/model_executor/models/qwen3_omni_moe_thinker.py index da489a812f55..468b25220154 100755 --- a/vllm/model_executor/models/qwen3_omni_moe_thinker.py +++ b/vllm/model_executor/models/qwen3_omni_moe_thinker.py @@ -47,7 +47,7 @@ ) from transformers.models.whisper import WhisperFeatureExtractor -from vllm.attention.backends.registry import _Backend +from vllm.attention.backends.registry import AttentionBackendEnum from vllm.attention.layer import check_upstream_fa_availability from vllm.compilation.decorators import support_torch_compile from vllm.config import VllmConfig @@ -301,7 +301,7 @@ def __init__( norm_eps: float = 1e-6, quant_config: QuantizationConfig | None = None, prefix: str = "", - attn_backend_override: _Backend | None = None, + attn_backend_override: AttentionBackendEnum | None = None, ) -> None: super().__init__() self.hidden_size = vision_config.hidden_size @@ -377,10 +377,11 @@ def __init__( dtype=torch.get_default_dtype(), attn_backend_override=attn_backend_override, ) - if self.attn_backend != _Backend.FLASH_ATTN and check_upstream_fa_availability( - torch.get_default_dtype() + if ( + self.attn_backend != AttentionBackendEnum.FLASH_ATTN + and check_upstream_fa_availability(torch.get_default_dtype()) ): - self.attn_backend = _Backend.FLASH_ATTN + self.attn_backend = AttentionBackendEnum.FLASH_ATTN @property def dtype(self) -> torch.dtype: @@ -490,9 +491,9 @@ def compute_attn_mask_seqlen( ) -> tuple[torch.Tensor, torch.Tensor]: max_seqlen = torch.zeros([], device=cu_seqlens.device) seqlens = torch.zeros(1, device=cu_seqlens.device) - if self.attn_backend == _Backend.FLASH_ATTN: + if self.attn_backend == AttentionBackendEnum.FLASH_ATTN: max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max() - elif self.attn_backend == _Backend.XFORMERS: + elif self.attn_backend == AttentionBackendEnum.XFORMERS: seqlens = cu_seqlens[1:] - cu_seqlens[:-1] return max_seqlen, seqlens diff --git a/vllm/model_executor/models/qwen3_vl.py b/vllm/model_executor/models/qwen3_vl.py index 1cd34bf54a35..1be35cde7dbd 100644 --- a/vllm/model_executor/models/qwen3_vl.py +++ b/vllm/model_executor/models/qwen3_vl.py @@ -49,7 +49,7 @@ ) from transformers.video_utils import VideoMetadata -from vllm.attention.backends.registry import _Backend +from vllm.attention.backends.registry import AttentionBackendEnum from vllm.attention.layer import check_upstream_fa_availability from vllm.compilation.decorators import support_torch_compile from vllm.config import VllmConfig @@ -198,7 +198,7 @@ def __init__( quant_config: QuantizationConfig | None = None, prefix: str = "", use_data_parallel: bool = False, - attn_backend: _Backend = _Backend.TORCH_SDPA, + attn_backend: AttentionBackendEnum = AttentionBackendEnum.TORCH_SDPA, use_upstream_fa: bool = False, ) -> None: super().__init__() @@ -306,7 +306,7 @@ def __init__( quant_config: QuantizationConfig | None = None, prefix: str = "", use_data_parallel: bool = False, - attn_backend_override: _Backend | None = None, + attn_backend_override: AttentionBackendEnum | None = None, ) -> None: super().__init__() self.hidden_size = vision_config.hidden_size @@ -372,18 +372,18 @@ def __init__( ) use_upstream_fa = False if ( - self.attn_backend != _Backend.FLASH_ATTN - and self.attn_backend != _Backend.ROCM_AITER_FA + self.attn_backend != AttentionBackendEnum.FLASH_ATTN + and self.attn_backend != AttentionBackendEnum.ROCM_AITER_FA and check_upstream_fa_availability(torch.get_default_dtype()) ): - self.attn_backend = _Backend.FLASH_ATTN + self.attn_backend = AttentionBackendEnum.FLASH_ATTN use_upstream_fa = True if self.attn_backend not in { - _Backend.FLASH_ATTN, - _Backend.TORCH_SDPA, - _Backend.XFORMERS, - _Backend.ROCM_AITER_FA, + AttentionBackendEnum.FLASH_ATTN, + AttentionBackendEnum.TORCH_SDPA, + AttentionBackendEnum.XFORMERS, + AttentionBackendEnum.ROCM_AITER_FA, }: raise RuntimeError( f"Qwen3-VL does not support {self.attn_backend} backend now." @@ -510,11 +510,11 @@ def compute_attn_mask_seqlen( max_seqlen = torch.zeros([], device=cu_seqlens.device) seqlens = torch.zeros(1, device=cu_seqlens.device) if ( - self.attn_backend == _Backend.FLASH_ATTN - or self.attn_backend == _Backend.ROCM_AITER_FA + self.attn_backend == AttentionBackendEnum.FLASH_ATTN + or self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA ): max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max() - elif self.attn_backend == _Backend.XFORMERS: + elif self.attn_backend == AttentionBackendEnum.XFORMERS: seqlens = cu_seqlens[1:] - cu_seqlens[:-1] return max_seqlen, seqlens diff --git a/vllm/model_executor/models/siglip2navit.py b/vllm/model_executor/models/siglip2navit.py index bab5c1d82ded..c20bcd975ca3 100644 --- a/vllm/model_executor/models/siglip2navit.py +++ b/vllm/model_executor/models/siglip2navit.py @@ -12,7 +12,7 @@ from transformers import Siglip2VisionConfig from transformers.configuration_utils import PretrainedConfig -from vllm.attention.backends.registry import _Backend +from vllm.attention.backends.registry import AttentionBackendEnum from vllm.attention.layer import maybe_get_vit_flash_attn_backend from vllm.distributed import divide, get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn @@ -208,7 +208,7 @@ def __init__( quant_config: QuantizationConfig | None = None, prefix: str = "", use_data_parallel: bool = False, - attn_backend_override: _Backend | None = None, + attn_backend_override: AttentionBackendEnum | None = None, ): super().__init__() self.config = config @@ -264,14 +264,14 @@ def __init__( ) if self.attn_backend not in { - _Backend.FLASH_ATTN, - _Backend.TORCH_SDPA, - _Backend.ROCM_AITER_FA, + AttentionBackendEnum.FLASH_ATTN, + AttentionBackendEnum.TORCH_SDPA, + AttentionBackendEnum.ROCM_AITER_FA, }: - self.attn_backend = _Backend.TORCH_SDPA + self.attn_backend = AttentionBackendEnum.TORCH_SDPA self.is_flash_attn_backend = self.attn_backend in { - _Backend.FLASH_ATTN, - _Backend.ROCM_AITER_FA, + AttentionBackendEnum.FLASH_ATTN, + AttentionBackendEnum.ROCM_AITER_FA, } def forward( @@ -308,7 +308,7 @@ def forward( attn_output = self.flash_attn_varlen_func( queries, keys, values, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen ).reshape(seq_length, -1) - elif self.attn_backend == _Backend.TORCH_SDPA: + elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA: # Execute attention entry by entry for speed & less VRAM. batch_size = cu_seqlens.shape[0] - 1 outputs = [] @@ -376,7 +376,7 @@ def __init__( quant_config: QuantizationConfig | None = None, prefix: str = "", use_data_parallel: bool = False, - attn_backend_override: _Backend | None = None, + attn_backend_override: AttentionBackendEnum | None = None, ): super().__init__() self.embed_dim = config.hidden_size @@ -440,7 +440,7 @@ def __init__( quant_config: QuantizationConfig | None = None, prefix: str = "", use_data_parallel: bool = False, - attn_backend_override: _Backend | None = None, + attn_backend_override: AttentionBackendEnum | None = None, ): super().__init__() self.config = config @@ -626,7 +626,7 @@ def __init__( quant_config: QuantizationConfig | None = None, prefix: str = "", use_data_parallel: bool = False, - attn_backend_override: _Backend | None = None, + attn_backend_override: AttentionBackendEnum | None = None, ): super().__init__() self.config = config @@ -667,7 +667,7 @@ def __init__( quant_config: QuantizationConfig | None = None, prefix: str = "", use_data_parallel: bool = False, - attn_backend_override: _Backend | None = None, + attn_backend_override: AttentionBackendEnum | None = None, ): super().__init__() diff --git a/vllm/model_executor/models/vision.py b/vllm/model_executor/models/vision.py index 9f94387c700d..0e814e5c86ad 100644 --- a/vllm/model_executor/models/vision.py +++ b/vllm/model_executor/models/vision.py @@ -10,7 +10,7 @@ import torch from transformers import PretrainedConfig -from vllm.attention.backends.registry import _Backend +from vllm.attention.backends.registry import AttentionBackendEnum from vllm.config import VllmConfig from vllm.distributed import ( get_tensor_model_parallel_rank, @@ -83,8 +83,8 @@ def get_vit_attn_backend( head_size: int, dtype: torch.dtype, *, - attn_backend_override: _Backend | None = None, -) -> _Backend: + attn_backend_override: AttentionBackendEnum | None = None, +) -> AttentionBackendEnum: """ Get the available attention backend for Vision Transformer. """ @@ -94,7 +94,7 @@ def get_vit_attn_backend( # Lazy import to avoid circular dependency from vllm.attention.selector import get_env_variable_attn_backend - selected_backend: _Backend | None = get_env_variable_attn_backend() + selected_backend: AttentionBackendEnum | None = get_env_variable_attn_backend() if selected_backend is not None: return selected_backend diff --git a/vllm/platforms/cpu.py b/vllm/platforms/cpu.py index ee904535ffe8..3dec6da89702 100644 --- a/vllm/platforms/cpu.py +++ b/vllm/platforms/cpu.py @@ -23,10 +23,10 @@ logger = init_logger(__name__) if TYPE_CHECKING: - from vllm.attention.backends.registry import _Backend + from vllm.attention.backends.registry import AttentionBackendEnum from vllm.config import VllmConfig else: - _Backend = None + AttentionBackendEnum = None VllmConfig = None @@ -127,7 +127,7 @@ def get_device_name(cls, device_id: int = 0) -> str: @classmethod def get_attn_backend_cls( cls, - selected_backend: "_Backend", + selected_backend: "AttentionBackendEnum", head_size: int, dtype: torch.dtype, kv_cache_dtype: str | None, @@ -137,9 +137,9 @@ def get_attn_backend_cls( has_sink: bool, use_sparse: bool, ) -> str: - from vllm.attention.backends.registry import _Backend + from vllm.attention.backends.registry import AttentionBackendEnum - if selected_backend and selected_backend != _Backend.TORCH_SDPA: + if selected_backend and selected_backend != AttentionBackendEnum.TORCH_SDPA: logger.info("Cannot use %s backend on CPU.", selected_backend) if use_mla: raise NotImplementedError("MLA is not supported on CPU.") @@ -148,7 +148,7 @@ def get_attn_backend_cls( logger.info("Using Torch SDPA backend.") if not use_v1: raise ValueError("CPU backend only supports V1.") - return "vllm.v1.attention.backends.cpu_attn.TorchSDPABackend" + return AttentionBackendEnum.TORCH_SDPA.get_path() @classmethod def get_device_total_memory(cls, device_id: int = 0) -> int: diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index 32734c3aba5e..43daf5e75b66 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -22,10 +22,13 @@ from .interface import DeviceCapability, Platform, PlatformEnum if TYPE_CHECKING: - from vllm.attention.backends.registry import _Backend + from vllm.attention.backends.registry import AttentionBackendEnum from vllm.config import VllmConfig + from vllm.config.cache import CacheDType else: - _Backend = None + AttentionBackendEnum = None + VllmConfig = None + CacheDType = None logger = init_logger(__name__) @@ -39,6 +42,49 @@ torch.backends.cuda.enable_cudnn_sdp(False) +@cache +def _get_backend_priorities( + use_mla: bool, + device_capability: DeviceCapability, +) -> list[AttentionBackendEnum]: + """Get backend priorities with lazy import to avoid circular dependency.""" + from vllm.attention.backends.registry import AttentionBackendEnum + + if use_mla: + if device_capability.major == 10: + return [ + AttentionBackendEnum.CUTLASS_MLA, + AttentionBackendEnum.FLASHINFER_MLA, + AttentionBackendEnum.FLASHMLA, + AttentionBackendEnum.FLASH_ATTN_MLA, + AttentionBackendEnum.TRITON_MLA, + AttentionBackendEnum.FLASHMLA_SPARSE, + ] + else: + return [ + AttentionBackendEnum.FLASHMLA, + AttentionBackendEnum.FLASH_ATTN_MLA, + AttentionBackendEnum.FLASHINFER_MLA, + AttentionBackendEnum.TRITON_MLA, + AttentionBackendEnum.FLASHMLA_SPARSE, + ] + else: + if device_capability.major == 10: + return [ + AttentionBackendEnum.FLASHINFER, + AttentionBackendEnum.FLASH_ATTN, + AttentionBackendEnum.TRITON_ATTN, + AttentionBackendEnum.FLEX_ATTENTION, + ] + else: + return [ + AttentionBackendEnum.FLASH_ATTN, + AttentionBackendEnum.FLASHINFER, + AttentionBackendEnum.TRITON_ATTN, + AttentionBackendEnum.FLEX_ATTENTION, + ] + + def with_nvml_context(fn: Callable[_P, _R]) -> Callable[_P, _R]: @wraps(fn) def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R: @@ -216,217 +262,171 @@ def get_current_memory_usage( return torch.cuda.max_memory_allocated(device) @classmethod - def get_vit_attn_backend(cls, head_size: int, dtype: torch.dtype) -> "_Backend": - from vllm.attention.backends.registry import _Backend + def get_vit_attn_backend( + cls, head_size: int, dtype: torch.dtype + ) -> "AttentionBackendEnum": + from vllm.attention.backends.registry import AttentionBackendEnum # For Blackwell GPUs, force TORCH_SDPA for now. # See https://github.com/facebookresearch/xformers/issues/1317#issuecomment-3199392579 # noqa: E501 if cls.has_device_capability(100): - return _Backend.TORCH_SDPA + return AttentionBackendEnum.TORCH_SDPA if dtype not in (torch.float16, torch.bfloat16): - return _Backend.XFORMERS + return AttentionBackendEnum.XFORMERS if cls.has_device_capability(80): - FLASH_ATTN_V1 = ( - "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend" # noqa: E501 - ) - from vllm.attention.selector import is_attn_backend_supported - - is_default_fa_supported = is_attn_backend_supported( - FLASH_ATTN_V1, head_size, dtype, allow_import_error=False - ) - if is_default_fa_supported: - return _Backend.FLASH_ATTN + backend_class = AttentionBackendEnum.FLASH_ATTN.get_class() + if backend_class.supports_head_size( + head_size + ) and backend_class.supports_dtype(dtype): + return AttentionBackendEnum.FLASH_ATTN else: - # Fallback to XFORMERS - return _Backend.XFORMERS + return AttentionBackendEnum.XFORMERS else: # Fallback for Volta/Turing GPUs or FA not supported - return _Backend.XFORMERS + return AttentionBackendEnum.XFORMERS @classmethod - def get_attn_backend_cls( + def get_valid_backends( cls, - selected_backend, head_size, dtype, kv_cache_dtype, block_size, - use_v1, use_mla, has_sink, use_sparse, - ) -> str: - from vllm.attention.backends.registry import _Backend - - if use_mla: - # explicitly reject non-MLA backends when MLA is enabled to avoid - # silently selecting an incompatible backend (e.g., FLASHINFER). - if selected_backend in { - _Backend.FLASHINFER, - _Backend.FLASH_ATTN, - _Backend.TRITON_ATTN, - _Backend.TREE_ATTN, - _Backend.XFORMERS, - }: - raise ValueError( - f"Attention backend {selected_backend} incompatible with MLA. " - "Please use one of the MLA backends: FLASHINFER_MLA, CUTLASS_MLA, " - "FLASHMLA, FLASH_ATTN_MLA, or TRITON_MLA. Alternatively, set " - "VLLM_MLA_DISABLE=1 to disable MLA for this model." + device_capability, + ) -> tuple[ + list[tuple["AttentionBackendEnum", int]], + dict["AttentionBackendEnum", list[str]], + ]: + valid_backends_priorities = [] + invalid_reasons = {} + + backend_priorities = _get_backend_priorities(use_mla, device_capability) + for priority, backend in enumerate(backend_priorities): + try: + backend_class = backend.get_class() + invalid_reasons_i = backend_class.validate_configuration( + head_size, + dtype, + kv_cache_dtype, + block_size, + use_mla, + has_sink, + use_sparse, + device_capability, ) + except ImportError: + invalid_reasons_i = ["ImportError"] + if invalid_reasons_i: + invalid_reasons[backend] = invalid_reasons_i + else: + valid_backends_priorities.append((backend, priority)) - from vllm.attention.ops.flashmla import is_flashmla_dense_supported - from vllm.attention.utils.fa_utils import flash_attn_supports_mla + return valid_backends_priorities, invalid_reasons - if use_sparse: - logger.info_once("Using Sparse MLA backend.") - return ( - "vllm.v1.attention.backends.mla.flashmla_sparse." - "FlashMLASparseBackend" - ) - - use_cutlassmla = selected_backend == _Backend.CUTLASS_MLA or ( - selected_backend is None - and cls.is_device_capability(100) - and block_size % 128 == 0 - ) - use_flashinfermla = selected_backend == _Backend.FLASHINFER_MLA or ( - selected_backend is None - and cls.is_device_capability(100) - and (block_size == 32 or block_size % 64 == 0) - ) - use_flashmla = selected_backend == _Backend.FLASHMLA or ( - selected_backend is None and is_flashmla_dense_supported()[0] - ) - use_flashattn = selected_backend == _Backend.FLASH_ATTN_MLA or ( - selected_backend is None and flash_attn_supports_mla() - ) - use_triton = selected_backend == _Backend.TRITON_MLA or ( - selected_backend is None + @classmethod + def get_attn_backend_cls( + cls, + selected_backend: "AttentionBackendEnum", + head_size: int, + dtype: torch.dtype, + kv_cache_dtype: "CacheDType | None", + block_size: int | None, + use_v1: bool, + use_mla: bool, + has_sink: bool, + use_sparse: bool, + ) -> str: + if not use_v1: + raise RuntimeError( + "V0 attention backends have been removed. Set VLLM_USE_V1=1 " + "to select a supported backend." ) - if use_cutlassmla: - logger.info_once("Using Cutlass MLA backend.", scope="local") - return "vllm.v1.attention.backends.mla.cutlass_mla.CutlassMLABackend" - if use_flashinfermla: - from vllm.v1.attention.backends.utils import set_kv_cache_layout - - set_kv_cache_layout("HND") - logger.info_once("Using FlashInfer MLA backend.") - return ( - "vllm.v1.attention.backends.mla.flashinfer_mla.FlashInferMLABackend" + device_capability = cls.get_device_capability() + assert device_capability is not None + + # First try checking just the selected backend, if there is one. + if selected_backend is not None: + try: + backend_class = selected_backend.get_class() + invalid_reasons = backend_class.validate_configuration( + head_size, + dtype, + kv_cache_dtype, + None, + use_mla, + has_sink, + use_sparse, + device_capability, ) - if use_flashmla: - if block_size % 64 != 0: - logger.warning( - "FlashMLA backend is not supported for block size %d" - " (currently only supports block size 64).", - block_size, - ) - else: - logger.info_once("Using FlashMLA backend.") - return "vllm.v1.attention.backends.mla.flashmla.FlashMLABackend" - if use_flashattn: - logger.info_once("Using FlashAttention MLA backend.") - return ( - "vllm.v1.attention.backends.mla.flashattn_mla.FlashAttnMLABackend" + except ImportError: + invalid_reasons = ["ImportError"] + if invalid_reasons: + raise ValueError( + f"Selected backend {selected_backend} is not valid for " + f"this configuration. Reason: {invalid_reasons}" ) - if use_triton: - logger.info_once("Using Triton MLA backend.") - return "vllm.v1.attention.backends.mla.triton_mla.TritonMLABackend" - - FLASHINFER_V1 = "vllm.v1.attention.backends.flashinfer.FlashInferBackend" # noqa: E501 - FLEX_ATTENTION_V1 = ( - "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend" # noqa: E501 + else: + logger.info("Using %s backend.", selected_backend) + return selected_backend.get_path() + + # No selected backend or the selected backend is invalid, + # so we try finding a valid backend. + valid_backends_priorities, invalid_reasons = cls.get_valid_backends( + head_size, + dtype, + kv_cache_dtype, + None, + use_mla, + has_sink, + use_sparse, + device_capability, ) - TRITON_ATTN = "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend" # noqa: E501 - FLASH_ATTN_V1 = "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend" # noqa: E501 - TREE_ATTN_V1 = "vllm.v1.attention.backends.tree_attn.TreeAttentionBackend" # noqa: E501 - XFORMERS_V1 = "vllm.v1.attention.backends.xformers.XFormersAttentionBackend" # noqa: E501 - - use_fp8_kv_cache = kv_cache_dtype is not None and kv_cache_dtype.startswith( - "fp8" + reasons_str = ( + "{" + + ", ".join( + f"{backend.name}: [{', '.join(reasons)}]" + for backend, reasons in invalid_reasons.items() + ) + + "}" ) + config_str = ( + f"head_size: {head_size}, dtype: {dtype}, " + f"kv_cache_dtype: {kv_cache_dtype}, block_size: {block_size}, " + f"use_mla: {use_mla}, has_sink: {has_sink}, use_sparse: {use_sparse}" + ) + logger.debug_once( + f"Some attention backends are not valid for {cls.device_name} with " + f"{config_str}. Reasons: {reasons_str}." + ) + if len(valid_backends_priorities) == 0: + raise ValueError( + f"No valid attention backend found for {cls.device_name} " + f"with {config_str}. Reasons: {reasons_str}." + ) - if selected_backend == _Backend.FLASHINFER: - logger.info_once("Using FlashInfer backend.") - if cls.has_device_capability(100): - from vllm.v1.attention.backends.utils import set_kv_cache_layout - - set_kv_cache_layout("HND") - return FLASHINFER_V1 - elif selected_backend == _Backend.FLEX_ATTENTION: - logger.info_once("Using FlexAttention backend.") - return FLEX_ATTENTION_V1 - elif selected_backend == _Backend.TRITON_ATTN: - logger.info_once("Using Triton backend.") - return TRITON_ATTN - elif selected_backend == _Backend.FLASH_ATTN: - logger.info_once("Using Flash Attention backend.") - return FLASH_ATTN_V1 - elif selected_backend == _Backend.TREE_ATTN: - logger.info_once("Using Tree Attention backend.") - return TREE_ATTN_V1 - elif selected_backend == _Backend.XFORMERS: - logger.info_once("Using XFormers backend.") - return XFORMERS_V1 - - from vllm.attention.selector import is_attn_backend_supported - - # Default backends for V1 engine - # Prefer FlashInfer for Blackwell GPUs if installed - if cls.is_device_capability(100): - if is_default_backend_supported := is_attn_backend_supported( - FLASHINFER_V1, head_size, dtype - ): - from vllm.v1.attention.backends.utils import set_kv_cache_layout - - logger.info_once( - "Using FlashInfer backend with HND KV cache layout on " - "V1 engine by default for Blackwell (SM 10.0) GPUs." - ) - set_kv_cache_layout("HND") - - return FLASHINFER_V1 - - if not is_default_backend_supported.can_import: - logger.warning_once( - "FlashInfer failed to import on Blackwell (SM 10.0) GPUs; " - "it is recommended to install FlashInfer for better " - "performance." - ) - - # FlashAttention is the default for SM 8.0+ GPUs - if cls.has_device_capability(80): - if (has_sink or use_fp8_kv_cache) and not cls.is_device_capability(90): - logger.info_once("Using Triton backend.") - return TRITON_ATTN - elif is_default_backend_supported := is_attn_backend_supported( - FLASH_ATTN_V1, head_size, dtype, allow_import_error=False - ): - logger.info_once("Using Flash Attention backend.") - return FLASH_ATTN_V1 - - # FlexAttention is the default for older GPUs - else: - logger.info_once("Using FlexAttention backend.") - return FLEX_ATTENTION_V1 - - assert not is_default_backend_supported - - use_flex_attention_reason = {} - if not is_default_backend_supported.head_size: - use_flex_attention_reason["head_size"] = head_size - if not is_default_backend_supported.dtype: - use_flex_attention_reason["dtype"] = dtype - - logger.info_once( - "Using FlexAttention backend for %s.", - ", ".join(f"{k}={v}" for k, v in use_flex_attention_reason.items()), + # We have found some valid backends. Select the one with the + # highest priority. + logger.info( + "Valid backends: %s", [b[0].name for b in valid_backends_priorities] ) - return FLEX_ATTENTION_V1 + sorted_indices = sorted( + range(len(valid_backends_priorities)), + key=lambda i: valid_backends_priorities[i][1], + ) + selected_index = sorted_indices[0] + selected_backend = valid_backends_priorities[selected_index][0] + logger.info( + "Using %s backend.", + selected_backend.name, + ) + + return selected_backend.get_path() @classmethod def get_punica_wrapper(cls) -> str: diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index 15e3b3a22bde..4969bcf116a4 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -17,8 +17,9 @@ if TYPE_CHECKING: from torch.distributed import PrefixStore, ProcessGroup - from vllm.attention.backends.registry import _Backend + from vllm.attention.backends.registry import AttentionBackendEnum from vllm.config import VllmConfig + from vllm.config.cache import CacheDType from vllm.inputs import ProcessorInputs, PromptType from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams @@ -58,6 +59,31 @@ class DeviceCapability(NamedTuple): major: int minor: int + def __lt__(self, other: Any) -> bool: + if not isinstance(other, DeviceCapability): + return NotImplemented + return (self.major, self.minor) < (other.major, other.minor) + + def __le__(self, other: Any) -> bool: + if not isinstance(other, DeviceCapability): + return NotImplemented + return (self.major, self.minor) <= (other.major, other.minor) + + def __eq__(self, other: Any) -> bool: + if not isinstance(other, DeviceCapability): + return NotImplemented + return (self.major, self.minor) == (other.major, other.minor) + + def __ge__(self, other: Any) -> bool: + if not isinstance(other, DeviceCapability): + return NotImplemented + return (self.major, self.minor) >= (other.major, other.minor) + + def __gt__(self, other: Any) -> bool: + if not isinstance(other, DeviceCapability): + return NotImplemented + return (self.major, self.minor) > (other.major, other.minor) + def as_version_str(self) -> str: return f"{self.major}.{self.minor}" @@ -173,19 +199,21 @@ def import_kernels(cls) -> None: import vllm._moe_C # noqa: F401 @classmethod - def get_vit_attn_backend(cls, head_size: int, dtype: torch.dtype) -> "_Backend": - # Import _Backend here to avoid circular import. - from vllm.attention.backends.registry import _Backend + def get_vit_attn_backend( + cls, head_size: int, dtype: torch.dtype + ) -> "AttentionBackendEnum": + # Import AttentionBackendEnum here to avoid circular import. + from vllm.attention.backends.registry import AttentionBackendEnum - return _Backend.TORCH_SDPA + return AttentionBackendEnum.TORCH_SDPA @classmethod def get_attn_backend_cls( cls, - selected_backend: "_Backend", + selected_backend: "AttentionBackendEnum", head_size: int, dtype: torch.dtype, - kv_cache_dtype: str | None, + kv_cache_dtype: "CacheDType | None", block_size: int, use_v1: bool, use_mla: bool, diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index e6536a02a73d..5318bdb8b36c 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -14,10 +14,10 @@ from .interface import DeviceCapability, Platform, PlatformEnum if TYPE_CHECKING: - from vllm.attention.backends.registry import _Backend + from vllm.attention.backends.registry import AttentionBackendEnum from vllm.config import VllmConfig else: - _Backend = None + AttentionBackendEnum = None logger = init_logger(__name__) @@ -204,21 +204,23 @@ class RocmPlatform(Platform): ] @classmethod - def get_vit_attn_backend(cls, head_size: int, dtype: torch.dtype) -> _Backend: + def get_vit_attn_backend( + cls, head_size: int, dtype: torch.dtype + ) -> AttentionBackendEnum: from importlib.util import find_spec from vllm._aiter_ops import rocm_aiter_ops - from vllm.attention.backends.registry import _Backend + from vllm.attention.backends.registry import AttentionBackendEnum if rocm_aiter_ops.is_mha_enabled(): # Note: AITER FA is only supported for Qwen-VL models. # TODO: Add support for other VL models in their model class. - return _Backend.ROCM_AITER_FA + return AttentionBackendEnum.ROCM_AITER_FA if on_gfx9() and find_spec("flash_attn") is not None: - return _Backend.FLASH_ATTN + return AttentionBackendEnum.FLASH_ATTN - return _Backend.TORCH_SDPA + return AttentionBackendEnum.TORCH_SDPA @classmethod def get_attn_backend_cls( @@ -234,7 +236,7 @@ def get_attn_backend_cls( use_sparse, ) -> str: from vllm._aiter_ops import rocm_aiter_ops - from vllm.attention.backends.registry import _Backend + from vllm.attention.backends.registry import AttentionBackendEnum if use_sparse: raise NotImplementedError("Sparse Attention is not supported on ROCm.") @@ -248,55 +250,52 @@ def get_attn_backend_cls( if use_mla: if selected_backend is None: selected_backend = ( - _Backend.ROCM_AITER_MLA + AttentionBackendEnum.ROCM_AITER_MLA if rocm_aiter_ops.is_mla_enabled() or block_size == 1 - else _Backend.TRITON_MLA + else AttentionBackendEnum.TRITON_MLA ) - if selected_backend == _Backend.TRITON_MLA: + if selected_backend == AttentionBackendEnum.TRITON_MLA: if block_size != 1: logger.info_once("Using Triton MLA backend.") - return "vllm.v1.attention.backends.mla.triton_mla.TritonMLABackend" + return AttentionBackendEnum.TRITON_MLA.get_path() raise ValueError( f" The selected backend, {selected_backend.name}," f"does not support block size {block_size}." ) - if selected_backend == _Backend.ROCM_AITER_MLA: + if selected_backend == AttentionBackendEnum.ROCM_AITER_MLA: logger.info("Using AITER MLA backend.") - return "vllm.v1.attention.backends.mla.rocm_aiter_mla.AiterMLABackend" # noqa: E501 + return AttentionBackendEnum.ROCM_AITER_MLA.get_path() raise ValueError( f" The selected backend, {selected_backend.name}," f"is not MLA type while requested for MLA backend." ) - if selected_backend == _Backend.FLEX_ATTENTION: + if selected_backend == AttentionBackendEnum.FLEX_ATTENTION: logger.info("Using FlexAttention backend.") return "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend" if ( rocm_aiter_ops.is_mha_enabled() - ) or selected_backend == _Backend.ROCM_AITER_FA: + ) or selected_backend == AttentionBackendEnum.ROCM_AITER_FA: logger.info("Using Aiter Flash Attention backend.") - return "vllm.v1.attention.backends.rocm_aiter_fa.AiterFlashAttentionBackend" + return AttentionBackendEnum.ROCM_AITER_FA.get_path() if ( rocm_aiter_ops.is_triton_unified_attn_enabled() - ) or selected_backend == _Backend.ROCM_AITER_UNIFIED_ATTN: + ) or selected_backend == AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN: logger.info("Using Aiter Unified Attention backend.") - return ( - "vllm.v1.attention.backends." - "rocm_aiter_unified_attn.RocmAiterUnifiedAttentionBackend" - ) + return AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN.get_path() if ( envs.VLLM_V1_USE_PREFILL_DECODE_ATTENTION - or selected_backend == _Backend.ROCM_ATTN + or selected_backend == AttentionBackendEnum.ROCM_ATTN ): # rocm specific backend, with aiter and/or # triton prefix-prefill logger.info("Using Rocm Attention backend.") - return "vllm.v1.attention.backends.rocm_attn.RocmAttentionBackend" + return AttentionBackendEnum.ROCM_ATTN.get_path() # default case, using triton unified attention logger.info("Using Triton Attention backend.") - return "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend" + return AttentionBackendEnum.TRITON_ATTN.get_path() @classmethod def set_device(cls, device: torch.device) -> None: diff --git a/vllm/platforms/tpu.py b/vllm/platforms/tpu.py index 1a4b67a1762f..575a9892c211 100644 --- a/vllm/platforms/tpu.py +++ b/vllm/platforms/tpu.py @@ -15,16 +15,15 @@ from .interface import Platform, PlatformEnum if TYPE_CHECKING: - from vllm.attention.backends.registry import _Backend - from vllm.config import ModelConfig, VllmConfig + from vllm.attention.backends.registry import AttentionBackendEnum + from vllm.config import VllmConfig from vllm.config.cache import BlockSize from vllm.pooling_params import PoolingParams else: BlockSize = None - ModelConfig = None VllmConfig = None PoolingParams = None - _Backend = None + AttentionBackendEnum = None logger = init_logger(__name__) @@ -54,7 +53,7 @@ def import_kernels(cls) -> None: @classmethod def get_attn_backend_cls( cls, - selected_backend: "_Backend", + selected_backend: "AttentionBackendEnum", head_size: int, dtype: torch.dtype, kv_cache_dtype: str | None, @@ -64,17 +63,17 @@ def get_attn_backend_cls( has_sink, use_sparse, ) -> str: - from vllm.attention.backends.registry import _Backend + from vllm.attention.backends.registry import AttentionBackendEnum if use_sparse: raise NotImplementedError("Sparse Attention is not supported on TPU.") - if selected_backend != _Backend.PALLAS: + if selected_backend != AttentionBackendEnum.PALLAS: logger.info("Cannot use %s backend on TPU.", selected_backend) if not use_v1: raise ValueError("TPU backend only supports V1.") logger.info("Using Pallas V1 backend.") - return "vllm.v1.attention.backends.pallas.PallasAttentionBackend" + return AttentionBackendEnum.PALLAS.get_path() @classmethod def set_device(cls, device: torch.device) -> None: diff --git a/vllm/platforms/xpu.py b/vllm/platforms/xpu.py index e4ecd0c807da..684d6d9a6b57 100644 --- a/vllm/platforms/xpu.py +++ b/vllm/platforms/xpu.py @@ -14,12 +14,11 @@ from .interface import DeviceCapability, Platform, PlatformEnum if TYPE_CHECKING: - from vllm.attention.backends.registry import _Backend - from vllm.config import ModelConfig, VllmConfig + from vllm.attention.backends.registry import AttentionBackendEnum + from vllm.config import VllmConfig else: - ModelConfig = None VllmConfig = None - _Backend = None + AttentionBackendEnum = None logger = init_logger(__name__) @@ -44,7 +43,7 @@ def import_kernels(cls) -> None: @classmethod def get_attn_backend_cls( cls, - selected_backend: "_Backend", + selected_backend: "AttentionBackendEnum", head_size: int, dtype: torch.dtype, kv_cache_dtype: str | None, @@ -62,18 +61,19 @@ def get_attn_backend_cls( "only NHD layout is supported by XPU attention kernels." ) - from vllm.attention.backends.registry import _Backend + from vllm.attention.backends.registry import AttentionBackendEnum if use_sparse: raise NotImplementedError("Sparse Attention is not supported on XPU.") - TRITON_ATTN = "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend" # noqa: E501 - FLASH_ATTN = "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend" # noqa: E501 - if selected_backend == _Backend.TRITON_ATTN: + use_v1 = envs.VLLM_USE_V1 + if not use_v1: + raise ValueError("XPU backend only supports V1.") + if selected_backend == AttentionBackendEnum.TRITON_ATTN: logger.info_once("Using Triton backend.") - return TRITON_ATTN - elif selected_backend == _Backend.FLASH_ATTN: + return AttentionBackendEnum.TRITON_ATTN.get_path() + elif selected_backend == AttentionBackendEnum.FLASH_ATTN: logger.info_once("Using Flash Attention backend.") - return FLASH_ATTN + return AttentionBackendEnum.FLASH_ATTN.get_path() elif selected_backend: raise ValueError( f"Invalid attention backend for {cls.device_name}, " @@ -81,7 +81,7 @@ def get_attn_backend_cls( ) logger.info("Using Flash Attention backend.") - return "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend" + return AttentionBackendEnum.FLASH_ATTN.get_path() @classmethod def set_device(cls, device: torch.device) -> None: @@ -113,10 +113,10 @@ def get_device_total_memory(cls, device_id: int = 0) -> int: return device_props.total_memory @classmethod - def get_vit_attn_backend(cls, head_size: int, dtype: torch.dtype) -> _Backend: - from vllm.attention.backends.registry import _Backend - - return _Backend.FLASH_ATTN + def get_vit_attn_backend( + cls, head_size: int, dtype: torch.dtype + ) -> AttentionBackendEnum: + return AttentionBackendEnum.FLASH_ATTN @classmethod def inference_mode(cls): diff --git a/vllm/v1/attention/backends/cpu_attn.py b/vllm/v1/attention/backends/cpu_attn.py index 20d987fa2de3..0057a7e22882 100644 --- a/vllm/v1/attention/backends/cpu_attn.py +++ b/vllm/v1/attention/backends/cpu_attn.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from dataclasses import dataclass -from typing import Optional +from typing import ClassVar, Optional import numpy as np import torch @@ -40,23 +40,16 @@ class TorchSDPABackend(AttentionBackend): accept_output_buffer: bool = False + supported_dtypes: ClassVar[list[torch.dtype]] = [ + torch.float16, + torch.bfloat16, + torch.float32, + ] @classmethod - def get_supported_dtypes(cls) -> list[torch.dtype]: - return [torch.float16, torch.bfloat16, torch.float32] - - @classmethod - def validate_head_size(cls, head_size: int) -> None: + def get_supported_head_sizes(cls) -> list[int]: attn_impl = _get_paged_attn_impl() - is_valid, supported_head_sizes = attn_impl.validate_head_size(head_size) - if not is_valid: - attn_type = cls.__name__.removesuffix("Backend") - raise ValueError( - f"Head size {head_size} is not supported by {attn_type}. " - f"Supported head sizes are: {supported_head_sizes}. " - "Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use " - "FlexAttention backend which supports all head sizes." - ) + return attn_impl.get_supported_head_sizes() @staticmethod def get_name() -> str: @@ -759,9 +752,8 @@ def _make_sliding_window_bias( class _PagedAttention: @staticmethod - def validate_head_size(head_size: int) -> tuple[bool, list[int]]: - SUPPORT_HS = [32, 64, 80, 96, 112, 128, 192, 256] - return head_size in SUPPORT_HS, SUPPORT_HS + def get_supported_head_sizes() -> list[int]: + return [32, 64, 80, 96, 112, 128, 192, 256] @staticmethod def get_kv_cache_shape( @@ -861,8 +853,8 @@ def forward_decode( class _IPEXPagedAttention(_PagedAttention): @staticmethod - def validate_head_size(head_size: int) -> tuple[bool, list[int]]: - return True, [] + def get_supported_head_sizes() -> list[int]: + return [] @staticmethod def split_kv_cache( diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index 15bb2f4a40ac..9cec623814c9 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -3,6 +3,7 @@ """Attention layer with FlashAttention.""" from dataclasses import dataclass +from typing import ClassVar import numpy as np import torch @@ -32,11 +33,13 @@ reshape_and_cache_flash, ) from vllm.config import VllmConfig, get_layers_from_vllm_config +from vllm.config.cache import CacheDType from vllm.distributed.parallel_state import get_dcp_group from vllm.logger import init_logger from vllm.model_executor.layers.batch_invariant import ( vllm_is_batch_invariant, ) +from vllm.platforms.interface import DeviceCapability from vllm.utils.math_utils import cdiv from vllm.v1.attention.backends.utils import ( AttentionCGSupport, @@ -52,34 +55,12 @@ class FlashAttentionBackend(AttentionBackend): accept_output_buffer: bool = True - - @classmethod - def get_supported_dtypes(cls) -> list[torch.dtype]: - return [torch.float16, torch.bfloat16] - - @classmethod - def get_supported_head_sizes(cls) -> list[int]: - return [32, 64, 96, 128, 160, 192, 224, 256] - - @staticmethod - def get_supported_kernel_block_size() -> list[int | MultipleOf]: - # NOTE(tdoublep): while in principle, FA supports - # MultipleOf(16), these are the block sizes that do not - # suffer from the NaN propagation problem described here: - # https://github.com/Dao-AILab/flash-attention/issues/1974 - return [16, 32, 64] - - @classmethod - def validate_head_size(cls, head_size: int) -> None: - supported_head_sizes = cls.get_supported_head_sizes() - if head_size not in supported_head_sizes: - attn_type = cls.__name__.removesuffix("Backend") - raise ValueError( - f"Head size {head_size} is not supported by {attn_type}. " - f"Supported head sizes are: {supported_head_sizes}. " - "Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use " - "FlexAttention backend which supports all head sizes." - ) + supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16] + # NOTE(tdoublep): while in principle, FA supports + # MultipleOf(16), these are the block sizes that do not + # suffer from the NaN propagation problem described here: + # https://github.com/Dao-AILab/flash-attention/issues/1974 + supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [16, 32, 64] @staticmethod def get_name() -> str: @@ -125,6 +106,38 @@ def get_fp8_dtype_for_flashattn(kv_cache_dtype: str) -> torch.dtype: else: raise ValueError(f"Unrecognized FP8 dtype: {kv_cache_dtype}") + @classmethod + def get_supported_head_sizes(cls) -> list[int]: + return [32, 64, 96, 128, 160, 192, 224, 256] + + @classmethod + def supports_kv_cache_dtype(cls, kv_cache_dtype: CacheDType | None) -> bool: + if kv_cache_dtype is None: + return True + if kv_cache_dtype.startswith("fp8"): + return flash_attn_supports_fp8() + return kv_cache_dtype in ["auto"] + + @classmethod + def supports_compute_capability(cls, capability: DeviceCapability) -> bool: + return capability >= DeviceCapability(8, 0) + + @classmethod + def supports_combination( + cls, + head_size: int, + dtype: torch.dtype, + kv_cache_dtype: CacheDType | None, + block_size: int, + use_mla: bool, + has_sink: bool, + use_sparse: bool, + device_capability: DeviceCapability, + ) -> str | None: + if has_sink and device_capability < DeviceCapability(9, 0): + return "sink not supported on compute capability < 9.0" + return None + @dataclass class FlashAttentionMetadata: @@ -481,8 +494,6 @@ def __init__( self.num_queries_per_kv = self.num_heads // self.num_kv_heads - FlashAttentionBackend.validate_head_size(head_size) - self.attn_type = attn_type self.vllm_flash_attn_version = get_flash_attn_version() # Cache the batch invariant result for use in forward passes diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index 683725b95819..07a0ab41a9e0 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -23,6 +23,7 @@ MultipleOf, ) from vllm.config import CUDAGraphMode, VllmConfig +from vllm.config.cache import CacheDType from vllm.logger import init_logger from vllm.model_executor.layers.batch_invariant import ( vllm_is_batch_invariant, @@ -33,6 +34,7 @@ kNvfp4Quant, ) from vllm.platforms import current_platform +from vllm.platforms.interface import DeviceCapability from vllm.triton_utils import tl, triton from vllm.utils.flashinfer import ( can_use_trtllm_attention, @@ -45,6 +47,7 @@ AttentionCGSupport, AttentionMetadataBuilder, CommonAttentionMetadata, + KVCacheLayoutType, get_kv_cache_layout, get_per_layer_parameters, infer_global_hyperparameters, @@ -158,34 +161,17 @@ def trtllm_prefill_attn_kvfp8_dequant( class FlashInferBackend(AttentionBackend): accept_output_buffer: bool = True - - @classmethod - def get_supported_dtypes(cls) -> list[torch.dtype]: - return [torch.float16, torch.bfloat16] - - @classmethod - def get_supported_head_sizes(cls) -> list[int]: - # https://github.com/flashinfer-ai/flashinfer/blob/3d55c71a62052c590c130897d3a3db49b14fcc34/include/flashinfer/utils.cuh#L157 - return [64, 128, 256] - - @staticmethod - def get_supported_kernel_block_size() -> list[int | MultipleOf]: - # Note: Not sure for all platforms, - # but on Blackwell, only support a page size of - # 16, 32, 64 - return [16, 32, 64] - - @classmethod - def validate_head_size(cls, head_size: int) -> None: - supported_head_sizes = cls.get_supported_head_sizes() - if head_size not in supported_head_sizes: - attn_type = cls.__name__.removesuffix("Backend") - raise ValueError( - f"Head size {head_size} is not supported by {attn_type}. " - f"Supported head sizes are: {supported_head_sizes}. " - "Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use " - "FlexAttention backend which supports all head sizes." - ) + supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16] + # Note: Not sure for all platforms, + # but on Blackwell, only support a page size of + # 16, 32, 64 + supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [16, 32, 64] + supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [ + "auto", + "fp8", + "fp8_e4m3", + "fp8_e5m2", + ] @staticmethod def get_name() -> str: @@ -231,6 +217,26 @@ def get_fp8_dtype_for_flashinfer(kv_cache_dtype: str) -> torch.dtype: else: raise ValueError(f"Unrecognized FP8 dtype: {kv_cache_dtype}") + @classmethod + def get_supported_head_sizes(cls) -> list[int]: + # https://github.com/flashinfer-ai/flashinfer/blob/3d55c71a62052c590c130897d3a3db49b14fcc34/include/flashinfer/utils.cuh#L157 + return [64, 128, 256] + + @classmethod + def supports_compute_capability(cls, capability: DeviceCapability) -> bool: + return capability >= DeviceCapability(7, 5) and capability <= DeviceCapability( + 12, 1 + ) + + @classmethod + def get_required_kv_cache_layout(cls) -> KVCacheLayoutType | None: + from vllm.platforms import current_platform + + capability = current_platform.get_device_capability() + if capability is not None and capability.major == 10: + return "HND" + return None + @dataclass class FlashInferMetadata: @@ -328,7 +334,6 @@ def __init__( ) self.num_kv_heads = self.kv_cache_spec.num_kv_heads self.head_dim = self.kv_cache_spec.head_size - FlashInferBackend.validate_head_size(self.head_dim) self.page_size = self.kv_cache_spec.block_size self.cache_dtype = self.cache_config.cache_dtype diff --git a/vllm/v1/attention/backends/flex_attention.py b/vllm/v1/attention/backends/flex_attention.py index 9af63831cecb..e53cd0d8af4f 100644 --- a/vllm/v1/attention/backends/flex_attention.py +++ b/vllm/v1/attention/backends/flex_attention.py @@ -4,6 +4,7 @@ import math from dataclasses import dataclass +from typing import ClassVar import torch import torch._dynamo.decorators @@ -24,6 +25,7 @@ is_quantized_kv_cache, ) from vllm.config import VllmConfig +from vllm.config.cache import CacheDType from vllm.logger import init_logger from vllm.model_executor.layers.batch_invariant import ( vllm_is_batch_invariant, @@ -71,14 +73,12 @@ def pad_to_multiple(x: torch.Tensor, multiple: int, dim: int): class FlexAttentionBackend(AttentionBackend): accept_output_buffer: bool = True - - @classmethod - def get_supported_dtypes(cls) -> list[torch.dtype]: - return [torch.float16, torch.bfloat16, torch.float32] - - @classmethod - def validate_head_size(cls, head_size: int) -> None: - return # FlexAttention supports any head size + supported_dtypes: ClassVar[list[torch.dtype]] = [ + torch.float16, + torch.bfloat16, + torch.float32, + ] + supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = ["auto"] @staticmethod def get_name() -> str: @@ -106,6 +106,10 @@ def get_builder_cls() -> type["FlexAttentionMetadataBuilder"]: def use_cascade_attention(*args, **kwargs) -> bool: return False + @classmethod + def get_supported_head_sizes(cls) -> list[int]: + return [] + # @torch.compile(fullgraph=True, mode="reduce-overhead") def physical_to_logical_mapping( @@ -720,7 +724,6 @@ def __init__( if kv_sharing_target_layer_name is not None: raise NotImplementedError("FlexAttention does not support kv sharing yet.") - FlexAttentionBackend.validate_head_size(head_size) if is_quantized_kv_cache(self.kv_cache_dtype): raise NotImplementedError( "FlexAttention does not support quantized kv-cache. Yet" diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index e38f7bcfa44e..b4cb5c200da3 100755 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -308,25 +308,13 @@ def get_kv_cache_shape( ) -> tuple[int, ...]: return (num_blocks, block_size, head_size) - @classmethod - def get_supported_dtypes(cls) -> list[torch.dtype]: - return [torch.float16, torch.bfloat16] - @classmethod def get_supported_head_sizes(cls) -> list[int]: return [576] @classmethod - def validate_head_size(cls, head_size: int) -> None: - supported_head_sizes = cls.get_supported_head_sizes() - if head_size not in supported_head_sizes: - attn_type = cls.__name__.removesuffix("Backend") - raise ValueError( - f"Head size {head_size} is not supported by {attn_type}. " - f"Supported head sizes are: {supported_head_sizes}. " - "Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use " - "FlexAttention backend which supports all head sizes." - ) + def is_mla(cls) -> bool: + return True @dataclass @@ -425,8 +413,10 @@ class MLACommonMetadata(Generic[D]): ) = None def __post_init__(self): - if self.head_dim is not None: - MLACommonBackend.validate_head_size(self.head_dim) + if self.head_dim is not None and not MLACommonBackend.supports_head_size( + self.head_dim + ): + raise ValueError(f"Head dimension {self.head_dim} is not supported by MLA.") M = TypeVar("M", bound=MLACommonMetadata) diff --git a/vllm/v1/attention/backends/mla/cutlass_mla.py b/vllm/v1/attention/backends/mla/cutlass_mla.py index c35e238eac4c..0a10ce74cd1d 100644 --- a/vllm/v1/attention/backends/mla/cutlass_mla.py +++ b/vllm/v1/attention/backends/mla/cutlass_mla.py @@ -13,7 +13,9 @@ MultipleOf, is_quantized_kv_cache, ) +from vllm.config.cache import CacheDType from vllm.logger import init_logger +from vllm.platforms.interface import DeviceCapability from vllm.v1.attention.backends.mla.common import ( MLACommonBackend, MLACommonImpl, @@ -33,6 +35,14 @@ class CutlassMLAMetadataBuilder(MLACommonMetadataBuilder[MLACommonMetadata]): class CutlassMLABackend(MLACommonBackend): + supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16] + supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [128] + supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [ + "auto", + "fp8", + "fp8_e4m3", + ] + @staticmethod def get_name() -> str: return "CUTLASS_MLA" @@ -45,9 +55,9 @@ def get_impl_cls() -> type["CutlassMLAImpl"]: def get_builder_cls() -> type["CutlassMLAMetadataBuilder"]: return CutlassMLAMetadataBuilder - @staticmethod - def get_supported_kernel_block_size() -> list[int | MultipleOf]: - return [128] + @classmethod + def supports_compute_capability(cls, capability: DeviceCapability) -> bool: + return capability.major == 10 class SM100Workspace: diff --git a/vllm/v1/attention/backends/mla/flashattn_mla.py b/vllm/v1/attention/backends/mla/flashattn_mla.py index 79b89c7890a2..5662acbe32c2 100644 --- a/vllm/v1/attention/backends/mla/flashattn_mla.py +++ b/vllm/v1/attention/backends/mla/flashattn_mla.py @@ -10,6 +10,7 @@ from vllm.attention.backends.abstract import ( AttentionLayer, AttentionType, + MultipleOf, is_quantized_kv_cache, ) from vllm.attention.utils.fa_utils import ( @@ -17,10 +18,12 @@ get_flash_attn_version, ) from vllm.config import VllmConfig +from vllm.config.cache import CacheDType from vllm.logger import init_logger from vllm.model_executor.layers.batch_invariant import ( vllm_is_batch_invariant, ) +from vllm.platforms.interface import DeviceCapability from vllm.v1.attention.backends.mla.common import ( MLACommonBackend, MLACommonDecodeMetadata, @@ -37,6 +40,10 @@ class FlashAttnMLABackend(MLACommonBackend): + supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16] + supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [MultipleOf(16)] + supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = ["auto"] + @staticmethod def get_name() -> str: return "FLASH_ATTN_MLA" @@ -49,6 +56,26 @@ def get_builder_cls() -> type["FlashAttnMLAMetadataBuilder"]: def get_impl_cls() -> type["FlashAttnMLAImpl"]: return FlashAttnMLAImpl + @classmethod + def supports_compute_capability(cls, capability: DeviceCapability) -> bool: + return capability.major == 9 + + @classmethod + def supports_combination( + cls, + head_size: int, + dtype: torch.dtype, + kv_cache_dtype: CacheDType | None, + block_size: int, + use_mla: bool, + has_sink: bool, + use_sparse: bool, + device_capability: DeviceCapability, + ) -> str | None: + if not flash_attn_supports_mla(): + return "FlashAttention MLA not supported on this device" + return None + @dataclass class FlashAttnMLADecodeMetadata(MLACommonDecodeMetadata): diff --git a/vllm/v1/attention/backends/mla/flashinfer_mla.py b/vllm/v1/attention/backends/mla/flashinfer_mla.py index ebbcfd0eaa2f..b0f514ba4451 100644 --- a/vllm/v1/attention/backends/mla/flashinfer_mla.py +++ b/vllm/v1/attention/backends/mla/flashinfer_mla.py @@ -6,8 +6,14 @@ import torch from flashinfer.decode import trtllm_batch_decode_with_kv_cache_mla -from vllm.attention.backends.abstract import AttentionLayer, AttentionType, MultipleOf +from vllm.attention.backends.abstract import ( + AttentionLayer, + AttentionType, + MultipleOf, +) +from vllm.config.cache import CacheDType from vllm.logger import init_logger +from vllm.platforms.interface import DeviceCapability from vllm.v1.attention.backends.mla.common import ( MLACommonBackend, MLACommonImpl, @@ -15,7 +21,7 @@ MLACommonMetadataBuilder, QueryLenSupport, ) -from vllm.v1.attention.backends.utils import AttentionCGSupport +from vllm.v1.attention.backends.utils import AttentionCGSupport, KVCacheLayoutType logger = init_logger(__name__) @@ -28,6 +34,14 @@ class FlashInferMLAMetadataBuilder(MLACommonMetadataBuilder[MLACommonMetadata]): class FlashInferMLABackend(MLACommonBackend): + supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16] + supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [32, 64] + supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [ + "auto", + "fp8", + "fp8_e4m3", + ] + @staticmethod def get_name() -> str: return "FLASHINFER_MLA" @@ -41,8 +55,12 @@ def get_builder_cls() -> type["FlashInferMLAMetadataBuilder"]: return FlashInferMLAMetadataBuilder @classmethod - def get_supported_kernel_block_size(cls) -> list[int | MultipleOf]: - return [32, 64] + def supports_compute_capability(cls, capability: DeviceCapability) -> bool: + return capability.major == 10 + + @classmethod + def get_required_kv_cache_layout(cls) -> "KVCacheLayoutType | None": + return "HND" g_fi_workspace = torch.zeros( diff --git a/vllm/v1/attention/backends/mla/flashmla.py b/vllm/v1/attention/backends/mla/flashmla.py index 708bb9d63839..8f0364cd58de 100644 --- a/vllm/v1/attention/backends/mla/flashmla.py +++ b/vllm/v1/attention/backends/mla/flashmla.py @@ -13,10 +13,12 @@ is_flashmla_dense_supported, ) from vllm.config import VllmConfig +from vllm.config.cache import CacheDType from vllm.logger import init_logger from vllm.model_executor.layers.batch_invariant import ( vllm_is_batch_invariant, ) +from vllm.platforms.interface import DeviceCapability from vllm.v1.attention.backends.mla.common import ( MLACommonBackend, MLACommonDecodeMetadata, @@ -36,6 +38,14 @@ class FlashMLABackend(MLACommonBackend): + supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16] + supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [64] + supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [ + "auto", + "fp8", + "fp8_e4m3", + ] + @staticmethod def get_name() -> str: return "FLASHMLA" @@ -48,9 +58,30 @@ def get_builder_cls() -> type["FlashMLAMetadataBuilder"]: def get_impl_cls() -> type["FlashMLAImpl"]: return FlashMLAImpl - @staticmethod - def get_supported_kernel_block_size() -> list[int | MultipleOf]: - return [64] + @classmethod + def supports_compute_capability(cls, capability: DeviceCapability) -> bool: + return capability.major in [9, 10] + + @classmethod + def supports_combination( + cls, + head_size: int, + dtype: torch.dtype, + kv_cache_dtype: CacheDType | None, + block_size: int, + use_mla: bool, + has_sink: bool, + use_sparse: bool, + device_capability: DeviceCapability, + ) -> str | None: + if use_sparse: + from vllm.attention.ops.flashmla import is_flashmla_sparse_supported + + return is_flashmla_sparse_supported()[1] + else: + from vllm.attention.ops.flashmla import is_flashmla_dense_supported + + return is_flashmla_dense_supported()[1] @dataclass diff --git a/vllm/v1/attention/backends/mla/flashmla_sparse.py b/vllm/v1/attention/backends/mla/flashmla_sparse.py index bf76549de1ce..4794312eb96e 100644 --- a/vllm/v1/attention/backends/mla/flashmla_sparse.py +++ b/vllm/v1/attention/backends/mla/flashmla_sparse.py @@ -10,6 +10,7 @@ from vllm.attention.backends.abstract import ( AttentionBackend, AttentionLayer, + MultipleOf, ) from vllm.attention.backends.utils import get_mla_dims from vllm.attention.ops.flashmla import ( @@ -18,8 +19,10 @@ get_mla_metadata, ) from vllm.config import VllmConfig +from vllm.config.cache import CacheDType from vllm.logger import init_logger from vllm.platforms import current_platform +from vllm.platforms.interface import DeviceCapability from vllm.triton_utils import tl, triton from vllm.utils.math_utils import cdiv from vllm.v1.attention.backends.mla.common import MLACommonBaseImpl @@ -51,6 +54,9 @@ class FlashMLASparseBackend(AttentionBackend): accept_output_buffer: bool = True + supported_dtypes: ClassVar[list[torch.dtype]] = [torch.bfloat16] + supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [64] + supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = ["auto", "fp8_ds_mla"] @staticmethod def get_name() -> str: @@ -64,6 +70,22 @@ def get_builder_cls() -> type["FlashMLASparseMetadataBuilder"]: def get_impl_cls() -> type["FlashMLASparseImpl"]: return FlashMLASparseImpl + @classmethod + def get_supported_head_sizes(cls) -> list[int]: + return [576] + + @classmethod + def is_mla(cls) -> bool: + return True + + @classmethod + def is_sparse(cls) -> bool: + return True + + @classmethod + def supports_compute_capability(cls, capability: DeviceCapability) -> bool: + return capability.major in [9, 10] + @staticmethod def get_kv_cache_shape( num_blocks: int, @@ -79,14 +101,6 @@ def get_kv_cache_shape( else: return (num_blocks, block_size, head_size) - @classmethod - def get_supported_dtypes(cls) -> list[torch.dtype]: - return [torch.bfloat16] - - @classmethod - def get_supported_head_sizes(cls) -> list[int]: - return [576] - @dataclass class FlashMLASparseMetadata: diff --git a/vllm/v1/attention/backends/mla/indexer.py b/vllm/v1/attention/backends/mla/indexer.py index f3c5bb732871..4f071145625f 100644 --- a/vllm/v1/attention/backends/mla/indexer.py +++ b/vllm/v1/attention/backends/mla/indexer.py @@ -23,6 +23,8 @@ class DeepseekV32IndexerBackend(AttentionBackend): + supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [64] + @classmethod def get_supported_head_sizes(cls) -> list[int]: return [32, 64, 128] @@ -46,10 +48,6 @@ def get_kv_cache_shape( def get_kv_cache_stride_order() -> tuple[int, ...]: return (0, 1, 2) - @classmethod - def get_supported_kernel_block_size(cls) -> list[int | MultipleOf]: - return [64] - @dataclass class DeepseekV32IndexerPrefillChunkMetadata: diff --git a/vllm/v1/attention/backends/mla/triton_mla.py b/vllm/v1/attention/backends/mla/triton_mla.py index 781f77e96319..0149639e8c0b 100644 --- a/vllm/v1/attention/backends/mla/triton_mla.py +++ b/vllm/v1/attention/backends/mla/triton_mla.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import ClassVar import torch @@ -12,11 +13,13 @@ ) from vllm.attention.ops.triton_decode_attention import decode_attention_fwd from vllm.attention.ops.triton_flash_attention import triton_attention +from vllm.config.cache import CacheDType from vllm.logger import init_logger from vllm.model_executor.layers.batch_invariant import ( vllm_is_batch_invariant, ) from vllm.platforms import current_platform +from vllm.platforms.interface import DeviceCapability from vllm.triton_utils import HAS_TRITON from vllm.v1.attention.backends.mla.common import ( MLACommonBackend, @@ -28,6 +31,9 @@ class TritonMLABackend(MLACommonBackend): + supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16] + supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = ["auto"] + @staticmethod def get_name() -> str: return "TRITON_MLA" @@ -36,6 +42,10 @@ def get_name() -> str: def get_impl_cls() -> type["TritonMLAImpl"]: return TritonMLAImpl + @classmethod + def supports_compute_capability(cls, capability: DeviceCapability) -> bool: + return True + class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]): can_return_lse_for_decode: bool = True diff --git a/vllm/v1/attention/backends/rocm_aiter_fa.py b/vllm/v1/attention/backends/rocm_aiter_fa.py index e8d3758a6395..81991244f5d9 100644 --- a/vllm/v1/attention/backends/rocm_aiter_fa.py +++ b/vllm/v1/attention/backends/rocm_aiter_fa.py @@ -3,6 +3,7 @@ """Attention layer with AiterFlashAttention.""" from dataclasses import dataclass +from typing import ClassVar import torch @@ -445,31 +446,13 @@ def use_cascade_attention(self, *args, **kwargs) -> bool: class AiterFlashAttentionBackend(AttentionBackend): accept_output_buffer: bool = True - - @classmethod - def get_supported_dtypes(cls) -> list[torch.dtype]: - return [torch.float16, torch.bfloat16] + supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16] + supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [MultipleOf(16)] @classmethod def get_supported_head_sizes(cls) -> list[int]: return [64, 128, 256] - @staticmethod - def get_supported_kernel_block_size() -> list[int | MultipleOf]: - return [MultipleOf(16)] - - @classmethod - def validate_head_size(cls, head_size: int) -> None: - supported_head_sizes = cls.get_supported_head_sizes() - if head_size not in supported_head_sizes: - attn_type = cls.__name__.removesuffix("Backend") - raise ValueError( - f"Head size {head_size} is not supported by {attn_type}. " - f"Supported head sizes are: {supported_head_sizes}. " - "Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use " - "FlexAttention backend which supports all head sizes." - ) - @staticmethod def get_name() -> str: return "FLASH_ATTN" @@ -531,8 +514,6 @@ def __init__( assert self.num_heads % self.num_kv_heads == 0 self.num_queries_per_kv = self.num_heads // self.num_kv_heads - AiterFlashAttentionBackend.validate_head_size(head_size) - if attn_type != AttentionType.DECODER: raise NotImplementedError( "Encoder self-attention and " diff --git a/vllm/v1/attention/backends/rocm_attn.py b/vllm/v1/attention/backends/rocm_attn.py index 57ba4dc78d9f..1d2c70f65d0f 100644 --- a/vllm/v1/attention/backends/rocm_attn.py +++ b/vllm/v1/attention/backends/rocm_attn.py @@ -152,10 +152,7 @@ def build( class RocmAttentionBackend(AttentionBackend): accept_output_buffer: bool = True - - @classmethod - def get_supported_dtypes(cls) -> list[torch.dtype]: - return [torch.float16, torch.bfloat16] + supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16] @classmethod def get_supported_head_sizes(cls) -> list[int]: @@ -163,12 +160,11 @@ def get_supported_head_sizes(cls) -> list[int]: @classmethod def validate_head_size(cls, head_size: int) -> None: - supported_head_sizes = cls.get_supported_head_sizes() - if head_size not in supported_head_sizes: + if not cls.supports_head_size(head_size): attn_type = cls.__name__.removesuffix("Backend") raise ValueError( f"Head size {head_size} is not supported by {attn_type}. " - f"Supported head sizes are: {supported_head_sizes}. " + f"Supported head sizes are: {cls.get_supported_head_sizes()}. " "Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use " "FlexAttention backend which supports all head sizes." ) diff --git a/vllm/v1/attention/backends/tree_attn.py b/vllm/v1/attention/backends/tree_attn.py index 0c0222d6152f..1bf38ed225a4 100644 --- a/vllm/v1/attention/backends/tree_attn.py +++ b/vllm/v1/attention/backends/tree_attn.py @@ -4,7 +4,7 @@ import ast from dataclasses import dataclass -from typing import Optional +from typing import ClassVar, Optional import torch @@ -30,31 +30,13 @@ class TreeAttentionBackend(AttentionBackend): accept_output_buffer: bool = True - - @classmethod - def get_supported_dtypes(cls) -> list[torch.dtype]: - return [torch.float16, torch.bfloat16] + supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16] + supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [MultipleOf(16)] @classmethod def get_supported_head_sizes(cls) -> list[int]: return [32, 64, 96, 128, 160, 192, 224, 256] - @staticmethod - def get_supported_kernel_block_size() -> list[int | MultipleOf]: - return [MultipleOf(16)] - - @classmethod - def validate_head_size(cls, head_size: int) -> None: - supported_head_sizes = cls.get_supported_head_sizes() - if head_size not in supported_head_sizes: - attn_type = cls.__name__.removesuffix("Backend") - raise ValueError( - f"Head size {head_size} is not supported by {attn_type}. " - f"Supported head sizes are: {supported_head_sizes}. " - "Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use " - "FlexAttention backend which supports all head sizes." - ) - @staticmethod def get_name() -> str: return "TREE_ATTN" @@ -331,8 +313,6 @@ def __init__( else: self.sliding_window = (sliding_window - 1, 0) - TreeAttentionBackend.validate_head_size(head_size) - if attn_type != AttentionType.DECODER: raise NotImplementedError( "Encoder self-attention and " diff --git a/vllm/v1/attention/backends/triton_attn.py b/vllm/v1/attention/backends/triton_attn.py index 0590a87bf8e5..37c0ae61e65d 100644 --- a/vllm/v1/attention/backends/triton_attn.py +++ b/vllm/v1/attention/backends/triton_attn.py @@ -18,12 +18,14 @@ ) from vllm.attention.ops.triton_unified_attention import unified_attention from vllm.config import VllmConfig +from vllm.config.cache import CacheDType from vllm.logger import init_logger from vllm.model_executor.layers.quantization.utils.quant_utils import ( QuantKey, kFp8StaticTensorSym, ) from vllm.platforms import current_platform +from vllm.platforms.interface import DeviceCapability from vllm.v1.attention.backends.utils import ( AttentionCGSupport, AttentionMetadataBuilder, @@ -147,25 +149,18 @@ def build( class TritonAttentionBackend(AttentionBackend): accept_output_buffer: bool = True - - @classmethod - def get_supported_dtypes(cls) -> list[torch.dtype]: - return [torch.float16, torch.bfloat16, torch.float32] - - @staticmethod - def get_supported_kernel_block_size() -> list[int | MultipleOf]: - return [MultipleOf(16)] - - @classmethod - def validate_head_size(cls, head_size: int) -> None: - # Triton Attention supports any head size above 32 - if head_size < 32: - raise ValueError( - f"Head size {head_size} is not supported by TritonAttention." - f"Head sizes need to be larger or equal 32 for this backend. " - "Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use " - "FlexAttention backend which supports all head sizes." - ) + supported_dtypes: ClassVar[list[torch.dtype]] = [ + torch.float16, + torch.bfloat16, + torch.float32, + ] + supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [MultipleOf(16)] + supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [ + "auto", + "fp8", + "fp8_e4m3", + "fp8_e5m2", + ] @staticmethod def get_name() -> str: @@ -195,6 +190,18 @@ def use_cascade_attention(*args, **kwargs) -> bool: def get_builder_cls() -> type["TritonAttentionMetadataBuilder"]: return TritonAttentionMetadataBuilder + @classmethod + def supports_head_size(cls, head_size: int) -> bool: + return head_size >= 32 + + @classmethod + def supports_sink(cls) -> bool: + return True + + @classmethod + def supports_compute_capability(cls, capability: DeviceCapability) -> bool: + return True + class TritonAttentionImpl(AttentionImpl): def fused_output_quant_supported(self, quant_key: QuantKey): @@ -237,8 +244,6 @@ def __init__( self.num_queries_per_kv = self.num_heads // self.num_kv_heads - TritonAttentionBackend.validate_head_size(head_size) - if attn_type != AttentionType.DECODER: raise NotImplementedError( "Encoder self-attention and " diff --git a/vllm/v1/attention/backends/xformers.py b/vllm/v1/attention/backends/xformers.py index 81bdbd641429..d15d79417cc6 100644 --- a/vllm/v1/attention/backends/xformers.py +++ b/vllm/v1/attention/backends/xformers.py @@ -3,7 +3,7 @@ """Attention layer with XFormersAttention.""" from dataclasses import dataclass -from typing import Optional +from typing import ClassVar, Optional import torch @@ -41,10 +41,8 @@ class XFormersAttentionBackend(AttentionBackend): accept_output_buffer: bool = True - - @classmethod - def get_supported_dtypes(cls) -> list[torch.dtype]: - return [torch.float16, torch.bfloat16] + supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16] + supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [MultipleOf(16)] @classmethod def get_supported_head_sizes(cls) -> list[int]: @@ -80,22 +78,6 @@ def get_supported_head_sizes(cls) -> list[int]: 256, ] - @staticmethod - def get_supported_kernel_block_size() -> list[int | MultipleOf]: - return [MultipleOf(16)] - - @classmethod - def validate_head_size(cls, head_size: int) -> None: - supported_head_sizes = cls.get_supported_head_sizes() - if head_size not in supported_head_sizes: - attn_type = cls.__name__.removesuffix("Backend") - raise ValueError( - f"Head size {head_size} is not supported by {attn_type}. " - f"Supported head sizes are: {supported_head_sizes}. " - "Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use " - "FlexAttention backend which supports all head sizes." - ) - @staticmethod def get_name() -> str: return "XFORMERS" @@ -305,8 +287,6 @@ def __init__( logits_soft_cap = 0 self.logits_soft_cap = logits_soft_cap - XFormersAttentionBackend.validate_head_size(head_size) - if attn_type != AttentionType.DECODER: raise NotImplementedError( "Encoder self-attention and " diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 75a4140fd655..55b04949ceb2 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -150,11 +150,15 @@ def __init__( ) # Determine allowed attention backends once during initialization. + from vllm.attention.backends.registry import AttentionBackendEnum + self.allowed_attn_types: tuple | None = None if current_platform.is_rocm(): rocm_types = [TritonAttentionMetadata, FlashAttentionMetadata] - # vllm.v1.attention.backends.rocm_aiter_fa is an optional backend - if find_spec("vllm.v1.attention.backends.rocm_aiter_fa"): + # ROCM_AITER_FA is an optional backend + if find_spec( + AttentionBackendEnum.ROCM_AITER_FA.get_path(include_classname=False) + ): from vllm.v1.attention.backends.rocm_aiter_fa import ( AiterFlashAttentionMetadata, ) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 6fccf2ea2f47..790649b69e5c 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -4371,7 +4371,7 @@ def block_size_is_supported( """ for backend in backends: is_supported = False - for supported_size in backend.get_supported_kernel_block_size(): + for supported_size in backend.supported_kernel_block_sizes: if isinstance(supported_size, int): if block_size == supported_size: is_supported = True @@ -4402,7 +4402,7 @@ def block_size_is_supported( all_int_supported_sizes = set( supported_size for backend in backends - for supported_size in backend.get_supported_kernel_block_size() + for supported_size in backend.supported_kernel_block_sizes if isinstance(supported_size, int) ) From 7dbe6d81d6f17abe93389d97d417e4886467546f Mon Sep 17 00:00:00 2001 From: Chaojun Zhang Date: Tue, 11 Nov 2025 20:46:47 +0800 Subject: [PATCH 053/183] Fix Fused MoE LoRA Triton kernel bug (#28450) Signed-off-by: chaojun-zhang --- vllm/lora/ops/triton_ops/fused_moe_lora_op.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/lora/ops/triton_ops/fused_moe_lora_op.py b/vllm/lora/ops/triton_ops/fused_moe_lora_op.py index 6d6de2529de3..893972144e99 100644 --- a/vllm/lora/ops/triton_ops/fused_moe_lora_op.py +++ b/vllm/lora/ops/triton_ops/fused_moe_lora_op.py @@ -26,7 +26,7 @@ def _get_ptr(lora_weights: list[torch.Tensor], device: torch.device): tensor_ptrs = [] for lora_weight in lora_weights: tensor_ptrs.append(lora_weight.data_ptr()) - ptr_tensor = torch.tensor(tensor_ptrs, device=device) + ptr_tensor = torch.tensor(tensor_ptrs, device=device, dtype=torch.uint64) _LORA_PTR_DICT[key] = ptr_tensor return _LORA_PTR_DICT.get(key) @@ -85,6 +85,7 @@ def _fused_moe_lora_kernel( GROUP_SIZE_M: tl.constexpr, SPLIT_K: tl.constexpr, USE_GDC: tl.constexpr, + launch_pdl: tl.constexpr, IS_PRIMARY: tl.constexpr, ): pid = tl.program_id(axis=0) From afffd3cc8a99ce1cf0f6f1687852e5519d725a3b Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Tue, 11 Nov 2025 21:14:48 +0800 Subject: [PATCH 054/183] [Model] Pass `mm_features` directly into `get_mrope_input_positions` (#28399) Signed-off-by: DarkLight1337 --- vllm/model_executor/models/ernie45_vl.py | 35 +++++------- vllm/model_executor/models/glm4_1v.py | 32 +++++------ vllm/model_executor/models/glm4v.py | 32 +++++------ vllm/model_executor/models/interfaces.py | 22 ++------ vllm/model_executor/models/keye.py | 29 ++++------ vllm/model_executor/models/keye_vl1_5.py | 29 ++++------ vllm/model_executor/models/paddleocr_vl.py | 29 ++++------ .../models/qwen2_5_omni_thinker.py | 46 +++++++++------- vllm/model_executor/models/qwen2_5_vl.py | 36 ++++++------ vllm/model_executor/models/qwen2_vl.py | 37 +++++-------- .../models/qwen3_omni_moe_thinker.py | 55 +++++++++++-------- vllm/model_executor/models/qwen3_vl.py | 30 ++++------ .../models/transformers/multimodal.py | 39 +++++++++---- vllm/multimodal/inputs.py | 13 +++++ vllm/v1/worker/gpu_model_runner.py | 33 ++--------- 15 files changed, 225 insertions(+), 272 deletions(-) diff --git a/vllm/model_executor/models/ernie45_vl.py b/vllm/model_executor/models/ernie45_vl.py index 97182a25f82b..c040b19bba20 100644 --- a/vllm/model_executor/models/ernie45_vl.py +++ b/vllm/model_executor/models/ernie45_vl.py @@ -34,7 +34,7 @@ import torch.nn as nn import torch.nn.functional as F from einops import rearrange, repeat -from transformers import BatchFeature, PretrainedConfig +from transformers import BatchFeature from vllm.attention.backends.registry import AttentionBackendEnum from vllm.attention.layer import ( @@ -58,6 +58,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import ( MultiModalDataDict, + MultiModalFeatureSpec, MultiModalFieldConfig, MultiModalKwargsItems, ) @@ -1433,15 +1434,16 @@ def _set_visual_token_mask(self, input_ids: torch.Tensor) -> None: def get_mrope_input_positions( self, input_tokens: list[int], - hf_config: PretrainedConfig, - image_grid_thw: list[list[int]] | torch.Tensor, - video_grid_thw: list[list[int]] | torch.Tensor, - second_per_grid_ts: list[float] | None = None, - audio_feature_lengths: torch.Tensor | None = None, - use_audio_in_video: bool = False, + mm_features: list[MultiModalFeatureSpec], ) -> tuple[torch.Tensor, int]: - """Get mrope input positions and delta value for Ernie VL.""" + kwargs = MultiModalFeatureSpec.gather_kwargs( + mm_features, + {"image_grid_thw", "video_grid_thw"}, + ) + image_grid_thw = [item.tolist() for item in kwargs.get("image_grid_thw", [])] + video_grid_thw = [item.tolist() for item in kwargs.get("video_grid_thw", [])] + hf_config = self.config image_token_id = hf_config.im_patch_id video_start_token_id = hf_config.video_start_token_id video_end_token_id = hf_config.video_end_token_id @@ -1449,10 +1451,7 @@ def get_mrope_input_positions( temporal_conv_size = hf_config.temporal_conv_size llm_pos_ids_list: list = [] - if not (image_grid_thw is None and video_grid_thw is None): - if isinstance(image_grid_thw, torch.Tensor): - image_grid_thw = image_grid_thw.tolist() - + if image_grid_thw or video_grid_thw: input_token_type: list[str] = [] video_check_flg = False for token in input_tokens: @@ -1484,11 +1483,7 @@ def get_mrope_input_positions( llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 ) if modality_type == "image": - t, h, w = ( - image_grid_thw[mm_data_idx][0], - image_grid_thw[mm_data_idx][1], - image_grid_thw[mm_data_idx][2], - ) + t, h, w = image_grid_thw[mm_data_idx] llm_grid_t, llm_grid_h, llm_grid_w = ( t, h // spatial_conv_size, @@ -1519,11 +1514,7 @@ def get_mrope_input_positions( mm_data_idx += 1 elif modality_type == "video": - t, h, w = ( - video_grid_thw[mm_data_idx][0], - video_grid_thw[mm_data_idx][1], - video_grid_thw[mm_data_idx][2], - ) + t, h, w = video_grid_thw[mm_data_idx] llm_grid_t, llm_grid_h, llm_grid_w = ( t // temporal_conv_size, h // spatial_conv_size, diff --git a/vllm/model_executor/models/glm4_1v.py b/vllm/model_executor/models/glm4_1v.py index 776527fdd973..60cad2e2907f 100644 --- a/vllm/model_executor/models/glm4_1v.py +++ b/vllm/model_executor/models/glm4_1v.py @@ -37,7 +37,7 @@ import torch.nn as nn import torch.nn.functional as F from einops import rearrange -from transformers import BatchFeature, PretrainedConfig +from transformers import BatchFeature from transformers.models.glm4v.configuration_glm4v import Glm4vVisionConfig from transformers.models.glm4v.image_processing_glm4v import ( Glm4vImageProcessor, @@ -70,6 +70,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import ( MultiModalDataDict, + MultiModalFeatureSpec, MultiModalFieldConfig, MultiModalKwargsItems, VideoItem, @@ -1619,25 +1620,23 @@ def get_multimodal_embeddings( def get_mrope_input_positions( self, input_tokens: list[int], - hf_config: "PretrainedConfig", - image_grid_thw: list[list[int]] | torch.Tensor | None, - video_grid_thw: list[list[int]] | torch.Tensor | None, - second_per_grid_ts: list[float] | None = None, - audio_feature_lengths: torch.Tensor | None = None, - use_audio_in_video: bool = False, + mm_features: list[MultiModalFeatureSpec], ) -> tuple[torch.Tensor, int]: - """Get mrope input positions and delta value for GLM4V.""" + kwargs = MultiModalFeatureSpec.gather_kwargs( + mm_features, + {"image_grid_thw", "video_grid_thw"}, + ) + image_grid_thw = [item.tolist() for item in kwargs.get("image_grid_thw", [])] + video_grid_thw = [item.tolist() for item in kwargs.get("video_grid_thw", [])] + hf_config = self.config image_token_id = hf_config.image_token_id video_start_token_id = hf_config.video_start_token_id video_end_token_id = hf_config.video_end_token_id spatial_merge_size = hf_config.vision_config.spatial_merge_size llm_pos_ids_list: list = [] - if not (image_grid_thw is None and video_grid_thw is None): - if isinstance(image_grid_thw, torch.Tensor): - image_grid_thw = image_grid_thw.tolist() - + if image_grid_thw or video_grid_thw: input_token_type: list[str] = [] video_check_flg = False for token in input_tokens: @@ -1669,11 +1668,7 @@ def get_mrope_input_positions( llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 ) if modality_type == "image": - t, h, w = ( - image_grid_thw[mm_data_idx][0], - image_grid_thw[mm_data_idx][1], - image_grid_thw[mm_data_idx][2], - ) + t, h, w = image_grid_thw[mm_data_idx] llm_grid_t, llm_grid_h, llm_grid_w = ( t, h // spatial_merge_size, @@ -1706,8 +1701,7 @@ def get_mrope_input_positions( elif modality_type == "video": t, h, w = ( video_frame_num, - image_grid_thw[mm_data_idx][1], - image_grid_thw[mm_data_idx][2], + *image_grid_thw[mm_data_idx][1:], ) llm_grid_t, llm_grid_h, llm_grid_w = ( t, diff --git a/vllm/model_executor/models/glm4v.py b/vllm/model_executor/models/glm4v.py index ebf6934dddea..899797a51053 100644 --- a/vllm/model_executor/models/glm4v.py +++ b/vllm/model_executor/models/glm4v.py @@ -15,7 +15,7 @@ from torch.nn import LayerNorm from torchvision import transforms from torchvision.transforms import InterpolationMode -from transformers import BatchFeature, PretrainedConfig, PreTrainedTokenizer, TensorType +from transformers import BatchFeature, PreTrainedTokenizer, TensorType from transformers.image_utils import ImageInput from transformers.tokenization_utils_base import TextInput @@ -36,6 +36,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import ( MultiModalDataDict, + MultiModalFeatureSpec, MultiModalFieldConfig, MultiModalKwargsItems, ) @@ -622,25 +623,23 @@ def _process_image_input(self, image_input: GLMVImagePixelInputs) -> torch.Tenso def get_mrope_input_positions( self, input_tokens: list[int], - hf_config: PretrainedConfig, - image_grid_thw: list[list[int]] | torch.Tensor, - video_grid_thw: list[list[int]] | torch.Tensor, - second_per_grid_ts: list[float] | None = None, - audio_feature_lengths: torch.Tensor | None = None, - use_audio_in_video: bool = False, + mm_features: list[MultiModalFeatureSpec], ) -> tuple[torch.Tensor, int]: - """Get mrope input positions and delta value for GLM4V.""" + kwargs = MultiModalFeatureSpec.gather_kwargs( + mm_features, + {"image_grid_thw", "video_grid_thw"}, + ) + image_grid_thw = [item.tolist() for item in kwargs.get("image_grid_thw", [])] + video_grid_thw = [item.tolist() for item in kwargs.get("video_grid_thw", [])] + hf_config = self.config image_token_id = hf_config.image_token_id video_start_token_id = hf_config.video_start_token_id video_end_token_id = hf_config.video_end_token_id spatial_merge_size = hf_config.vision_config.spatial_merge_size llm_pos_ids_list: list = [] - if not (image_grid_thw is None and video_grid_thw is None): - if isinstance(image_grid_thw, torch.Tensor): - image_grid_thw = image_grid_thw.tolist() - + if image_grid_thw or video_grid_thw: input_token_type: list[str] = [] video_check_flg = False for token in input_tokens: @@ -672,11 +671,7 @@ def get_mrope_input_positions( llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 ) if modality_type == "image": - t, h, w = ( - image_grid_thw[mm_data_idx][0], - image_grid_thw[mm_data_idx][1], - image_grid_thw[mm_data_idx][2], - ) + t, h, w = image_grid_thw[mm_data_idx] llm_grid_t, llm_grid_h, llm_grid_w = ( t, h // spatial_merge_size, @@ -709,8 +704,7 @@ def get_mrope_input_positions( elif modality_type == "video": t, h, w = ( video_frame_num, - image_grid_thw[mm_data_idx][1], - image_grid_thw[mm_data_idx][2], + *image_grid_thw[mm_data_idx][1:], ) llm_grid_t, llm_grid_h, llm_grid_w = ( t, diff --git a/vllm/model_executor/models/interfaces.py b/vllm/model_executor/models/interfaces.py index d6a8f86d998b..88b45bf07c0d 100644 --- a/vllm/model_executor/models/interfaces.py +++ b/vllm/model_executor/models/interfaces.py @@ -16,7 +16,6 @@ import torch import torch.nn as nn from torch import Tensor -from transformers import PretrainedConfig from transformers.models.whisper.tokenization_whisper import LANGUAGES from typing_extensions import Self, TypeIs @@ -32,10 +31,12 @@ if TYPE_CHECKING: from vllm.config import VllmConfig from vllm.model_executor.models.utils import WeightsMapper + from vllm.multimodal.inputs import MultiModalFeatureSpec from vllm.sequence import IntermediateTensors else: VllmConfig = object WeightsMapper = object + MultiModalFeatureSpec = object IntermediateTensors = object logger = init_logger(__name__) @@ -991,12 +992,7 @@ class SupportsMRoPE(Protocol): def get_mrope_input_positions( self, input_tokens: list[int], - hf_config: PretrainedConfig, - image_grid_thw: list[list[int]] | torch.Tensor | None, - video_grid_thw: list[list[int]] | torch.Tensor | None, - second_per_grid_ts: list[float] | None = None, - audio_feature_lengths: torch.Tensor | None = None, - use_audio_in_video: bool = False, + mm_features: list["MultiModalFeatureSpec"], ) -> tuple[torch.Tensor, int]: """ Get M-RoPE input positions and delta value for this specific model. @@ -1006,17 +1002,11 @@ def get_mrope_input_positions( Args: input_tokens: List of input token IDs - hf_config: HuggingFace model configuration - image_grid_thw: Image grid dimensions (t, h, w) - video_grid_thw: Video grid dimensions (t, h, w) - second_per_grid_ts: Seconds per grid timestep for videos - audio_feature_lengths: Audio feature lengths for multimodal models - use_audio_in_video: Whether to use audio in video for interleaving + mm_features: Information about each multi-modal data item Returns: - Tuple of (llm_positions, mrope_position_delta) - - llm_positions: Tensor of shape [3, num_tokens] - with T/H/W positions + Tuple of `(llm_positions, mrope_position_delta)` + - llm_positions: Tensor of shape `[3, num_tokens]` with T/H/W positions - mrope_position_delta: Delta for position calculations """ ... diff --git a/vllm/model_executor/models/keye.py b/vllm/model_executor/models/keye.py index 80d7e6c5b0cd..aa0134badc40 100644 --- a/vllm/model_executor/models/keye.py +++ b/vllm/model_executor/models/keye.py @@ -40,6 +40,7 @@ ImageItem, ModalityData, MultiModalDataDict, + MultiModalFeatureSpec, MultiModalFieldConfig, MultiModalKwargsItems, VideoItem, @@ -1627,16 +1628,17 @@ def _process_video_input( def get_mrope_input_positions( self, input_tokens: list[int], - hf_config: PretrainedConfig, - image_grid_thw: list[list[int]] | torch.Tensor, - video_grid_thw: list[list[int]] | torch.Tensor, - second_per_grid_ts: list[float] | None = None, - audio_feature_lengths: torch.Tensor | None = None, - use_audio_in_video: bool = False, + mm_features: list[MultiModalFeatureSpec], ) -> tuple[torch.Tensor, int]: + kwargs = MultiModalFeatureSpec.gather_kwargs( + mm_features, + {"image_grid_thw", "video_grid_thw"}, + ) + image_grid_thw = [item.tolist() for item in kwargs.get("image_grid_thw", [])] + video_grid_thw = [item.tolist() for item in kwargs.get("video_grid_thw", [])] + if isinstance(video_grid_thw, list) and len(video_grid_thw) > 0: video_grid_thw = video_grid_thw[0] - """Get mrope input positions and delta value (Keye series).""" def split_thw(grid_thw: torch.Tensor | list[int]) -> list[list[int]]: """ @@ -1662,6 +1664,7 @@ def split_thw(grid_thw: torch.Tensor | list[int]) -> list[list[int]]: video_grid_thw = split_thw(video_grid_thw) + hf_config = self.config image_token_id = hf_config.image_token_id video_token_id = hf_config.video_token_id spatial_merge_size = hf_config.vision_config.spatial_merge_size @@ -1691,20 +1694,12 @@ def split_thw(grid_thw: torch.Tensor | list[int]) -> list[list[int]]: ed_video = len(input_tokens) + 1 if ed_image < ed_video: - t, h, w = ( - image_grid_thw[image_index][0], - image_grid_thw[image_index][1], - image_grid_thw[image_index][2], - ) + t, h, w = image_grid_thw[image_index] image_index += 1 remain_images -= 1 ed = ed_image else: - t, h, w = ( - video_grid_thw[video_index][0], - video_grid_thw[video_index][1], - video_grid_thw[video_index][2], - ) + t, h, w = video_grid_thw[video_index] video_index += 1 remain_frames -= 1 ed = ed_video diff --git a/vllm/model_executor/models/keye_vl1_5.py b/vllm/model_executor/models/keye_vl1_5.py index 6f95a59d36d2..124e9c2afa21 100644 --- a/vllm/model_executor/models/keye_vl1_5.py +++ b/vllm/model_executor/models/keye_vl1_5.py @@ -21,6 +21,7 @@ from vllm.multimodal.inputs import ( ImageItem, ModalityData, + MultiModalFeatureSpec, MultiModalFieldConfig, MultiModalKwargsItems, VideoItem, @@ -597,16 +598,17 @@ def _process_video_input( def get_mrope_input_positions( self, input_tokens: list[int], - hf_config: PretrainedConfig, - image_grid_thw: list[list[int]] | torch.Tensor, - video_grid_thw: list[list[int]] | torch.Tensor, - second_per_grid_ts: list[float] | None = None, - audio_feature_lengths: torch.Tensor | None = None, - use_audio_in_video: bool = False, + mm_features: list[MultiModalFeatureSpec], ) -> tuple[torch.Tensor, int]: + kwargs = MultiModalFeatureSpec.gather_kwargs( + mm_features, + {"image_grid_thw", "video_grid_thw"}, + ) + image_grid_thw = [item.tolist() for item in kwargs.get("image_grid_thw", [])] + video_grid_thw = [item.tolist() for item in kwargs.get("video_grid_thw", [])] + if isinstance(video_grid_thw, list) and len(video_grid_thw) > 0: video_grid_thw = video_grid_thw[0] - """Get mrope input positions and delta value (Keye series).""" def split_thw(grid_thw: torch.Tensor | list[int]) -> list[list[int]]: """ @@ -632,6 +634,7 @@ def split_thw(grid_thw: torch.Tensor | list[int]) -> list[list[int]]: video_grid_thw = split_thw(video_grid_thw) + hf_config = self.config image_token_id = hf_config.image_token_id video_token_id = hf_config.video_token_id spatial_merge_size = hf_config.vision_config.spatial_merge_size @@ -661,20 +664,12 @@ def split_thw(grid_thw: torch.Tensor | list[int]) -> list[list[int]]: ed_video = len(input_tokens) + 1 if ed_image < ed_video: - t, h, w = ( - image_grid_thw[image_index][0], - image_grid_thw[image_index][1], - image_grid_thw[image_index][2], - ) + t, h, w = image_grid_thw[image_index] image_index += 1 remain_images -= 1 ed = ed_image else: - t, h, w = ( - video_grid_thw[video_index][0], - video_grid_thw[video_index][1], - video_grid_thw[video_index][2], - ) + t, h, w = video_grid_thw[video_index] video_index += 1 remain_frames -= 1 ed = ed_video diff --git a/vllm/model_executor/models/paddleocr_vl.py b/vllm/model_executor/models/paddleocr_vl.py index 86d7d1c11ffe..62994abe8e31 100644 --- a/vllm/model_executor/models/paddleocr_vl.py +++ b/vllm/model_executor/models/paddleocr_vl.py @@ -61,6 +61,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import ( MultiModalDataDict, + MultiModalFeatureSpec, MultiModalFieldConfig, MultiModalKwargs, ) @@ -1184,15 +1185,17 @@ def compute_logits( def get_mrope_input_positions( self, input_tokens: list[int], - hf_config: PretrainedConfig, - image_grid_thw: list[list[int]] | torch.Tensor, - video_grid_thw: list[list[int]] | torch.Tensor, - second_per_grid_ts: list[float], - audio_feature_lengths: torch.Tensor | None = None, - use_audio_in_video: bool = False, + mm_features: list[MultiModalFeatureSpec], ) -> tuple[torch.Tensor, int]: - """Get mrope input positions and delta value.""" + kwargs = MultiModalFeatureSpec.gather_kwargs( + mm_features, + {"image_grid_thw", "video_grid_thw", "second_per_grid_ts"}, + ) + image_grid_thw = [item.tolist() for item in kwargs.get("image_grid_thw", [])] + video_grid_thw = [item.tolist() for item in kwargs.get("video_grid_thw", [])] + second_per_grid_ts = kwargs.get("second_per_grid_ts", []) + hf_config = self.config image_token_id = hf_config.image_token_id video_token_id = hf_config.video_token_id vision_start_token_id = hf_config.vision_start_token_id @@ -1229,20 +1232,12 @@ def get_mrope_input_positions( else: ed_video = len(input_tokens) + 1 if ed_image < ed_video: - t, h, w = ( - image_grid_thw[image_index][0], - image_grid_thw[image_index][1], - image_grid_thw[image_index][2], - ) + t, h, w = image_grid_thw[image_index] image_index += 1 remain_images -= 1 ed = ed_image else: - t, h, w = ( - video_grid_thw[video_index][0], - video_grid_thw[video_index][1], - video_grid_thw[video_index][2], - ) + t, h, w = video_grid_thw[video_index] video_second_per_grid_t = 1.0 if second_per_grid_ts: video_second_per_grid_t = second_per_grid_ts[video_index] diff --git a/vllm/model_executor/models/qwen2_5_omni_thinker.py b/vllm/model_executor/models/qwen2_5_omni_thinker.py index fac281d2caf4..8f74cab0534d 100644 --- a/vllm/model_executor/models/qwen2_5_omni_thinker.py +++ b/vllm/model_executor/models/qwen2_5_omni_thinker.py @@ -68,6 +68,7 @@ ImageItem, ModalityData, MultiModalDataDict, + MultiModalFeatureSpec, MultiModalFieldConfig, MultiModalKwargsItems, NestedTensors, @@ -923,21 +924,9 @@ def get_language_model(self) -> torch.nn.Module: def get_mrope_input_positions( self, input_tokens: list[int], - hf_config: PretrainedConfig, - image_grid_thw: list[list[int]] | torch.Tensor, - video_grid_thw: list[list[int]] | torch.Tensor, - second_per_grid_ts: list[float] | None = None, - audio_feature_lengths: torch.Tensor | None = None, - use_audio_in_video: bool = False, + mm_features: list[MultiModalFeatureSpec], ) -> tuple[torch.Tensor, int]: - """Get mrope input positions and delta value (Qwen2.5-Omni version). - - Differences from MRotaryEmbedding: - 1. Add audio support (and related `audio_feature_lengths`). - 2. Add `use_audio_in_video` option to read audio from video inputs. - In this case, audio and vision position ids will be split into - chunks and interleaved. - + """ Example: (V_i are vision position ids, A_i are audio position ids) @@ -945,11 +934,33 @@ def get_mrope_input_positions( |V_1 ... V_n|A_1 ... A_n|V_n+1 ... V_2n|A_n+1 ... A_2n|... |vision chunk 1|audio chunk 1|vision chunk 2|audio chunk 2 |... """ + kwargs = MultiModalFeatureSpec.gather_kwargs( + mm_features, + { + "image_grid_thw", + "video_grid_thw", + "second_per_grid_ts", + "audio_feature_lengths", + "use_audio_in_video", + }, + ) + image_grid_thw = kwargs.get("image_grid_thw", []) + video_grid_thw = kwargs.get("video_grid_thw", []) + second_per_grid_ts = kwargs.get("second_per_grid_ts", []) + audio_feature_lengths = kwargs.get("audio_feature_lengths", []) + use_audio_in_video = any(kwargs.get("use_audio_in_video", [])) + + image_grid_thw = (torch.stack if image_grid_thw else torch.tensor)( + image_grid_thw + ) + video_grid_thw = (torch.stack if video_grid_thw else torch.tensor)( + video_grid_thw + ) # TODO(fyabc): refactor and share more code with # _vl_get_input_positions_tensor. - thinker_config = hf_config.thinker_config + thinker_config = self.config audio_token_id = thinker_config.audio_token_index image_token_id = thinker_config.image_token_index video_token_id = thinker_config.video_token_index @@ -963,11 +974,6 @@ def get_mrope_input_positions( thinker_config.vision_config, "tokens_per_second", 25 ) - if isinstance(image_grid_thw, list): - image_grid_thw = torch.tensor(image_grid_thw) - if isinstance(video_grid_thw, list): - video_grid_thw = torch.tensor(video_grid_thw) - src_item = input_tokens audio_seqlens = audio_feature_lengths if not second_per_grid_ts: diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index 3292cf8220ff..4662176a1cc5 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -35,7 +35,7 @@ import torch import torch.nn as nn import torch.nn.functional as F -from transformers import BatchFeature, PretrainedConfig +from transformers import BatchFeature from transformers.models.qwen2_5_vl import Qwen2_5_VLProcessor from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import ( Qwen2_5_VLConfig, @@ -75,7 +75,11 @@ compute_retention_mask, recompute_mrope_positions, ) -from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargs +from vllm.multimodal.inputs import ( + MultiModalFeatureSpec, + MultiModalFieldConfig, + MultiModalKwargs, +) from vllm.multimodal.parse import MultiModalDataItems from vllm.multimodal.processing import PromptReplacement, PromptUpdate from vllm.sequence import IntermediateTensors @@ -1120,15 +1124,17 @@ class Qwen2_5_VLForConditionalGeneration( def get_mrope_input_positions( self, input_tokens: list[int], - hf_config: PretrainedConfig, - image_grid_thw: list[list[int]] | torch.Tensor, - video_grid_thw: list[list[int]] | torch.Tensor, - second_per_grid_ts: list[float], - audio_feature_lengths: torch.Tensor | None = None, - use_audio_in_video: bool = False, + mm_features: list[MultiModalFeatureSpec], ) -> tuple[torch.Tensor, int]: - """Get mrope input positions and delta value.""" + kwargs = MultiModalFeatureSpec.gather_kwargs( + mm_features, + {"image_grid_thw", "video_grid_thw", "second_per_grid_ts"}, + ) + image_grid_thw = [item.tolist() for item in kwargs.get("image_grid_thw", [])] + video_grid_thw = [item.tolist() for item in kwargs.get("video_grid_thw", [])] + second_per_grid_ts = kwargs.get("second_per_grid_ts", []) + hf_config = self.config image_token_id = hf_config.image_token_id video_token_id = hf_config.video_token_id vision_start_token_id = hf_config.vision_start_token_id @@ -1165,20 +1171,12 @@ def get_mrope_input_positions( else: ed_video = len(input_tokens) + 1 if ed_image < ed_video: - t, h, w = ( - image_grid_thw[image_index][0], - image_grid_thw[image_index][1], - image_grid_thw[image_index][2], - ) + t, h, w = image_grid_thw[image_index] image_index += 1 remain_images -= 1 ed = ed_image else: - t, h, w = ( - video_grid_thw[video_index][0], - video_grid_thw[video_index][1], - video_grid_thw[video_index][2], - ) + t, h, w = video_grid_thw[video_index] video_second_per_grid_t = 1.0 if second_per_grid_ts: video_second_per_grid_t = second_per_grid_ts[video_index] diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index 61057fa145f4..bbebe7c0f928 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -34,7 +34,7 @@ import torch.nn as nn import torch.nn.functional as F from einops import rearrange, repeat -from transformers import BatchFeature, PretrainedConfig +from transformers import BatchFeature from transformers.models.qwen2_vl import Qwen2VLImageProcessor, Qwen2VLProcessor from transformers.models.qwen2_vl.configuration_qwen2_vl import ( Qwen2VLConfig, @@ -70,6 +70,7 @@ ImageItem, ModalityData, MultiModalDataDict, + MultiModalFeatureSpec, MultiModalFieldConfig, MultiModalKwargsItems, VideoItem, @@ -1240,21 +1241,17 @@ class Qwen2VLForConditionalGeneration( def get_mrope_input_positions( self, input_tokens: list[int], - hf_config: PretrainedConfig, - image_grid_thw: list[list[int]] | torch.Tensor | None, - video_grid_thw: list[list[int]] | torch.Tensor | None, - second_per_grid_ts: list[float] | None = None, - audio_feature_lengths: torch.Tensor | None = None, - use_audio_in_video: bool = False, + mm_features: list[MultiModalFeatureSpec], ) -> tuple[torch.Tensor, int]: - """Get M-RoPE input positions for Qwen2-VL model.""" - if image_grid_thw is None: - image_grid_thw = [] - if video_grid_thw is None: - video_grid_thw = [] - if second_per_grid_ts is None: - second_per_grid_ts = [] + kwargs = MultiModalFeatureSpec.gather_kwargs( + mm_features, + {"image_grid_thw", "video_grid_thw", "second_per_grid_ts"}, + ) + image_grid_thw = [item.tolist() for item in kwargs.get("image_grid_thw", [])] + video_grid_thw = [item.tolist() for item in kwargs.get("video_grid_thw", [])] + second_per_grid_ts = kwargs.get("second_per_grid_ts", []) + hf_config = self.config image_token_id = hf_config.image_token_id video_token_id = hf_config.video_token_id vision_start_token_id = hf_config.vision_start_token_id @@ -1291,20 +1288,12 @@ def get_mrope_input_positions( else: ed_video = len(input_tokens) + 1 if ed_image < ed_video: - t, h, w = ( - image_grid_thw[image_index][0], - image_grid_thw[image_index][1], - image_grid_thw[image_index][2], - ) + t, h, w = image_grid_thw[image_index] image_index += 1 remain_images -= 1 ed = ed_image else: - t, h, w = ( - video_grid_thw[video_index][0], - video_grid_thw[video_index][1], - video_grid_thw[video_index][2], - ) + t, h, w = video_grid_thw[video_index] video_second_per_grid_t = 1.0 if second_per_grid_ts: video_second_per_grid_t = second_per_grid_ts[video_index] diff --git a/vllm/model_executor/models/qwen3_omni_moe_thinker.py b/vllm/model_executor/models/qwen3_omni_moe_thinker.py index 468b25220154..e6cb4442e2be 100755 --- a/vllm/model_executor/models/qwen3_omni_moe_thinker.py +++ b/vllm/model_executor/models/qwen3_omni_moe_thinker.py @@ -65,7 +65,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.qwen2_audio import Qwen2AudioProcessingInfo from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import MultiModalKwargsItems +from vllm.multimodal.inputs import MultiModalFeatureSpec, MultiModalKwargsItems from vllm.multimodal.parse import AudioProcessorItems, MultiModalDataItems from vllm.multimodal.processing import ( BaseMultiModalProcessor, @@ -1414,39 +1414,48 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: def get_mrope_input_positions( self, input_tokens: list[int], - hf_config: PretrainedConfig, - image_grid_thw: list[list[int]] | torch.Tensor | None, - video_grid_thw: list[list[int]] | torch.Tensor | None, - second_per_grid_ts: list[float] | None = None, - audio_feature_lengths: torch.Tensor | None = None, - use_audio_in_video: bool = False, + mm_features: list[MultiModalFeatureSpec], ) -> tuple[torch.Tensor, int]: - config = hf_config.thinker_config - if isinstance(image_grid_thw, list): - image_grid_thw = torch.tensor(image_grid_thw) - if isinstance(video_grid_thw, list): - video_grid_thw = torch.tensor(video_grid_thw) + kwargs = MultiModalFeatureSpec.gather_kwargs( + mm_features, + { + "image_grid_thw", + "video_grid_thw", + "second_per_grid_ts", + "audio_feature_lengths", + "use_audio_in_video", + }, + ) + image_grid_thw = kwargs.get("image_grid_thw", []) + video_grid_thw = kwargs.get("video_grid_thw", []) + second_per_grid_ts = kwargs.get("second_per_grid_ts", []) + audio_feature_lengths = kwargs.get("audio_feature_lengths", []) + use_audio_in_video = any(kwargs.get("use_audio_in_video", [])) + + image_grid_thw = (torch.stack if image_grid_thw else torch.tensor)( + image_grid_thw + ) + video_grid_thw = (torch.stack if video_grid_thw else torch.tensor)( + video_grid_thw + ) + input_ids = torch.tensor(input_tokens) if input_ids is None or input_ids.ndim != 1: raise ValueError("_omni3_get_input_positions_tensor expects 1D input_ids") seq_len = input_ids.shape[0] - if audio_feature_lengths is not None and not isinstance( - audio_feature_lengths, torch.Tensor - ): - audio_feature_lengths = torch.as_tensor( + + if isinstance(audio_feature_lengths, list): + audio_feature_lengths = torch.tensor( audio_feature_lengths, dtype=torch.long ) - if second_per_grid_ts is None: - if video_grid_thw is not None and video_grid_thw.numel() > 0: - second_per_grids = torch.ones( - video_grid_thw.shape[0], dtype=torch.float32 - ) - else: - second_per_grids = torch.tensor([], dtype=torch.float32) + + if not len(second_per_grid_ts) and len(video_grid_thw): + second_per_grids = torch.ones(len(video_grid_thw), dtype=torch.float32) else: second_per_grids = torch.tensor(second_per_grid_ts, dtype=torch.float32) + config = self.config spatial_merge_size = config.vision_config.spatial_merge_size image_token_id = config.image_token_id video_token_id = config.video_token_id diff --git a/vllm/model_executor/models/qwen3_vl.py b/vllm/model_executor/models/qwen3_vl.py index 1be35cde7dbd..97d4667d82e9 100644 --- a/vllm/model_executor/models/qwen3_vl.py +++ b/vllm/model_executor/models/qwen3_vl.py @@ -34,7 +34,7 @@ import torch import torch.nn as nn import torch.nn.functional as F -from transformers import BatchFeature, PretrainedConfig +from transformers import BatchFeature from transformers.models.qwen2_vl import Qwen2VLImageProcessorFast from transformers.models.qwen2_vl.image_processing_qwen2_vl import ( smart_resize as image_smart_resize, @@ -70,6 +70,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import ( MultiModalDataDict, + MultiModalFeatureSpec, MultiModalFieldConfig, MultiModalKwargsItem, MultiModalKwargsItems, @@ -1416,17 +1417,18 @@ def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: def get_mrope_input_positions( self, input_tokens: list[int], - hf_config: PretrainedConfig, - image_grid_thw: list[list[int]] | torch.Tensor, - video_grid_thw: list[list[int]] | torch.Tensor, - second_per_grid_ts: list[float] | None = None, - audio_feature_lengths: torch.Tensor | None = None, - use_audio_in_video: bool = False, + mm_features: list[MultiModalFeatureSpec], ) -> tuple[torch.Tensor, int]: - """Get mrope input positions and delta value.""" + kwargs = MultiModalFeatureSpec.gather_kwargs( + mm_features, + {"image_grid_thw", "video_grid_thw"}, + ) + image_grid_thw = [item.tolist() for item in kwargs.get("image_grid_thw", [])] + video_grid_thw = [item.tolist() for item in kwargs.get("video_grid_thw", [])] video_grid_thw = [[1, h, w] for t, h, w in video_grid_thw for _ in range(t)] + hf_config = self.config image_token_id = hf_config.image_token_id video_token_id = hf_config.video_token_id vision_start_token_id = hf_config.vision_start_token_id @@ -1455,20 +1457,12 @@ def get_mrope_input_positions( else: ed_video = len(input_tokens) + 1 if ed_image < ed_video: - t, h, w = ( - image_grid_thw[image_index][0], - image_grid_thw[image_index][1], - image_grid_thw[image_index][2], - ) + t, h, w = image_grid_thw[image_index] image_index += 1 remain_images -= 1 ed = ed_image else: - t, h, w = ( - video_grid_thw[video_index][0], - video_grid_thw[video_index][1], - video_grid_thw[video_index][2], - ) + t, h, w = video_grid_thw[video_index] video_index += 1 remain_videos -= 1 ed = ed_video diff --git a/vllm/model_executor/models/transformers/multimodal.py b/vllm/model_executor/models/transformers/multimodal.py index 476074542e6a..2efcef68d1c7 100644 --- a/vllm/model_executor/models/transformers/multimodal.py +++ b/vllm/model_executor/models/transformers/multimodal.py @@ -27,6 +27,7 @@ from vllm.multimodal import MultiModalKwargsItems from vllm.multimodal.inputs import ( MultiModalDataDict, + MultiModalFeatureSpec, MultiModalFieldConfig, MultiModalInputs, MultiModalUUIDDict, @@ -38,7 +39,7 @@ from vllm.sequence import IntermediateTensors if TYPE_CHECKING: - from transformers import BatchFeature, PretrainedConfig + from transformers import BatchFeature from vllm.config import VllmConfig from vllm.config.multimodal import BaseDummyOptions @@ -367,20 +368,34 @@ def get_multimodal_embeddings(self, **kwargs): def get_mrope_input_positions( self, input_tokens: list[int], - hf_config: "PretrainedConfig", - image_grid_thw: list[list[int]] | torch.Tensor | None, - video_grid_thw: list[list[int]] | torch.Tensor | None, - second_per_grid_ts: list[float] | None = None, - audio_feature_lengths: torch.Tensor | None = None, - use_audio_in_video: bool = False, + mm_features: list[MultiModalFeatureSpec], ) -> tuple[torch.Tensor, int]: - if any((second_per_grid_ts, audio_feature_lengths, use_audio_in_video)): + kwargs = MultiModalFeatureSpec.gather_kwargs( + mm_features, + { + "image_grid_thw", + "video_grid_thw", + "second_per_grid_ts", + "audio_feature_lengths", + "use_audio_in_video", + }, + ) + if any( + v + for k, v in kwargs.items() + if k not in {"image_grid_thw", "video_grid_thw"} + ): raise NotImplementedError("Transformers backend only supports images.") - if isinstance(image_grid_thw, list): - image_grid_thw = torch.tensor(image_grid_thw) - if isinstance(video_grid_thw, list): - video_grid_thw = torch.tensor(video_grid_thw) + image_grid_thw = kwargs.get("image_grid_thw", []) + video_grid_thw = kwargs.get("video_grid_thw", []) + + image_grid_thw = (torch.stack if image_grid_thw else torch.tensor)( + image_grid_thw + ) + video_grid_thw = (torch.stack if video_grid_thw else torch.tensor)( + video_grid_thw + ) mrope_positions, mrope_position_delta = self.model.get_rope_index( input_ids=torch.tensor(input_tokens).unsqueeze(0), diff --git a/vllm/multimodal/inputs.py b/vllm/multimodal/inputs.py index a05f54191f04..7518a023c5f5 100644 --- a/vllm/multimodal/inputs.py +++ b/vllm/multimodal/inputs.py @@ -249,6 +249,19 @@ class MultiModalFeatureSpec: mm_position: PlaceholderRange """e.g., PlaceholderRange(offset=2, length=336)""" + @staticmethod + def gather_kwargs(features: list["MultiModalFeatureSpec"], keys: set[str]): + kwargs = defaultdict[str, list[NestedTensors]](list) + + for f in features: + item = f.data + if item is not None: + for k in keys: + if k in item: + kwargs[k].append(item[k].data) + + return dict(kwargs) + @dataclass class MultiModalFieldElem: diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 790649b69e5c..fbd3e5f31316 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -892,38 +892,13 @@ def _update_states_after_model_execute( self.input_batch.num_accepted_tokens_cpu[i] = num_tokens def _init_mrope_positions(self, req_state: CachedRequestState): - image_grid_thw = [] - video_grid_thw = [] - second_per_grid_ts = [] - audio_feature_lengths = [] - use_audio_in_video = False - for mm_feature in req_state.mm_features: - mm_item = mm_feature.data - if mm_item is None: - continue - mm_input = mm_item.get_data() - if (t := mm_input.get("image_grid_thw")) is not None: - image_grid_thw.append(t.tolist()) - if (t := mm_input.get("video_grid_thw")) is not None: - video_grid_thw.append(t.tolist()) - if (t := mm_input.get("second_per_grid_ts")) is not None: - second_per_grid_ts.append(t) - if (t := mm_input.get("audio_feature_lengths")) is not None: - audio_feature_lengths.append(t) - if mm_input.get("use_audio_in_video") is True: - use_audio_in_video = True - - assert supports_mrope(self.get_model()), "M-RoPE support is not implemented." + model = self.get_model() + assert supports_mrope(model), "M-RoPE support is not implemented." req_state.mrope_positions, req_state.mrope_position_delta = ( - self.model.get_mrope_input_positions( + model.get_mrope_input_positions( req_state.prompt_token_ids, - hf_config=self.model_config.hf_config, - image_grid_thw=image_grid_thw, - video_grid_thw=video_grid_thw, - second_per_grid_ts=second_per_grid_ts, - audio_feature_lengths=audio_feature_lengths, - use_audio_in_video=use_audio_in_video, + req_state.mm_features, ) ) From 3380543b2075abd6f3e6e283f4eacb307354e33a Mon Sep 17 00:00:00 2001 From: Ido Segev Date: Tue, 11 Nov 2025 15:41:18 +0200 Subject: [PATCH 055/183] Add request timeout override for multi-turn benchmarks (#28386) Signed-off-by: Ido Segev --- .../benchmark_serving_multi_turn.py | 40 ++++++++++++++++--- 1 file changed, 35 insertions(+), 5 deletions(-) diff --git a/benchmarks/multi_turn/benchmark_serving_multi_turn.py b/benchmarks/multi_turn/benchmark_serving_multi_turn.py index 5d2ac66e5ab9..2c1a051cc9c9 100644 --- a/benchmarks/multi_turn/benchmark_serving_multi_turn.py +++ b/benchmarks/multi_turn/benchmark_serving_multi_turn.py @@ -63,6 +63,7 @@ class RequestArgs(NamedTuple): stream: bool limit_min_tokens: int # Use negative value for no limit limit_max_tokens: int # Use negative value for no limit + timeout_sec: int class BenchmarkArgs(NamedTuple): @@ -214,6 +215,7 @@ async def send_request( stream: bool = True, min_tokens: int | None = None, max_tokens: int | None = None, + timeout_sec: int = 120, ) -> ServerResponse: payload = { "model": model, @@ -235,10 +237,16 @@ async def send_request( headers = {"Content-Type": "application/json"} # Calculate the timeout for the request - timeout_sec = 120 if max_tokens is not None: # Assume TPOT of 200ms and use max_tokens to determine timeout - timeout_sec = max(timeout_sec, int(max_tokens * 0.2)) + token_based_timeout = int(max_tokens * 0.2) + if token_based_timeout > timeout_sec: + timeout_sec = token_based_timeout + logger.info( + "Using timeout of %ds based on max_tokens %d", + timeout_sec, + max_tokens, + ) timeout = aiohttp.ClientTimeout(total=timeout_sec) valid_response = True @@ -409,6 +417,7 @@ async def send_turn( req_args.stream, min_tokens, max_tokens, + req_args.timeout_sec, ) if response.valid is False: @@ -676,8 +685,18 @@ async def client_main( except asyncio.exceptions.TimeoutError: num_failures += 1 - logger.exception( - f"{Color.RED}Client {client_id} - Timeout during conversation ID {conv_id} (turn: {current_turn}){Color.RESET}" # noqa: E501 + logger.error( + "%sClient %d - Timeout during conversation ID %s (turn: %d). " + "Base timeout is %ss (set with --request-timeout-sec), but the " + "effective timeout may be longer based on max_tokens. If this " + "is unexpected, consider increasing the timeout or checking " + "model performance.%s", + Color.RED, + client_id, + conv_id, + current_turn, + req_args.timeout_sec, + Color.RESET, ) break # Exit gracefully instead of raising an error @@ -815,6 +834,9 @@ def get_client_config( "Invalid min/max tokens limits (min should not be larger than max)" ) + if args.request_timeout_sec <= 0: + raise ValueError("Request timeout must be a positive number") + # Arguments for API requests chat_url = f"{args.url}/v1/chat/completions" model_name = args.served_model_name if args.served_model_name else args.model @@ -825,6 +847,7 @@ def get_client_config( stream=not args.no_stream, limit_min_tokens=args.limit_min_tokens, limit_max_tokens=args.limit_max_tokens, + timeout_sec=args.request_timeout_sec, ) return client_args, req_args @@ -968,7 +991,7 @@ async def main_mp( f"(is alive: {client.is_alive()}){Color.RESET}" ) - client.join(timeout=120) + client.join(timeout=req_args.timeout_sec + 1) if client.is_alive(): logger.warning( @@ -1351,6 +1374,13 @@ async def main() -> None: action="store_true", help="Verify the LLM output (compare to the answers in the input JSON file)", ) + parser.add_argument( + "--request-timeout-sec", + type=int, + default=120, + help="Timeout in seconds for each API request (default: 120). " + "Automatically increased if max tokens imply longer decoding.", + ) parser.add_argument( "--no-stream", From fa1970201d2efae6db48ca808ba50b63390457db Mon Sep 17 00:00:00 2001 From: Maryam Tahhan Date: Tue, 11 Nov 2025 14:01:11 +0000 Subject: [PATCH 056/183] [Docs] Fix grammar in CPU installation guide (#28461) Signed-off-by: Maryam Tahhan --- docs/getting_started/installation/cpu.md | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/docs/getting_started/installation/cpu.md b/docs/getting_started/installation/cpu.md index 2369eaed1802..dbfefa9a1fe5 100644 --- a/docs/getting_started/installation/cpu.md +++ b/docs/getting_started/installation/cpu.md @@ -93,7 +93,7 @@ Currently, there are no pre-built CPU wheels. ## Related runtime environment variables -- `VLLM_CPU_KVCACHE_SPACE`: specify the KV Cache size (e.g, `VLLM_CPU_KVCACHE_SPACE=40` means 40 GiB space for KV cache), larger setting will allow vLLM running more requests in parallel. This parameter should be set based on the hardware configuration and memory management pattern of users. Default value is `0`. +- `VLLM_CPU_KVCACHE_SPACE`: specify the KV Cache size (e.g, `VLLM_CPU_KVCACHE_SPACE=40` means 40 GiB space for KV cache), larger setting will allow vLLM to run more requests in parallel. This parameter should be set based on the hardware configuration and memory management pattern of users. Default value is `0`. - `VLLM_CPU_OMP_THREADS_BIND`: specify the CPU cores dedicated to the OpenMP threads, can be set as CPU id lists, `auto` (by default), or `nobind` (to disable binding to individual CPU cores and to inherit user-defined OpenMP variables). For example, `VLLM_CPU_OMP_THREADS_BIND=0-31` means there will be 32 OpenMP threads bound on 0-31 CPU cores. `VLLM_CPU_OMP_THREADS_BIND=0-31|32-63` means there will be 2 tensor parallel processes, 32 OpenMP threads of rank0 are bound on 0-31 CPU cores, and the OpenMP threads of rank1 are bound on 32-63 CPU cores. By setting to `auto`, the OpenMP threads of each rank are bound to the CPU cores in each NUMA node respectively. If set to `nobind`, the number of OpenMP threads is determined by the standard `OMP_NUM_THREADS` environment variable. - `VLLM_CPU_NUM_OF_RESERVED_CPU`: specify the number of CPU cores which are not dedicated to the OpenMP threads for each rank. The variable only takes effect when VLLM_CPU_OMP_THREADS_BIND is set to `auto`. Default value is `None`. If the value is not set and use `auto` thread binding, no CPU will be reserved for `world_size == 1`, 1 CPU per rank will be reserved for `world_size > 1`. - `CPU_VISIBLE_MEMORY_NODES`: specify visible NUMA memory nodes for vLLM CPU workers, similar to ```CUDA_VISIBLE_DEVICES```. The variable only takes effect when VLLM_CPU_OMP_THREADS_BIND is set to `auto`. The variable provides more control for the auto thread-binding feature, such as masking nodes and changing nodes binding sequence. @@ -128,7 +128,7 @@ Note, it is recommended to manually reserve 1 CPU for vLLM front-end process whe ### How to decide `VLLM_CPU_OMP_THREADS_BIND`? -- Default `auto` thread-binding is recommended for most cases. Ideally, each OpenMP thread will be bound to a dedicated physical core respectively, threads of each rank will be bound to a same NUMA node respectively, and 1 CPU per rank will be reserved for other vLLM components when `world_size > 1`. If have any performance problems or unexpected binding behaviours, please try to bind threads as following. +- Default `auto` thread-binding is recommended for most cases. Ideally, each OpenMP thread will be bound to a dedicated physical core respectively, threads of each rank will be bound to the same NUMA node respectively, and 1 CPU per rank will be reserved for other vLLM components when `world_size > 1`. If you have any performance problems or unexpected binding behaviours, please try to bind threads as following. - On a hyper-threading enabled platform with 16 logical CPU cores / 8 physical CPU cores: @@ -156,12 +156,12 @@ Note, it is recommended to manually reserve 1 CPU for vLLM front-end process whe 14 0 0 6 6:6:6:0 yes 2401.0000 800.0000 800.000 15 0 0 7 7:7:7:0 yes 2401.0000 800.0000 800.000 - # On this platform, it is recommend to only bind openMP threads on logical CPU cores 0-7 or 8-15 + # On this platform, it is recommended to only bind openMP threads on logical CPU cores 0-7 or 8-15 $ export VLLM_CPU_OMP_THREADS_BIND=0-7 $ python examples/offline_inference/basic/basic.py ``` -- When deploy vLLM CPU backend on a multi-socket machine with NUMA and enable tensor parallel or pipeline parallel, each NUMA node is treated as a TP/PP rank. So be aware to set CPU cores of a single rank on a same NUMA node to avoid cross NUMA node memory access. +- When deploying vLLM CPU backend on a multi-socket machine with NUMA and enable tensor parallel or pipeline parallel, each NUMA node is treated as a TP/PP rank. So be aware to set CPU cores of a single rank on the same NUMA node to avoid cross NUMA node memory access. ### How to decide `VLLM_CPU_KVCACHE_SPACE`? @@ -171,7 +171,7 @@ This value is 4GB by default. Larger space can support more concurrent requests, First of all, please make sure the thread-binding and KV cache space are properly set and take effect. You can check the thread-binding by running a vLLM benchmark and observing CPU cores usage via `htop`. -Inference batch size is an important parameter for the performance. Larger batch usually provides higher throughput, smaller batch provides lower latency. Tuning max batch size starts from default value to balance throughput and latency is an effective way to improve vLLM CPU performance on specific platforms. There are two important related parameters in vLLM: +Inference batch size is an important parameter for the performance. A larger batch usually provides higher throughput, a smaller batch provides lower latency. Tuning the max batch size starting from the default value to balance throughput and latency is an effective way to improve vLLM CPU performance on specific platforms. There are two important related parameters in vLLM: - `--max-num-batched-tokens`, defines the limit of token numbers in a single batch, has more impacts on the first token performance. The default value is set as: - Offline Inference: `4096 * world_size` @@ -192,8 +192,8 @@ vLLM CPU supports data parallel (DP), tensor parallel (TP) and pipeline parallel ### (x86 only) What is the purpose of `VLLM_CPU_MOE_PREPACK` and `VLLM_CPU_SGL_KERNEL`? - Both of them require `amx` CPU flag. - - `VLLM_CPU_MOE_PREPACK` can provides better performance for MoE models - - `VLLM_CPU_SGL_KERNEL` can provides better performance for MoE models and small-batch scenarios. + - `VLLM_CPU_MOE_PREPACK` can provide better performance for MoE models + - `VLLM_CPU_SGL_KERNEL` can provide better performance for MoE models and small-batch scenarios. ### Why do I see `get_mempolicy: Operation not permitted` when running in Docker? From a1448b4b69b15c33b4fbc9a883c4f3b9559ee7db Mon Sep 17 00:00:00 2001 From: bnellnm <49004751+bnellnm@users.noreply.github.com> Date: Tue, 11 Nov 2025 09:29:02 -0500 Subject: [PATCH 057/183] [Kernels] Split up fused_moe/layer.py, isolate more modular kernel code (#28064) --- .../moe/modular_kernel_tools/mk_objects.py | 9 +- vllm/lora/layers/fused_moe.py | 4 +- .../layers/fused_moe/__init__.py | 4 +- .../layers/fused_moe/all2all_utils.py | 160 +++ .../layers/fused_moe/fused_moe_method_base.py | 112 +++ .../fused_moe/fused_moe_modular_method.py | 164 +++ vllm/model_executor/layers/fused_moe/layer.py | 950 +----------------- .../layers/fused_moe/shared_fused_moe.py | 2 +- .../fused_moe/unquantized_fused_moe_method.py | 578 +++++++++++ .../layers/quantization/mxfp4.py | 29 +- 10 files changed, 1064 insertions(+), 948 deletions(-) create mode 100644 vllm/model_executor/layers/fused_moe/all2all_utils.py create mode 100644 vllm/model_executor/layers/fused_moe/fused_moe_method_base.py create mode 100644 vllm/model_executor/layers/fused_moe/fused_moe_modular_method.py create mode 100644 vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py diff --git a/tests/kernels/moe/modular_kernel_tools/mk_objects.py b/tests/kernels/moe/modular_kernel_tools/mk_objects.py index 21eeffb1c726..d79fdfbe07af 100644 --- a/tests/kernels/moe/modular_kernel_tools/mk_objects.py +++ b/tests/kernels/moe/modular_kernel_tools/mk_objects.py @@ -6,6 +6,10 @@ # Fused experts and PrepareFinalize imports import vllm.model_executor.layers.fused_moe.modular_kernel as mk +from vllm.model_executor.layers.fused_moe import TritonExperts +from vllm.model_executor.layers.fused_moe.all2all_utils import ( + maybe_make_prepare_finalize, +) from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import ( BatchedDeepGemmExperts, ) @@ -21,7 +25,6 @@ BatchedTritonExperts, NaiveBatchedExperts, ) -from vllm.model_executor.layers.fused_moe.layer import FusedMoEMethodBase, TritonExperts from vllm.model_executor.layers.fused_moe.prepare_finalize import ( MoEPrepareAndFinalizeNoEP, ) @@ -399,9 +402,7 @@ def make_prepare_finalize( quant_config: FusedMoEQuantConfig, ) -> mk.FusedMoEPrepareAndFinalize: if backend != "naive" and backend is not None: - prepare_finalize = FusedMoEMethodBase._maybe_make_prepare_finalize( - moe, quant_config - ) + prepare_finalize = maybe_make_prepare_finalize(moe, quant_config) assert prepare_finalize is not None return prepare_finalize elif prepare_finalize_type == FlashInferCutlassMoEPrepareAndFinalize: diff --git a/vllm/lora/layers/fused_moe.py b/vllm/lora/layers/fused_moe.py index dadb9e25ba2f..8fb3efa220f6 100644 --- a/vllm/lora/layers/fused_moe.py +++ b/vllm/lora/layers/fused_moe.py @@ -25,7 +25,9 @@ modular_triton_fused_moe, try_get_optimal_moe_config, ) -from vllm.model_executor.layers.fused_moe.layer import FusedMoEModularMethod +from vllm.model_executor.layers.fused_moe.fused_moe_modular_method import ( + FusedMoEModularMethod, +) class FusedMoEWithLoRA(BaseLayerWithLoRA): diff --git a/vllm/model_executor/layers/fused_moe/__init__.py b/vllm/model_executor/layers/fused_moe/__init__.py index cb31045971bd..53d98d0650b4 100644 --- a/vllm/model_executor/layers/fused_moe/__init__.py +++ b/vllm/model_executor/layers/fused_moe/__init__.py @@ -5,9 +5,11 @@ from typing import Any from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig +from vllm.model_executor.layers.fused_moe.fused_moe_method_base import ( + FusedMoEMethodBase, +) from vllm.model_executor.layers.fused_moe.layer import ( FusedMoE, - FusedMoEMethodBase, FusedMoeWeightScaleSupported, ) from vllm.model_executor.layers.fused_moe.modular_kernel import ( diff --git a/vllm/model_executor/layers/fused_moe/all2all_utils.py b/vllm/model_executor/layers/fused_moe/all2all_utils.py new file mode 100644 index 000000000000..2dd625054339 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/all2all_utils.py @@ -0,0 +1,160 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + + +import torch + +from vllm.distributed import ( + get_ep_group, +) +from vllm.model_executor.layers.fused_moe.config import ( + FusedMoEConfig, + FusedMoEParallelConfig, + FusedMoEQuantConfig, +) +from vllm.model_executor.layers.fused_moe.modular_kernel import ( + FusedMoEPrepareAndFinalize, +) +from vllm.platforms import current_platform +from vllm.utils.import_utils import has_deep_ep, has_pplx + +if current_platform.is_cuda_alike(): + if has_pplx(): + from .pplx_prepare_finalize import ( + PplxPrepareAndFinalize, + pplx_hidden_dim_scale_bytes, + ) + if has_deep_ep(): + from .deepep_ht_prepare_finalize import DeepEPHTPrepareAndFinalize + from .deepep_ll_prepare_finalize import ( + DEEPEP_QUANT_BLOCK_SHAPE, + DeepEPLLPrepareAndFinalize, + ) + + +def maybe_roundup_layer_hidden_size( + hidden_size: int, + act_dtype: torch.dtype, + moe_parallel_config: FusedMoEParallelConfig, +) -> int: + """ + Given layer hidden size and MoE configurations, round up hidden_size + if necessary. + + Args: + hidden_size: Layer hidden-size + act_dtype: Data type of the layer activations. + moe_parallel_config: Fused MoE parallelization strategy configuration. + + Return: + Rounded up hidden_size if rounding up is required based on the configs + and all2all backend. + Original hidden size otherwise. + """ + if moe_parallel_config.use_deepep_ht_kernels: + hidden_size = DeepEPHTPrepareAndFinalize.maybe_roundup_layer_hidden_size( + hidden_size, act_dtype + ) + + if moe_parallel_config.use_deepep_ll_kernels: + hidden_size = DeepEPLLPrepareAndFinalize.maybe_roundup_layer_hidden_size( + hidden_size + ) + + return hidden_size + + +def maybe_make_prepare_finalize( + moe: FusedMoEConfig, + quant_config: FusedMoEQuantConfig | None, +) -> FusedMoEPrepareAndFinalize | None: + if not moe.moe_parallel_config.use_all2all_kernels: + return None + + all2all_manager = get_ep_group().device_communicator.all2all_manager + assert all2all_manager is not None + + prepare_finalize: FusedMoEPrepareAndFinalize | None = None + + # TODO: could allow this now + assert not moe.use_flashinfer_cutlass_kernels, "Must be created in modelopt.py" + + if moe.use_pplx_kernels: + assert quant_config is not None + + hidden_dim_bytes, hidden_scale_bytes = pplx_hidden_dim_scale_bytes( + moe.max_num_tokens, + moe.hidden_dim, + moe.in_dtype, + quant_config.quant_dtype, + per_act_token_quant=quant_config.per_act_token_quant, + block_shape=quant_config.block_shape, + ) + + all_to_all_args = dict( + max_num_tokens=moe.max_num_tokens, + num_experts=moe.num_experts, + experts_per_token=moe.experts_per_token, # topk + rank=all2all_manager.rank, + world_size=all2all_manager.world_size, + # dp_size actually means tp_size, bug in pplx kernels + dp_size=all2all_manager.tp_group.world_size, + hidden_dim=moe.hidden_dim, + hidden_dim_bytes=hidden_dim_bytes, + hidden_dim_scale_bytes=hidden_scale_bytes, + ) + + num_dispatchers = ( + all2all_manager.world_size // all2all_manager.tp_group.world_size + ) + + # Intranode pplx a2a takes a group name while internode does not. + if not all2all_manager.internode: + all_to_all_args["group_name"] = all2all_manager.cpu_group.group_name + + handle = all2all_manager.get_handle(all_to_all_args) + + prepare_finalize = PplxPrepareAndFinalize( + handle, + max_num_tokens=moe.max_num_tokens, + num_local_experts=moe.num_local_experts, + num_dispatchers=num_dispatchers, + ) + elif moe.use_deepep_ht_kernels: + assert moe.dp_size == all2all_manager.dp_world_size + + all_to_all_args = dict() + handle = all2all_manager.get_handle(all_to_all_args) + prepare_finalize = DeepEPHTPrepareAndFinalize( + handle, + num_dispatchers=all2all_manager.world_size, + dp_size=all2all_manager.dp_world_size, + rank_expert_offset=all2all_manager.rank * moe.num_local_experts, + ) + + elif moe.use_deepep_ll_kernels: + assert quant_config is not None + all_to_all_args = dict( + max_num_tokens_per_dp_rank=moe.max_num_tokens, + token_hidden_size=moe.hidden_dim, + num_ep_ranks=all2all_manager.world_size, + num_global_experts=moe.num_experts, + num_local_experts=moe.num_experts // all2all_manager.world_size, + ) + handle = all2all_manager.get_handle(all_to_all_args) + + # Note: We may want to use FP8 dispatch just to reduce + # data movement. + use_fp8_dispatch = ( + quant_config.quant_dtype == current_platform.fp8_dtype() + and quant_config.block_shape == DEEPEP_QUANT_BLOCK_SHAPE + ) + + prepare_finalize = DeepEPLLPrepareAndFinalize( + handle, + max_tokens_per_rank=moe.max_num_tokens, + num_dispatchers=all2all_manager.world_size, + use_fp8_dispatch=use_fp8_dispatch, + ) + + return prepare_finalize diff --git a/vllm/model_executor/layers/fused_moe/fused_moe_method_base.py b/vllm/model_executor/layers/fused_moe/fused_moe_method_base.py new file mode 100644 index 000000000000..87f8c8d75a9b --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/fused_moe_method_base.py @@ -0,0 +1,112 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from abc import abstractmethod +from collections.abc import Callable + +import torch + +from vllm.logger import init_logger +from vllm.model_executor.layers.fused_moe.config import ( + FusedMoEConfig, + FusedMoEQuantConfig, +) +from vllm.model_executor.layers.fused_moe.modular_kernel import ( + FusedMoEPermuteExpertsUnpermute, + FusedMoEPrepareAndFinalize, +) +from vllm.model_executor.layers.quantization.base_config import ( + QuantizeMethodBase, +) + +logger = init_logger(__name__) + + +class FusedMoEMethodBase(QuantizeMethodBase): + def __init__(self, moe: FusedMoEConfig): + super().__init__() + self.moe: FusedMoEConfig = moe + self.moe_quant_config: FusedMoEQuantConfig | None = None + + @abstractmethod + def create_weights( + self, + layer: torch.nn.Module, + num_experts: int, + hidden_size: int, + intermediate_size_per_partition: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + raise NotImplementedError + + def uses_weight_scale_2_pattern(self) -> bool: + """ + Returns True if this quantization method uses 'weight_scale_2' pattern + for per-tensor weight scales (e.g., FP4 variants), False otherwise. + + This method should be overridden by subclasses that use the + 'weight_scale_2' pattern instead of the standard 'weight_scale' pattern. + """ + return False + + def maybe_make_prepare_finalize(self) -> FusedMoEPrepareAndFinalize | None: + from .all2all_utils import maybe_make_prepare_finalize + + return maybe_make_prepare_finalize(self.moe, self.moe_quant_config) + + def select_gemm_impl( + self, + prepare_finalize: FusedMoEPrepareAndFinalize, + layer: torch.nn.Module, + ) -> FusedMoEPermuteExpertsUnpermute: + # based on the all2all implementation, select the appropriate + # gemm implementation + raise NotImplementedError( + f"{self.__class__.__name__} must select appropriate gemm " + "implementation based on the prepare_finalize" + ) + + @abstractmethod + def get_fused_moe_quant_config( + self, layer: torch.nn.Module + ) -> FusedMoEQuantConfig | None: + raise NotImplementedError + + @property + def topk_indices_dtype(self) -> torch.dtype | None: + return None + + @property + def supports_eplb(self) -> bool: + return False + + @property + def allow_inplace(self) -> bool: + return False + + @abstractmethod + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool = False, + topk_group: int | None = None, + num_expert_group: int | None = None, + global_num_experts: int = -1, + expert_map: torch.Tensor | None = None, + custom_routing_function: Callable | None = None, + scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, + e_score_correction_bias: torch.Tensor | None = None, + apply_router_weight_on_input: bool = False, + activation: str = "silu", + enable_eplb: bool = False, + expert_load_view: torch.Tensor | None = None, + logical_to_physical_map: torch.Tensor | None = None, + logical_replica_count: torch.Tensor | None = None, + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: + raise NotImplementedError diff --git a/vllm/model_executor/layers/fused_moe/fused_moe_modular_method.py b/vllm/model_executor/layers/fused_moe/fused_moe_modular_method.py new file mode 100644 index 000000000000..43974ba917e4 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/fused_moe_modular_method.py @@ -0,0 +1,164 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from collections.abc import Callable + +import torch + +from vllm.logger import init_logger +from vllm.model_executor.custom_op import CustomOp +from vllm.model_executor.layers.fused_moe.config import ( + FusedMoEQuantConfig, +) +from vllm.model_executor.layers.fused_moe.fused_moe_method_base import ( + FusedMoEMethodBase, +) +from vllm.model_executor.layers.fused_moe.modular_kernel import ( + FusedMoEModularKernel, + FusedMoEPrepareAndFinalize, +) + +logger = init_logger(__name__) + + +@CustomOp.register("modular_fused_moe") +class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp): + def __init__( + self, old_quant_method: FusedMoEMethodBase, experts: FusedMoEModularKernel + ): + super().__init__(old_quant_method.moe) + self.moe_quant_config = old_quant_method.moe_quant_config + self.fused_experts = experts + self.disable_expert_map = getattr( + old_quant_method, + "disable_expert_map", + not self.fused_experts.supports_expert_map(), + ) + self.old_quant_method = old_quant_method + logger.debug("Swapping out %s", self.old_quant_method.__class__.__name__) + + @staticmethod + def make( + moe_layer: torch.nn.Module, + old_quant_method: FusedMoEMethodBase, + prepare_finalize: FusedMoEPrepareAndFinalize, + shared_experts: torch.nn.Module | None, + ) -> "FusedMoEModularMethod": + return FusedMoEModularMethod( + old_quant_method, + FusedMoEModularKernel( + prepare_finalize, + old_quant_method.select_gemm_impl(prepare_finalize, moe_layer), + shared_experts, + ), + ) + + @property + def topk_indices_dtype(self) -> torch.dtype | None: + return self.fused_experts.prepare_finalize.topk_indices_dtype() + + @property + def supports_eplb(self) -> bool: + return self.old_quant_method.supports_eplb + + @property + def allow_inplace(self) -> bool: + return self.old_quant_method.allow_inplace + + def create_weights( + self, + layer: torch.nn.Module, + num_experts: int, + hidden_size: int, + intermediate_size_per_partition: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + raise NotImplementedError + + def get_fused_moe_quant_config( + self, layer: torch.nn.Module + ) -> FusedMoEQuantConfig | None: + return self.moe_quant_config + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool = False, + topk_group: int | None = None, + num_expert_group: int | None = None, + global_num_experts: int = -1, + expert_map: torch.Tensor | None = None, + custom_routing_function: Callable | None = None, + scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, + e_score_correction_bias: torch.Tensor | None = None, + apply_router_weight_on_input: bool = False, + activation: str = "silu", + enable_eplb: bool = False, + expert_load_view: torch.Tensor | None = None, + logical_to_physical_map: torch.Tensor | None = None, + logical_replica_count: torch.Tensor | None = None, + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: + # Is getattr needed? + zero_expert_num = getattr(layer, "zero_expert_num", 0) + zero_expert_type = getattr(layer, "zero_expert_type", None) + + if enable_eplb: + if self.supports_eplb: + assert expert_load_view is not None + assert logical_to_physical_map is not None + assert logical_replica_count is not None + else: + raise NotImplementedError( + "EPLB is not supported for " + f"{self.old_quant_method.__class__.__name__}." + ) + + topk_weights, topk_ids, zero_expert_result = layer.select_experts( + hidden_states=x, + router_logits=router_logits, + use_grouped_topk=use_grouped_topk, + top_k=top_k, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + routed_scaling_factor=routed_scaling_factor, + e_score_correction_bias=e_score_correction_bias, + indices_type=self.topk_indices_dtype, + enable_eplb=enable_eplb, + expert_map=expert_map, + expert_load_view=expert_load_view, + logical_to_physical_map=logical_to_physical_map, + logical_replica_count=logical_replica_count, + global_num_experts=global_num_experts, + zero_expert_num=zero_expert_num, + zero_expert_type=zero_expert_type, + ) + + result = self.fused_experts( + hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + inplace=self.allow_inplace, + activation=activation, + global_num_experts=global_num_experts, + apply_router_weight_on_input=apply_router_weight_on_input, + expert_map=None if self.disable_expert_map else expert_map, + ) + + if zero_expert_num != 0 and zero_expert_type is not None: + assert not isinstance(result, tuple), ( + "Shared + zero experts are mutually exclusive not yet supported" + ) + return result, zero_expert_result + else: + return result diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 39547cc83c7b..e198322ba7a8 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from abc import abstractmethod from collections.abc import Callable, Iterable from contextlib import nullcontext from enum import Enum @@ -27,17 +26,13 @@ from vllm.logger import init_logger from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.layers.fused_moe.config import ( - FUSED_MOE_UNQUANTIZED_CONFIG, FusedMoEConfig, FusedMoEParallelConfig, FusedMoEQuantConfig, RoutingMethodType, - biased_moe_quant_config, ) from vllm.model_executor.layers.fused_moe.fused_moe import zero_experts_compute_triton from vllm.model_executor.layers.fused_moe.modular_kernel import ( - FusedMoEActivationFormat, - FusedMoEModularKernel, FusedMoEPermuteExpertsUnpermute, FusedMoEPrepareAndFinalize, ) @@ -47,35 +42,17 @@ from vllm.model_executor.layers.fused_moe.routing_simulator import RoutingSimulator from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, - QuantizeMethodBase, ) from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( is_flashinfer_supporting_global_sf, ) -from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform -from vllm.platforms.interface import CpuArchEnum -from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe -from vllm.utils.import_utils import has_deep_ep, has_pplx from vllm.utils.math_utils import cdiv, round_up from vllm.utils.torch_utils import current_stream, direct_register_custom_op from vllm.v1.worker.ubatching import dbo_current_ubatch_id if current_platform.is_cuda_alike(): - from .fused_batched_moe import BatchedTritonExperts - from .fused_moe import TritonExperts, eplb_map_to_physical_and_record, fused_experts - - if has_pplx(): - from .pplx_prepare_finalize import ( - PplxPrepareAndFinalize, - pplx_hidden_dim_scale_bytes, - ) - if has_deep_ep(): - from .deepep_ht_prepare_finalize import DeepEPHTPrepareAndFinalize - from .deepep_ll_prepare_finalize import ( - DEEPEP_QUANT_BLOCK_SHAPE, - DeepEPLLPrepareAndFinalize, - ) + from .fused_moe import eplb_map_to_physical_and_record, fused_experts else: fused_experts = None # type: ignore FusedMoEPermuteExpertsUnpermute = object # type: ignore @@ -102,6 +79,16 @@ def _eplb_map_to_physical_and_record( else: fused_moe_pallas = None # type: ignore +from vllm.model_executor.layers.fused_moe.fused_moe_method_base import ( + FusedMoEMethodBase, +) +from vllm.model_executor.layers.fused_moe.fused_moe_modular_method import ( + FusedMoEModularMethod, +) +from vllm.model_executor.layers.fused_moe.unquantized_fused_moe_method import ( + UnquantizedFusedMoEMethod, +) + logger = init_logger(__name__) @@ -112,885 +99,6 @@ class FusedMoeWeightScaleSupported(Enum): BLOCK = "block" -class FusedMoEMethodBase(QuantizeMethodBase): - def __init__(self, moe: FusedMoEConfig): - super().__init__() - self.moe: FusedMoEConfig = moe - self.moe_quant_config: FusedMoEQuantConfig | None = None - - @abstractmethod - def create_weights( - self, - layer: torch.nn.Module, - num_experts: int, - hidden_size: int, - intermediate_size_per_partition: int, - params_dtype: torch.dtype, - **extra_weight_attrs, - ): - raise NotImplementedError - - def uses_weight_scale_2_pattern(self) -> bool: - """ - Returns True if this quantization method uses 'weight_scale_2' pattern - for per-tensor weight scales (e.g., FP4 variants), False otherwise. - - This method should be overridden by subclasses that use the - 'weight_scale_2' pattern instead of the standard 'weight_scale' pattern. - """ - return False - - @staticmethod - def _maybe_make_prepare_finalize( - moe: FusedMoEConfig, - quant_config: FusedMoEQuantConfig | None, - ) -> FusedMoEPrepareAndFinalize | None: - all2all_manager = get_ep_group().device_communicator.all2all_manager - assert all2all_manager is not None - - prepare_finalize: FusedMoEPrepareAndFinalize | None = None - - # TODO: could allow this now - assert not moe.use_flashinfer_cutlass_kernels, "Must be created in modelopt.py" - - if moe.use_pplx_kernels: - assert quant_config is not None - - hidden_dim_bytes, hidden_scale_bytes = pplx_hidden_dim_scale_bytes( - moe.max_num_tokens, - moe.hidden_dim, - moe.in_dtype, - quant_config.quant_dtype, - per_act_token_quant=quant_config.per_act_token_quant, - block_shape=quant_config.block_shape, - ) - - all_to_all_args = dict( - max_num_tokens=moe.max_num_tokens, - num_experts=moe.num_experts, - experts_per_token=moe.experts_per_token, # topk - rank=all2all_manager.rank, - world_size=all2all_manager.world_size, - # dp_size actually means tp_size, bug in pplx kernels - dp_size=all2all_manager.tp_group.world_size, - hidden_dim=moe.hidden_dim, - hidden_dim_bytes=hidden_dim_bytes, - hidden_dim_scale_bytes=hidden_scale_bytes, - ) - - num_dispatchers = ( - all2all_manager.world_size // all2all_manager.tp_group.world_size - ) - - # Intranode pplx a2a takes a group name while internode does not. - if not all2all_manager.internode: - all_to_all_args["group_name"] = all2all_manager.cpu_group.group_name - - handle = all2all_manager.get_handle(all_to_all_args) - - prepare_finalize = PplxPrepareAndFinalize( - handle, - max_num_tokens=moe.max_num_tokens, - num_local_experts=moe.num_local_experts, - num_dispatchers=num_dispatchers, - ) - elif moe.use_deepep_ht_kernels: - assert moe.dp_size == all2all_manager.dp_world_size - - all_to_all_args = dict() - handle = all2all_manager.get_handle(all_to_all_args) - prepare_finalize = DeepEPHTPrepareAndFinalize( - handle, - num_dispatchers=all2all_manager.world_size, - dp_size=all2all_manager.dp_world_size, - rank_expert_offset=all2all_manager.rank * moe.num_local_experts, - ) - - elif moe.use_deepep_ll_kernels: - assert quant_config is not None - all_to_all_args = dict( - max_num_tokens_per_dp_rank=moe.max_num_tokens, - token_hidden_size=moe.hidden_dim, - num_ep_ranks=all2all_manager.world_size, - num_global_experts=moe.num_experts, - num_local_experts=moe.num_experts // all2all_manager.world_size, - ) - handle = all2all_manager.get_handle(all_to_all_args) - - # Note: We may want to use FP8 dispatch just to reduce - # data movement. - use_fp8_dispatch = ( - quant_config.quant_dtype == current_platform.fp8_dtype() - and quant_config.block_shape == DEEPEP_QUANT_BLOCK_SHAPE - ) - - prepare_finalize = DeepEPLLPrepareAndFinalize( - handle, - max_tokens_per_rank=moe.max_num_tokens, - num_dispatchers=all2all_manager.world_size, - use_fp8_dispatch=use_fp8_dispatch, - ) - - return prepare_finalize - - def maybe_make_prepare_finalize(self) -> FusedMoEPrepareAndFinalize | None: - if self.moe.moe_parallel_config.use_all2all_kernels: - return FusedMoEMethodBase._maybe_make_prepare_finalize( - self.moe, self.moe_quant_config - ) - else: - return None - - def maybe_init_modular_kernel( - self, layer: torch.nn.Module - ) -> FusedMoEModularKernel | None: - assert self.moe is not None - - # We must get the quant config here so that the layer is - # completely initialized, i.e. all weights loaded and post - # processed. - self.moe_quant_config = self.get_fused_moe_quant_config(layer) - - prepare_finalize = self.maybe_make_prepare_finalize() - - if prepare_finalize is not None: - logger.debug( - "%s for %s(%s)", prepare_finalize.__class__.__name__, self, id(self) - ) - experts = self.select_gemm_impl(prepare_finalize, layer) - return FusedMoEModularKernel( - prepare_finalize, - experts, - layer.shared_experts, - ) - else: - return None - - def select_gemm_impl( - self, - prepare_finalize: FusedMoEPrepareAndFinalize, - layer: torch.nn.Module, - ) -> FusedMoEPermuteExpertsUnpermute: - # based on the all2all implementation, select the appropriate - # gemm implementation - raise NotImplementedError( - f"{self.__class__.__name__} must select appropriate gemm " - "implementation based on the prepare_finalize" - ) - - @abstractmethod - def get_fused_moe_quant_config( - self, layer: torch.nn.Module - ) -> FusedMoEQuantConfig | None: - raise NotImplementedError - - @property - def topk_indices_dtype(self) -> torch.dtype | None: - return None - - @property - def supports_eplb(self) -> bool: - return False - - @property - def allow_inplace(self) -> bool: - return False - - @abstractmethod - def apply( - self, - layer: torch.nn.Module, - x: torch.Tensor, - router_logits: torch.Tensor, - top_k: int, - renormalize: bool, - use_grouped_topk: bool = False, - topk_group: int | None = None, - num_expert_group: int | None = None, - global_num_experts: int = -1, - expert_map: torch.Tensor | None = None, - custom_routing_function: Callable | None = None, - scoring_func: str = "softmax", - routed_scaling_factor: float = 1.0, - e_score_correction_bias: torch.Tensor | None = None, - apply_router_weight_on_input: bool = False, - activation: str = "silu", - enable_eplb: bool = False, - expert_load_view: torch.Tensor | None = None, - logical_to_physical_map: torch.Tensor | None = None, - logical_replica_count: torch.Tensor | None = None, - ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: - raise NotImplementedError - - -@CustomOp.register("modular_fused_moe") -class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp): - def __init__( - self, - old_quant_method: FusedMoEMethodBase, - fused_experts: FusedMoEModularKernel, - ): - super().__init__(old_quant_method.moe) - # Find better way to copy attributes? Should we even copy attributes? - # self.__dict__.update(old_quant_method.__dict__) - self.moe_quant_config = old_quant_method.moe_quant_config - self.fused_experts = fused_experts - self.disable_expert_map = getattr( - old_quant_method, - "disable_expert_map", - not fused_experts.supports_expert_map(), - ) - self.old_quant_method = old_quant_method - logger.debug("Swapping out %s", self.old_quant_method.__class__.__name__) - - @property - def topk_indices_dtype(self) -> torch.dtype | None: - return self.fused_experts.prepare_finalize.topk_indices_dtype() - - @property - def supports_eplb(self) -> bool: - return self.old_quant_method.supports_eplb - - @property - def allow_inplace(self) -> bool: - return self.old_quant_method.allow_inplace - - def create_weights( - self, - layer: torch.nn.Module, - num_experts: int, - hidden_size: int, - intermediate_size_per_partition: int, - params_dtype: torch.dtype, - **extra_weight_attrs, - ): - raise NotImplementedError - - def get_fused_moe_quant_config( - self, layer: torch.nn.Module - ) -> FusedMoEQuantConfig | None: - return self.moe_quant_config - - def apply( - self, - layer: torch.nn.Module, - x: torch.Tensor, - router_logits: torch.Tensor, - top_k: int, - renormalize: bool, - use_grouped_topk: bool = False, - topk_group: int | None = None, - num_expert_group: int | None = None, - global_num_experts: int = -1, - expert_map: torch.Tensor | None = None, - custom_routing_function: Callable | None = None, - scoring_func: str = "softmax", - routed_scaling_factor: float = 1.0, - e_score_correction_bias: torch.Tensor | None = None, - apply_router_weight_on_input: bool = False, - activation: str = "silu", - enable_eplb: bool = False, - expert_load_view: torch.Tensor | None = None, - logical_to_physical_map: torch.Tensor | None = None, - logical_replica_count: torch.Tensor | None = None, - ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: - # Is getattr needed? - zero_expert_num = getattr(layer, "zero_expert_num", 0) - zero_expert_type = getattr(layer, "zero_expert_type", None) - - if enable_eplb: - if self.supports_eplb: - assert expert_load_view is not None - assert logical_to_physical_map is not None - assert logical_replica_count is not None - assert isinstance(layer, FusedMoE) - else: - raise NotImplementedError( - "EPLB is not supported for " - f"{self.old_quant_method.__class__.__name__}." - ) - - topk_weights, topk_ids, zero_expert_result = FusedMoE.select_experts( - hidden_states=x, - router_logits=router_logits, - use_grouped_topk=use_grouped_topk, - top_k=top_k, - renormalize=renormalize, - topk_group=topk_group, - num_expert_group=num_expert_group, - custom_routing_function=custom_routing_function, - scoring_func=scoring_func, - routed_scaling_factor=routed_scaling_factor, - e_score_correction_bias=e_score_correction_bias, - indices_type=self.topk_indices_dtype, - enable_eplb=enable_eplb, - expert_map=expert_map, - expert_load_view=expert_load_view, - logical_to_physical_map=logical_to_physical_map, - logical_replica_count=logical_replica_count, - global_num_experts=global_num_experts, - zero_expert_num=zero_expert_num, - zero_expert_type=zero_expert_type, - ) - - result = self.fused_experts( - hidden_states=x, - w1=layer.w13_weight, - w2=layer.w2_weight, - topk_weights=topk_weights, - topk_ids=topk_ids, - inplace=self.allow_inplace, - activation=activation, - global_num_experts=global_num_experts, - apply_router_weight_on_input=apply_router_weight_on_input, - expert_map=None if self.disable_expert_map else expert_map, - ) - - if zero_expert_num != 0 and zero_expert_type is not None: - assert not isinstance(result, tuple), ( - "Shared + zero experts are mutually exclusive not yet supported" - ) - return result, zero_expert_result - else: - return result - - -@CustomOp.register("unquantized_fused_moe") -class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): - """MoE method without quantization.""" - - def __init__(self, moe: FusedMoEConfig): - super().__init__(moe) - - self.rocm_aiter_moe_enabled = rocm_aiter_ops.is_fused_moe_enabled() - if self.rocm_aiter_moe_enabled: - from .rocm_aiter_fused_moe import rocm_aiter_fused_experts - - self.rocm_aiter_fused_experts = rocm_aiter_fused_experts - else: - self.rocm_aiter_fused_experts = None # type: ignore - - # FlashInfer CUTLASS MoE is only supported on Hopper and later GPUS - self.flashinfer_cutlass_moe_enabled = ( - has_flashinfer_cutlass_fused_moe() - and envs.VLLM_USE_FLASHINFER_MOE_FP16 - and self.moe.moe_parallel_config.use_ep - and self.moe.moe_parallel_config.dp_size == 1 - and current_platform.get_device_capability()[0] >= 9 - ) - if self.flashinfer_cutlass_moe_enabled: - logger.info_once( - "Enabling FlashInfer CUTLASS MoE for UnquantizedFusedMoEMethod" - ) - from functools import partial - - from .flashinfer_cutlass_moe import flashinfer_cutlass_moe - - self.flashinfer_cutlass_moe = partial( - flashinfer_cutlass_moe, - quant_config=FUSED_MOE_UNQUANTIZED_CONFIG, - tp_rank=self.moe.moe_parallel_config.tp_rank, - tp_size=self.moe.moe_parallel_config.tp_size, - ep_rank=self.moe.moe_parallel_config.ep_rank, - ep_size=self.moe.moe_parallel_config.ep_size, - ) - else: - if ( - self.moe.moe_parallel_config.use_ep - and self.moe.moe_parallel_config.dp_size == 1 - ): - logger.info_once( - "FlashInfer CUTLASS MoE is available for EP" - " but not enabled, consider setting" - " VLLM_USE_FLASHINFER_MOE_FP16=1 to enable it.", - scope="local", - ) - elif self.moe.moe_parallel_config.dp_size > 1: - logger.info_once( - "FlashInfer CUTLASS MoE is currently not available for DP.", - scope="local", - ) - self.flashinfer_cutlass_moe = None # type: ignore - - @property - def supports_eplb(self) -> bool: - return True - - @property - def allow_inplace(self) -> bool: - return True - - def maybe_make_prepare_finalize(self) -> FusedMoEPrepareAndFinalize | None: - if self.rocm_aiter_moe_enabled: - return None - else: - return super().maybe_make_prepare_finalize() - - def select_gemm_impl( - self, - prepare_finalize: FusedMoEPrepareAndFinalize, - layer: torch.nn.Module, - ) -> FusedMoEPermuteExpertsUnpermute: - assert self.moe_quant_config is not None - if ( - prepare_finalize.activation_format - == FusedMoEActivationFormat.BatchedExperts - ): - logger.debug("BatchedTritonExperts %s", self.moe) - return BatchedTritonExperts( - max_num_tokens=self.moe.max_num_tokens, - num_dispatchers=prepare_finalize.num_dispatchers(), - quant_config=self.moe_quant_config, - ) - else: - logger.debug("TritonExperts %s", self.moe) - return TritonExperts(self.moe_quant_config) - - def create_weights( - self, - layer: torch.nn.Module, - num_experts: int, - hidden_size: int, - intermediate_size_per_partition: int, - params_dtype: torch.dtype, - **extra_weight_attrs, - ): - if self.moe.is_act_and_mul: - w13_up_dim = 2 * intermediate_size_per_partition - else: - w13_up_dim = intermediate_size_per_partition - # Fused gate_up_proj (column parallel) - w13_weight = torch.nn.Parameter( - torch.empty( - num_experts, - w13_up_dim, - hidden_size, - dtype=params_dtype, - ), - requires_grad=False, - ) - layer.register_parameter("w13_weight", w13_weight) - set_weight_attrs(w13_weight, extra_weight_attrs) - if self.moe.has_bias: - w13_bias = torch.nn.Parameter( - torch.zeros(num_experts, w13_up_dim, dtype=params_dtype), - requires_grad=False, - ) - layer.register_parameter("w13_bias", w13_bias) - set_weight_attrs(w13_bias, extra_weight_attrs) - # down_proj (row parallel) - w2_weight = torch.nn.Parameter( - torch.empty( - num_experts, - hidden_size, - intermediate_size_per_partition, - dtype=params_dtype, - ), - requires_grad=False, - ) - layer.register_parameter("w2_weight", w2_weight) - set_weight_attrs(w2_weight, extra_weight_attrs) - if self.moe.has_bias: - w2_bias = torch.nn.Parameter( - torch.zeros(num_experts, hidden_size, dtype=params_dtype), - requires_grad=False, - ) - layer.register_parameter("w2_bias", w2_bias) - set_weight_attrs(w2_bias, extra_weight_attrs) - - def _maybe_pad_weight(self, weight: torch.Tensor) -> torch.Tensor: - # Pad the weight tensor. This is an optimization on ROCm platform, which - # can benefit from tensors located far enough from one another in memory - if ( - envs.VLLM_ROCM_MOE_PADDING - and current_platform.is_rocm() - and weight.stride(-1) == 1 - and (weight.stride(-2) * weight.element_size()) % 512 == 0 - ): - num_pad = 256 // weight.element_size() - weight = F.pad(weight, (0, num_pad), "constant", 0)[..., :-num_pad] - torch.cuda.empty_cache() - - return weight - - def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - super().process_weights_after_loading(layer) - - # Padding the weight for better performance on ROCm - layer.w13_weight.data = self._maybe_pad_weight(layer.w13_weight.data) - layer.w2_weight.data = self._maybe_pad_weight(layer.w2_weight.data) - - if self.rocm_aiter_moe_enabled: - shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights( - layer.w13_weight.data, layer.w2_weight.data - ) - - layer.w13_weight.data = shuffled_w13 - layer.w2_weight.data = shuffled_w2 - - if self.flashinfer_cutlass_moe_enabled: - # Swap halves to arrange as [w3; w1] (kernel expectation) - w1_w, w3_w = torch.chunk(layer.w13_weight.data, 2, dim=1) - w13_weight_swapped = torch.cat([w3_w, w1_w], dim=1) - layer.w13_weight.data = w13_weight_swapped.contiguous() - - if current_platform.is_xpu(): - import intel_extension_for_pytorch as ipex - - ep_rank_start = self.moe.ep_rank * self.moe.num_local_experts - layer.ipex_fusion = ipex.llm.modules.GatedMLPMOE( - layer.w13_weight, - layer.w2_weight, - use_prepack=True, - experts_start_id=ep_rank_start, - ) - elif current_platform.is_cpu(): - from vllm.model_executor.layers.fused_moe import cpu_fused_moe - - if current_platform.get_cpu_architecture() == CpuArchEnum.X86: - from vllm.model_executor.layers.utils import check_cpu_sgl_kernel - - dtype_w13 = layer.w13_weight.dtype - _, n_w13, k_w13 = layer.w13_weight.size() - dtype_w2 = layer.w2_weight.dtype - _, n_w2, k_w2 = layer.w2_weight.size() - if ( - envs.VLLM_CPU_SGL_KERNEL - and check_cpu_sgl_kernel(n_w13, k_w13, dtype_w13) - and check_cpu_sgl_kernel(n_w2, k_w2, dtype_w2) - ): - packed_w13_weight = torch.ops._C.convert_weight_packed( - layer.w13_weight - ) - assert packed_w13_weight.size() == layer.w13_weight.size() - layer.w13_weight.copy_(packed_w13_weight) - del packed_w13_weight - packed_w2_weight = torch.ops._C.convert_weight_packed( - layer.w2_weight - ) - assert packed_w2_weight.size() == layer.w2_weight.size() - layer.w2_weight.copy_(packed_w2_weight) - layer.cpu_fused_moe = cpu_fused_moe.SGLFusedMOE(layer) - else: - layer.cpu_fused_moe = cpu_fused_moe.IPEXFusedMOE(layer) - else: - layer.cpu_fused_moe = cpu_fused_moe.CPUFusedMOE(layer) - - def apply( - self, - layer: torch.nn.Module, - x: torch.Tensor, - router_logits: torch.Tensor, - top_k: int, - renormalize: bool, - use_grouped_topk: bool = False, - topk_group: int | None = None, - num_expert_group: int | None = None, - global_num_experts: int = -1, - expert_map: torch.Tensor | None = None, - custom_routing_function: Callable | None = None, - scoring_func: str = "softmax", - routed_scaling_factor: float = 1.0, - e_score_correction_bias: torch.Tensor | None = None, - apply_router_weight_on_input: bool = False, - activation: str = "silu", - enable_eplb: bool = False, - expert_load_view: torch.Tensor | None = None, - logical_to_physical_map: torch.Tensor | None = None, - logical_replica_count: torch.Tensor | None = None, - ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: - if enable_eplb: - assert expert_load_view is not None - assert logical_to_physical_map is not None - assert logical_replica_count is not None - assert isinstance(layer, FusedMoE) - - return self.forward( - x=x, - layer=layer, - router_logits=router_logits, - top_k=top_k, - renormalize=renormalize, - use_grouped_topk=use_grouped_topk, - topk_group=topk_group, - num_expert_group=num_expert_group, - global_num_experts=global_num_experts, - expert_map=expert_map, - custom_routing_function=custom_routing_function, - scoring_func=scoring_func, - routed_scaling_factor=routed_scaling_factor, - e_score_correction_bias=e_score_correction_bias, - activation=activation, - apply_router_weight_on_input=apply_router_weight_on_input, - enable_eplb=enable_eplb, - expert_load_view=expert_load_view, - logical_to_physical_map=logical_to_physical_map, - logical_replica_count=logical_replica_count, - ) - - def get_fused_moe_quant_config( - self, layer: torch.nn.Module - ) -> FusedMoEQuantConfig | None: - if self.moe.has_bias: - return biased_moe_quant_config( - layer.w13_bias, - layer.w2_bias, - ) - else: - return FUSED_MOE_UNQUANTIZED_CONFIG - - def forward_cuda( - self, - layer: torch.nn.Module, - x: torch.Tensor, - use_grouped_topk: bool, - top_k: int, - router_logits: torch.Tensor, - renormalize: bool, - topk_group: int | None = None, - num_expert_group: int | None = None, - global_num_experts: int = -1, - expert_map: torch.Tensor | None = None, - custom_routing_function: Callable | None = None, - scoring_func: str = "softmax", - routed_scaling_factor: float = 1.0, - e_score_correction_bias: torch.Tensor | None = None, - apply_router_weight_on_input: bool = False, - activation: str = "silu", - enable_eplb: bool = False, - expert_load_view: torch.Tensor | None = None, - logical_to_physical_map: torch.Tensor | None = None, - logical_replica_count: torch.Tensor | None = None, - ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: - zero_expert_num = getattr(layer, "zero_expert_num", 0) - zero_expert_type = getattr(layer, "zero_expert_type", None) - - topk_weights, topk_ids, zero_expert_result = FusedMoE.select_experts( - hidden_states=x, - router_logits=router_logits, - use_grouped_topk=use_grouped_topk, - top_k=top_k, - renormalize=renormalize, - topk_group=topk_group, - num_expert_group=num_expert_group, - custom_routing_function=custom_routing_function, - scoring_func=scoring_func, - routed_scaling_factor=routed_scaling_factor, - e_score_correction_bias=e_score_correction_bias, - indices_type=self.topk_indices_dtype, - enable_eplb=enable_eplb, - expert_map=expert_map, - expert_load_view=expert_load_view, - logical_to_physical_map=logical_to_physical_map, - logical_replica_count=logical_replica_count, - global_num_experts=global_num_experts, - zero_expert_num=zero_expert_num, - zero_expert_type=zero_expert_type, - num_fused_shared_experts=layer.num_fused_shared_experts, - ) - - if self.rocm_aiter_moe_enabled: - result = self.rocm_aiter_fused_experts( - hidden_states=x, - w1=layer.w13_weight, - w2=layer.w2_weight, - topk_weights=topk_weights, - topk_ids=topk_ids, - expert_map=expert_map, - activation=activation, - apply_router_weight_on_input=apply_router_weight_on_input, - ) - elif self.flashinfer_cutlass_moe_enabled: - return self.flashinfer_cutlass_moe( - hidden_states=x, - w1=layer.w13_weight, - w2=layer.w2_weight, - topk_weights=topk_weights, - topk_ids=topk_ids, - activation=activation, - apply_router_weight_on_input=apply_router_weight_on_input, - ) - else: - result = fused_experts( - hidden_states=x, - w1=layer.w13_weight, - w2=layer.w2_weight, - topk_weights=topk_weights, - topk_ids=topk_ids, - inplace=True, - activation=activation, - quant_config=self.moe_quant_config, - apply_router_weight_on_input=apply_router_weight_on_input, - global_num_experts=global_num_experts, - expert_map=expert_map, - ) - - if zero_expert_num != 0 and zero_expert_type is not None: - assert not isinstance(result, tuple), ( - "Shared + zero experts are mutually exclusive not yet supported" - ) - return result, zero_expert_result - else: - return result - - def forward_cpu( - self, - layer: torch.nn.Module, - x: torch.Tensor, - use_grouped_topk: bool, - top_k: int, - router_logits: torch.Tensor, - renormalize: bool, - topk_group: int | None = None, - num_expert_group: int | None = None, - global_num_experts: int = -1, - expert_map: torch.Tensor | None = None, - custom_routing_function: Callable | None = None, - scoring_func: str = "softmax", - routed_scaling_factor: float = 1.0, - e_score_correction_bias: torch.Tensor | None = None, - apply_router_weight_on_input: bool = False, - activation: str = "silu", - enable_eplb: bool = False, - expert_load_view: torch.Tensor | None = None, - logical_to_physical_map: torch.Tensor | None = None, - logical_replica_count: torch.Tensor | None = None, - ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: - if ( - enable_eplb is not False - or expert_load_view is not None - or logical_to_physical_map is not None - or logical_replica_count is not None - ): - raise NotImplementedError("Expert load balancing is not supported for CPU.") - return layer.cpu_fused_moe( - layer, - x, - use_grouped_topk, - top_k, - router_logits, - renormalize, - topk_group, - num_expert_group, - global_num_experts, - expert_map, - custom_routing_function, - scoring_func, - routed_scaling_factor, - e_score_correction_bias, - apply_router_weight_on_input, - activation, - ) - - def forward_xpu( - self, - layer: torch.nn.Module, - x: torch.Tensor, - use_grouped_topk: bool, - top_k: int, - router_logits: torch.Tensor, - renormalize: bool, - topk_group: int | None = None, - num_expert_group: int | None = None, - global_num_experts: int = -1, - expert_map: torch.Tensor | None = None, - custom_routing_function: Callable | None = None, - scoring_func: str = "softmax", - routed_scaling_factor: float = 1.0, - e_score_correction_bias: torch.Tensor | None = None, - apply_router_weight_on_input: bool = False, - activation: str = "silu", - enable_eplb: bool = False, - expert_load_view: torch.Tensor | None = None, - logical_to_physical_map: torch.Tensor | None = None, - logical_replica_count: torch.Tensor | None = None, - ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: - if ( - enable_eplb is not False - or expert_load_view is not None - or logical_to_physical_map is not None - or logical_replica_count is not None - ): - raise NotImplementedError("Expert load balancing is not supported for XPU.") - return layer.ipex_fusion( - x, - use_grouped_topk, - top_k, - router_logits, - renormalize, - topk_group, - num_expert_group, - custom_routing_function=custom_routing_function, - ) - - def forward_tpu( - self, - layer: torch.nn.Module, - x: torch.Tensor, - use_grouped_topk: bool, - top_k: int, - router_logits: torch.Tensor, - renormalize: bool, - topk_group: int | None = None, - num_expert_group: int | None = None, - global_num_experts: int = -1, - expert_map: torch.Tensor | None = None, - custom_routing_function: Callable | None = None, - scoring_func: str = "softmax", - routed_scaling_factor: float = 1.0, - e_score_correction_bias: torch.Tensor | None = None, - apply_router_weight_on_input: bool = False, - activation: str = "silu", - enable_eplb: bool = False, - expert_load_view: torch.Tensor | None = None, - logical_to_physical_map: torch.Tensor | None = None, - logical_replica_count: torch.Tensor | None = None, - ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: - assert not use_grouped_topk - assert num_expert_group is None - assert topk_group is None - assert custom_routing_function is None - assert apply_router_weight_on_input is False - if scoring_func != "softmax": - raise NotImplementedError( - "Only softmax scoring function is supported for TPU." - ) - if e_score_correction_bias is not None: - raise NotImplementedError( - "Expert score correction bias is not supported for TPU." - ) - assert activation == "silu", f"{activation} is not supported for TPU." - assert routed_scaling_factor == 1.0, ( - f"routed_scaling_factor {routed_scaling_factor} is not supported for TPU." - ) - if ( - enable_eplb is not False - or expert_load_view is not None - or logical_to_physical_map is not None - or logical_replica_count is not None - ): - raise NotImplementedError("Expert load balancing is not supported for TPU.") - return fused_moe_pallas( - hidden_states=x, - w1=layer.w13_weight, - w2=layer.w2_weight, - topk=top_k, - gating_output=router_logits, - global_num_experts=global_num_experts, - expert_map=expert_map, - renormalize=renormalize, - ) - - if current_platform.is_tpu(): - forward_native = forward_tpu - elif current_platform.is_cpu(): - forward_native = forward_cpu - elif current_platform.is_xpu(): - forward_native = forward_xpu - else: - forward_native = forward_cuda - - def determine_expert_map( ep_size: int, ep_rank: int, @@ -1125,16 +233,13 @@ def maybe_roundup_hidden_size( Rounded up hidden_size if rounding up is required based on the configs. Original hidden size otherwise. """ + from vllm.model_executor.layers.fused_moe.all2all_utils import ( + maybe_roundup_layer_hidden_size, + ) - if moe_parallel_config.use_deepep_ht_kernels: - hidden_size = DeepEPHTPrepareAndFinalize.maybe_roundup_layer_hidden_size( - hidden_size, act_dtype - ) - - if moe_parallel_config.use_deepep_ll_kernels: - hidden_size = DeepEPLLPrepareAndFinalize.maybe_roundup_layer_hidden_size( - hidden_size - ) + hidden_size = maybe_roundup_layer_hidden_size( + hidden_size, act_dtype, moe_parallel_config + ) # we are padding globally so EP buffer allocation works if quant_config and quant_config.get_name() == "mxfp4": @@ -1430,7 +535,6 @@ def __init__( is_lora_enabled=vllm_config.lora_config is not None, ) - self.moe_quant_config: FusedMoEQuantConfig | None = None self.quant_config = quant_config def _get_quant_method() -> FusedMoEMethodBase: @@ -1508,9 +612,15 @@ def _get_quant_method() -> FusedMoEMethodBase: # This is called after all weight loading and post-processing, so it # should be safe to swap out the quant_method. def maybe_init_modular_kernel(self) -> None: - mk = self.quant_method.maybe_init_modular_kernel(self) - if mk is not None: - self.quant_method = FusedMoEModularMethod(self.quant_method, mk) + self.ensure_moe_quant_config_init() + prepare_finalize = self.quant_method.maybe_make_prepare_finalize() + if prepare_finalize is not None: + logger.debug( + "%s for %s(%s)", prepare_finalize.__class__.__name__, self, id(self) + ) + self.quant_method = FusedMoEModularMethod.make( + self, self.quant_method, prepare_finalize, self.shared_experts + ) @property def shared_experts(self) -> torch.nn.Module | None: @@ -2142,12 +1252,16 @@ def set_eplb_state( def ensure_moe_quant_config_init(self): if self.quant_method.moe_quant_config is None: + # Note: the moe_quant_config can't be constructed until after + # weight loading post processing. self.quant_method.moe_quant_config = ( self.quant_method.get_fused_moe_quant_config(self) ) - if self.moe_quant_config is None: - self.moe_quant_config = self.quant_method.moe_quant_config + @property + def moe_quant_config(self) -> FusedMoEQuantConfig | None: + self.ensure_moe_quant_config_init() + return self.quant_method.moe_quant_config def ensure_dp_chunking_init(self): if not self.use_dp_chunking or self.batched_hidden_states is not None: diff --git a/vllm/model_executor/layers/fused_moe/shared_fused_moe.py b/vllm/model_executor/layers/fused_moe/shared_fused_moe.py index 3d0c5636d6c0..06112ca51b6d 100644 --- a/vllm/model_executor/layers/fused_moe/shared_fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/shared_fused_moe.py @@ -38,7 +38,7 @@ def __init__( and not ( # TODO(wentao): find the root cause and remove this condition self.enable_eplb - or (self.use_flashinfer_cutlass_kernels and self.dp_size > 1) + or (self.moe_config.use_flashinfer_cutlass_kernels and self.dp_size > 1) ) and self._shared_experts is not None ) diff --git a/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py b/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py new file mode 100644 index 000000000000..ce56887f1c26 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py @@ -0,0 +1,578 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from collections.abc import Callable + +import torch +import torch.nn.functional as F + +import vllm.envs as envs +from vllm._aiter_ops import rocm_aiter_ops +from vllm.logger import init_logger +from vllm.model_executor.custom_op import CustomOp +from vllm.model_executor.layers.fused_moe.config import ( + FUSED_MOE_UNQUANTIZED_CONFIG, + FusedMoEConfig, + FusedMoEQuantConfig, + biased_moe_quant_config, +) +from vllm.model_executor.layers.fused_moe.fused_moe_method_base import ( + FusedMoEMethodBase, +) +from vllm.model_executor.layers.fused_moe.modular_kernel import ( + FusedMoEActivationFormat, + FusedMoEPermuteExpertsUnpermute, + FusedMoEPrepareAndFinalize, +) +from vllm.model_executor.utils import set_weight_attrs +from vllm.platforms import current_platform +from vllm.platforms.interface import CpuArchEnum +from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe + +if current_platform.is_cuda_alike(): + from .fused_batched_moe import BatchedTritonExperts + from .fused_moe import TritonExperts, fused_experts +else: + fused_experts = None # type: ignore + +if current_platform.is_tpu(): + from .moe_pallas import fused_moe as fused_moe_pallas +else: + fused_moe_pallas = None # type: ignore + +logger = init_logger(__name__) + + +@CustomOp.register("unquantized_fused_moe") +class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): + """MoE method without quantization.""" + + def __init__(self, moe: FusedMoEConfig): + super().__init__(moe) + + self.rocm_aiter_moe_enabled = rocm_aiter_ops.is_fused_moe_enabled() + if self.rocm_aiter_moe_enabled: + from .rocm_aiter_fused_moe import rocm_aiter_fused_experts + + self.rocm_aiter_fused_experts = rocm_aiter_fused_experts + else: + self.rocm_aiter_fused_experts = None # type: ignore + + # FlashInfer CUTLASS MoE is only supported on Hopper and later GPUS + self.flashinfer_cutlass_moe_enabled = ( + has_flashinfer_cutlass_fused_moe() + and envs.VLLM_USE_FLASHINFER_MOE_FP16 + and self.moe.moe_parallel_config.use_ep + and self.moe.moe_parallel_config.dp_size == 1 + and current_platform.get_device_capability()[0] >= 9 + ) + if self.flashinfer_cutlass_moe_enabled: + logger.info_once( + "Enabling FlashInfer CUTLASS MoE for UnquantizedFusedMoEMethod" + ) + from functools import partial + + from .flashinfer_cutlass_moe import flashinfer_cutlass_moe + + self.flashinfer_cutlass_moe = partial( + flashinfer_cutlass_moe, + quant_config=FUSED_MOE_UNQUANTIZED_CONFIG, + tp_rank=self.moe.moe_parallel_config.tp_rank, + tp_size=self.moe.moe_parallel_config.tp_size, + ep_rank=self.moe.moe_parallel_config.ep_rank, + ep_size=self.moe.moe_parallel_config.ep_size, + ) + else: + if ( + self.moe.moe_parallel_config.use_ep + and self.moe.moe_parallel_config.dp_size == 1 + ): + logger.info_once( + "FlashInfer CUTLASS MoE is available for EP" + " but not enabled, consider setting" + " VLLM_USE_FLASHINFER_MOE_FP16=1 to enable it.", + scope="local", + ) + elif self.moe.moe_parallel_config.dp_size > 1: + logger.info_once( + "FlashInfer CUTLASS MoE is currently not available for DP.", + scope="local", + ) + self.flashinfer_cutlass_moe = None # type: ignore + + @property + def supports_eplb(self) -> bool: + return True + + @property + def allow_inplace(self) -> bool: + return True + + def maybe_make_prepare_finalize(self) -> FusedMoEPrepareAndFinalize | None: + if self.rocm_aiter_moe_enabled: + return None + else: + return super().maybe_make_prepare_finalize() + + def select_gemm_impl( + self, + prepare_finalize: FusedMoEPrepareAndFinalize, + layer: torch.nn.Module, + ) -> FusedMoEPermuteExpertsUnpermute: + assert self.moe_quant_config is not None + if ( + prepare_finalize.activation_format + == FusedMoEActivationFormat.BatchedExperts + ): + logger.debug("BatchedTritonExperts %s", self.moe) + return BatchedTritonExperts( + max_num_tokens=self.moe.max_num_tokens, + num_dispatchers=prepare_finalize.num_dispatchers(), + quant_config=self.moe_quant_config, + ) + else: + logger.debug("TritonExperts %s", self.moe) + return TritonExperts(self.moe_quant_config) + + def create_weights( + self, + layer: torch.nn.Module, + num_experts: int, + hidden_size: int, + intermediate_size_per_partition: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + if self.moe.is_act_and_mul: + w13_up_dim = 2 * intermediate_size_per_partition + else: + w13_up_dim = intermediate_size_per_partition + # Fused gate_up_proj (column parallel) + w13_weight = torch.nn.Parameter( + torch.empty( + num_experts, + w13_up_dim, + hidden_size, + dtype=params_dtype, + ), + requires_grad=False, + ) + layer.register_parameter("w13_weight", w13_weight) + set_weight_attrs(w13_weight, extra_weight_attrs) + if self.moe.has_bias: + w13_bias = torch.nn.Parameter( + torch.zeros(num_experts, w13_up_dim, dtype=params_dtype), + requires_grad=False, + ) + layer.register_parameter("w13_bias", w13_bias) + set_weight_attrs(w13_bias, extra_weight_attrs) + # down_proj (row parallel) + w2_weight = torch.nn.Parameter( + torch.empty( + num_experts, + hidden_size, + intermediate_size_per_partition, + dtype=params_dtype, + ), + requires_grad=False, + ) + layer.register_parameter("w2_weight", w2_weight) + set_weight_attrs(w2_weight, extra_weight_attrs) + if self.moe.has_bias: + w2_bias = torch.nn.Parameter( + torch.zeros(num_experts, hidden_size, dtype=params_dtype), + requires_grad=False, + ) + layer.register_parameter("w2_bias", w2_bias) + set_weight_attrs(w2_bias, extra_weight_attrs) + + def _maybe_pad_weight(self, weight: torch.Tensor) -> torch.Tensor: + # Pad the weight tensor. This is an optimization on ROCm platform, which + # can benefit from tensors located far enough from one another in memory + if ( + envs.VLLM_ROCM_MOE_PADDING + and current_platform.is_rocm() + and weight.stride(-1) == 1 + and (weight.stride(-2) * weight.element_size()) % 512 == 0 + ): + num_pad = 256 // weight.element_size() + weight = F.pad(weight, (0, num_pad), "constant", 0)[..., :-num_pad] + torch.cuda.empty_cache() + + return weight + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + super().process_weights_after_loading(layer) + + # Padding the weight for better performance on ROCm + layer.w13_weight.data = self._maybe_pad_weight(layer.w13_weight.data) + layer.w2_weight.data = self._maybe_pad_weight(layer.w2_weight.data) + + if self.rocm_aiter_moe_enabled: + shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights( + layer.w13_weight.data, layer.w2_weight.data + ) + + layer.w13_weight.data = shuffled_w13 + layer.w2_weight.data = shuffled_w2 + + if self.flashinfer_cutlass_moe_enabled: + # Swap halves to arrange as [w3; w1] (kernel expectation) + w1_w, w3_w = torch.chunk(layer.w13_weight.data, 2, dim=1) + w13_weight_swapped = torch.cat([w3_w, w1_w], dim=1) + layer.w13_weight.data = w13_weight_swapped.contiguous() + + if current_platform.is_xpu(): + import intel_extension_for_pytorch as ipex + + ep_rank_start = self.moe.ep_rank * self.moe.num_local_experts + layer.ipex_fusion = ipex.llm.modules.GatedMLPMOE( + layer.w13_weight, + layer.w2_weight, + use_prepack=True, + experts_start_id=ep_rank_start, + ) + elif current_platform.is_cpu(): + from vllm.model_executor.layers.fused_moe import cpu_fused_moe + + if current_platform.get_cpu_architecture() == CpuArchEnum.X86: + from vllm.model_executor.layers.utils import check_cpu_sgl_kernel + + dtype_w13 = layer.w13_weight.dtype + _, n_w13, k_w13 = layer.w13_weight.size() + dtype_w2 = layer.w2_weight.dtype + _, n_w2, k_w2 = layer.w2_weight.size() + if ( + envs.VLLM_CPU_SGL_KERNEL + and check_cpu_sgl_kernel(n_w13, k_w13, dtype_w13) + and check_cpu_sgl_kernel(n_w2, k_w2, dtype_w2) + ): + packed_w13_weight = torch.ops._C.convert_weight_packed( + layer.w13_weight + ) + assert packed_w13_weight.size() == layer.w13_weight.size() + layer.w13_weight.copy_(packed_w13_weight) + del packed_w13_weight + packed_w2_weight = torch.ops._C.convert_weight_packed( + layer.w2_weight + ) + assert packed_w2_weight.size() == layer.w2_weight.size() + layer.w2_weight.copy_(packed_w2_weight) + layer.cpu_fused_moe = cpu_fused_moe.SGLFusedMOE(layer) + else: + layer.cpu_fused_moe = cpu_fused_moe.IPEXFusedMOE(layer) + else: + layer.cpu_fused_moe = cpu_fused_moe.CPUFusedMOE(layer) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool = False, + topk_group: int | None = None, + num_expert_group: int | None = None, + global_num_experts: int = -1, + expert_map: torch.Tensor | None = None, + custom_routing_function: Callable | None = None, + scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, + e_score_correction_bias: torch.Tensor | None = None, + apply_router_weight_on_input: bool = False, + activation: str = "silu", + enable_eplb: bool = False, + expert_load_view: torch.Tensor | None = None, + logical_to_physical_map: torch.Tensor | None = None, + logical_replica_count: torch.Tensor | None = None, + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: + if enable_eplb: + assert expert_load_view is not None + assert logical_to_physical_map is not None + assert logical_replica_count is not None + + return self.forward( + x=x, + layer=layer, + router_logits=router_logits, + top_k=top_k, + renormalize=renormalize, + use_grouped_topk=use_grouped_topk, + topk_group=topk_group, + num_expert_group=num_expert_group, + global_num_experts=global_num_experts, + expert_map=expert_map, + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + routed_scaling_factor=routed_scaling_factor, + e_score_correction_bias=e_score_correction_bias, + activation=activation, + apply_router_weight_on_input=apply_router_weight_on_input, + enable_eplb=enable_eplb, + expert_load_view=expert_load_view, + logical_to_physical_map=logical_to_physical_map, + logical_replica_count=logical_replica_count, + ) + + def get_fused_moe_quant_config( + self, layer: torch.nn.Module + ) -> FusedMoEQuantConfig | None: + if self.moe.has_bias: + return biased_moe_quant_config( + layer.w13_bias, + layer.w2_bias, + ) + else: + return FUSED_MOE_UNQUANTIZED_CONFIG + + def forward_cuda( + self, + layer: torch.nn.Module, + x: torch.Tensor, + use_grouped_topk: bool, + top_k: int, + router_logits: torch.Tensor, + renormalize: bool, + topk_group: int | None = None, + num_expert_group: int | None = None, + global_num_experts: int = -1, + expert_map: torch.Tensor | None = None, + custom_routing_function: Callable | None = None, + scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, + e_score_correction_bias: torch.Tensor | None = None, + apply_router_weight_on_input: bool = False, + activation: str = "silu", + enable_eplb: bool = False, + expert_load_view: torch.Tensor | None = None, + logical_to_physical_map: torch.Tensor | None = None, + logical_replica_count: torch.Tensor | None = None, + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: + zero_expert_num = getattr(layer, "zero_expert_num", 0) + zero_expert_type = getattr(layer, "zero_expert_type", None) + + topk_weights, topk_ids, zero_expert_result = layer.select_experts( + hidden_states=x, + router_logits=router_logits, + use_grouped_topk=use_grouped_topk, + top_k=top_k, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + routed_scaling_factor=routed_scaling_factor, + e_score_correction_bias=e_score_correction_bias, + indices_type=self.topk_indices_dtype, + enable_eplb=enable_eplb, + expert_map=expert_map, + expert_load_view=expert_load_view, + logical_to_physical_map=logical_to_physical_map, + logical_replica_count=logical_replica_count, + global_num_experts=global_num_experts, + zero_expert_num=zero_expert_num, + zero_expert_type=zero_expert_type, + num_fused_shared_experts=layer.num_fused_shared_experts, + ) + + if self.rocm_aiter_moe_enabled: + result = self.rocm_aiter_fused_experts( + hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + expert_map=expert_map, + activation=activation, + apply_router_weight_on_input=apply_router_weight_on_input, + ) + elif self.flashinfer_cutlass_moe_enabled: + return self.flashinfer_cutlass_moe( + hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + activation=activation, + apply_router_weight_on_input=apply_router_weight_on_input, + ) + else: + result = fused_experts( + hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + inplace=True, + activation=activation, + quant_config=self.moe_quant_config, + apply_router_weight_on_input=apply_router_weight_on_input, + global_num_experts=global_num_experts, + expert_map=expert_map, + ) + + if zero_expert_num != 0 and zero_expert_type is not None: + assert not isinstance(result, tuple), ( + "Shared + zero experts are mutually exclusive not yet supported" + ) + return result, zero_expert_result + else: + return result + + def forward_cpu( + self, + layer: torch.nn.Module, + x: torch.Tensor, + use_grouped_topk: bool, + top_k: int, + router_logits: torch.Tensor, + renormalize: bool, + topk_group: int | None = None, + num_expert_group: int | None = None, + global_num_experts: int = -1, + expert_map: torch.Tensor | None = None, + custom_routing_function: Callable | None = None, + scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, + e_score_correction_bias: torch.Tensor | None = None, + apply_router_weight_on_input: bool = False, + activation: str = "silu", + enable_eplb: bool = False, + expert_load_view: torch.Tensor | None = None, + logical_to_physical_map: torch.Tensor | None = None, + logical_replica_count: torch.Tensor | None = None, + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: + if ( + enable_eplb is not False + or expert_load_view is not None + or logical_to_physical_map is not None + or logical_replica_count is not None + ): + raise NotImplementedError("Expert load balancing is not supported for CPU.") + return layer.cpu_fused_moe( + layer, + x, + use_grouped_topk, + top_k, + router_logits, + renormalize, + topk_group, + num_expert_group, + global_num_experts, + expert_map, + custom_routing_function, + scoring_func, + routed_scaling_factor, + e_score_correction_bias, + apply_router_weight_on_input, + activation, + ) + + def forward_xpu( + self, + layer: torch.nn.Module, + x: torch.Tensor, + use_grouped_topk: bool, + top_k: int, + router_logits: torch.Tensor, + renormalize: bool, + topk_group: int | None = None, + num_expert_group: int | None = None, + global_num_experts: int = -1, + expert_map: torch.Tensor | None = None, + custom_routing_function: Callable | None = None, + scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, + e_score_correction_bias: torch.Tensor | None = None, + apply_router_weight_on_input: bool = False, + activation: str = "silu", + enable_eplb: bool = False, + expert_load_view: torch.Tensor | None = None, + logical_to_physical_map: torch.Tensor | None = None, + logical_replica_count: torch.Tensor | None = None, + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: + if ( + enable_eplb is not False + or expert_load_view is not None + or logical_to_physical_map is not None + or logical_replica_count is not None + ): + raise NotImplementedError("Expert load balancing is not supported for XPU.") + return layer.ipex_fusion( + x, + use_grouped_topk, + top_k, + router_logits, + renormalize, + topk_group, + num_expert_group, + custom_routing_function=custom_routing_function, + ) + + def forward_tpu( + self, + layer: torch.nn.Module, + x: torch.Tensor, + use_grouped_topk: bool, + top_k: int, + router_logits: torch.Tensor, + renormalize: bool, + topk_group: int | None = None, + num_expert_group: int | None = None, + global_num_experts: int = -1, + expert_map: torch.Tensor | None = None, + custom_routing_function: Callable | None = None, + scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, + e_score_correction_bias: torch.Tensor | None = None, + apply_router_weight_on_input: bool = False, + activation: str = "silu", + enable_eplb: bool = False, + expert_load_view: torch.Tensor | None = None, + logical_to_physical_map: torch.Tensor | None = None, + logical_replica_count: torch.Tensor | None = None, + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: + assert not use_grouped_topk + assert num_expert_group is None + assert topk_group is None + assert custom_routing_function is None + assert apply_router_weight_on_input is False + if scoring_func != "softmax": + raise NotImplementedError( + "Only softmax scoring function is supported for TPU." + ) + if e_score_correction_bias is not None: + raise NotImplementedError( + "Expert score correction bias is not supported for TPU." + ) + assert activation == "silu", f"{activation} is not supported for TPU." + assert routed_scaling_factor == 1.0, ( + f"routed_scaling_factor {routed_scaling_factor} is not supported for TPU." + ) + if ( + enable_eplb is not False + or expert_load_view is not None + or logical_to_physical_map is not None + or logical_replica_count is not None + ): + raise NotImplementedError("Expert load balancing is not supported for TPU.") + return fused_moe_pallas( + hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight, + topk=top_k, + gating_output=router_logits, + global_num_experts=global_num_experts, + expert_map=expert_map, + renormalize=renormalize, + ) + + if current_platform.is_tpu(): + forward_native = forward_tpu + elif current_platform.is_cpu(): + forward_native = forward_cpu + elif current_platform.is_xpu(): + forward_native = forward_xpu + else: + forward_native = forward_cuda diff --git a/vllm/model_executor/layers/quantization/mxfp4.py b/vllm/model_executor/layers/quantization/mxfp4.py index e339f15510d7..4e51249f2d25 100644 --- a/vllm/model_executor/layers/quantization/mxfp4.py +++ b/vllm/model_executor/layers/quantization/mxfp4.py @@ -741,15 +741,10 @@ def _interleave_mxfp4_cutlass_sm90(w): weight_scale=w2_scale, flex_ctx=FlexCtx(rhs_data=w2_flex) ) - self.w13_weight_triton_tensor = w13_weight - self.w2_weight_triton_tensor = w2_weight - - # need to delete the original weights to save memory on single GPU - del layer.w13_weight - del layer.w2_weight - layer.w13_weight = None - layer.w2_weight = None - torch.cuda.empty_cache() + self.w13_weight = w13_weight + self.w2_weight = w2_weight + layer.w13_weight = w13_weight + layer.w2_weight = w2_weight else: raise ValueError(f"Unsupported backend: {self.mxfp4_backend}") @@ -824,18 +819,6 @@ def select_gemm_impl( "EP batched experts format" ) else: - layer.w13_weight = ( - self.w13_weight_triton_tensor - if layer.w13_weight is None - else layer.w13_weight - ) - layer.w2_weight = ( - self.w2_weight_triton_tensor - if layer.w2_weight is None - else layer.w2_weight - ) - assert all([w is not None for w in [layer.w13_weight, layer.w2_weight]]) - assert self.moe_quant_config is not None if ( self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM @@ -1070,8 +1053,8 @@ def apply( return triton_kernel_moe_forward( hidden_states=x, - w1=self.w13_weight_triton_tensor, - w2=self.w2_weight_triton_tensor, + w1=self.w13_weight, + w2=self.w2_weight, gating_output=router_logits, topk=top_k, renormalize=renormalize, From 533b018f725fb9c2421e2c4b5a48d62fa5f1d844 Mon Sep 17 00:00:00 2001 From: jvlunteren <161835099+jvlunteren@users.noreply.github.com> Date: Tue, 11 Nov 2025 15:41:43 +0100 Subject: [PATCH 058/183] [BugFix] Fix Failing Ruff Check (#28469) Signed-off-by: Jan van Lunteren --- tests/compile/test_fusions_e2e.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/compile/test_fusions_e2e.py b/tests/compile/test_fusions_e2e.py index f67063cdf42e..e1560efb3f24 100644 --- a/tests/compile/test_fusions_e2e.py +++ b/tests/compile/test_fusions_e2e.py @@ -74,7 +74,7 @@ class ModelBackendTestCase(NamedTuple): ModelBackendTestCase( model_name="Qwen/Qwen3-30B-A3B", model_kwargs=dict(max_model_len=1024), - backend=_Backend.TRITON_ATTN, + backend=AttentionBackendEnum.TRITON_ATTN, attention_fusions=0, allreduce_fusions=97, ), From a90ad7d838b446cfc2dd7b4252086e13c3a8abbf Mon Sep 17 00:00:00 2001 From: Mark McLoughlin Date: Tue, 11 Nov 2025 15:03:22 +0000 Subject: [PATCH 059/183] Add @markmc to CODEOWNERS for Observability (#28457) Signed-off-by: Mark McLoughlin --- .github/CODEOWNERS | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 23def076cf88..f26c782bccf2 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -61,6 +61,16 @@ CMakeLists.txt @tlrmchlsmth @LucasWilkinson /vllm/model_executor/models/transformers @hmellor /tests/models/test_transformers.py @hmellor +# Observability +/vllm/config/observability.py @markmc +/vllm/v1/metrics @markmc +/tests/v1/metrics @markmc +/vllm/tracing.py @markmc +/tests/v1/tracing/test_tracing.py @markmc +/vllm/config/kv_events.py @markmc +/vllm/distributed/kv_events.py @markmc +/tests/distributed/test_events.py @markmc + # Docs /docs/mkdocs @hmellor /docs/**/*.yml @hmellor From b886068056a05857f796909d2f8573b36fc668a5 Mon Sep 17 00:00:00 2001 From: Fanli Lin Date: Tue, 11 Nov 2025 23:29:33 +0800 Subject: [PATCH 060/183] [BugFix] Fix RuntimeError in PixtralHFAttention on CPU/XPU (#28444) Signed-off-by: Lin, Fanli --- vllm/model_executor/models/pixtral.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/models/pixtral.py b/vllm/model_executor/models/pixtral.py index 0555717017cd..dfe5f0c52a50 100644 --- a/vllm/model_executor/models/pixtral.py +++ b/vllm/model_executor/models/pixtral.py @@ -1109,7 +1109,7 @@ def forward( ) out = out.transpose(1, 2) - out = out.view(batch, patches, self.n_heads * self.head_dim) + out = out.reshape(batch, patches, self.n_heads * self.head_dim) attn_output, _ = self.o_proj(out) return attn_output, None From 3143eb23fc4e017bc31d11a9756d5a788d6f7e33 Mon Sep 17 00:00:00 2001 From: usberkeley <150880684+usberkeley@users.noreply.github.com> Date: Wed, 12 Nov 2025 00:01:30 +0800 Subject: [PATCH 061/183] [BugFix] Add test_outputs.py to CI pipeline (#28466) Signed-off-by: Bradley Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- .buildkite/test-amd.yaml | 1 + .buildkite/test-pipeline.yaml | 1 + 2 files changed, 2 insertions(+) diff --git a/.buildkite/test-amd.yaml b/.buildkite/test-amd.yaml index bb5ef5d62463..5fd048c2ad0c 100644 --- a/.buildkite/test-amd.yaml +++ b/.buildkite/test-amd.yaml @@ -348,6 +348,7 @@ steps: - pytest -v -s -m 'not cpu_test' v1/metrics - pytest -v -s v1/test_oracle.py - pytest -v -s v1/test_request.py + - pytest -v -s v1/test_outputs.py # Integration test for streaming correctness (requires special branch). - pip install -U git+https://github.com/robertgshaw2-redhat/lm-evaluation-harness.git@streaming-api - pytest -v -s entrypoints/openai/correctness/test_lmeval.py::test_lm_eval_accuracy_v1_engine diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 83a7df3b093f..25f711dd60b3 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -329,6 +329,7 @@ steps: - pytest -v -s -m 'not cpu_test' v1/metrics - pytest -v -s v1/test_oracle.py - pytest -v -s v1/test_request.py + - pytest -v -s v1/test_outputs.py # Integration test for streaming correctness (requires special branch). - pip install -U git+https://github.com/robertgshaw2-redhat/lm-evaluation-harness.git@streaming-api - pytest -v -s entrypoints/openai/correctness/test_lmeval.py::test_lm_eval_accuracy_v1_engine From 287bbbeb067cd9e16ea9b834b35b47258a8ad43f Mon Sep 17 00:00:00 2001 From: the-codeboy <71213855+the-codeboy@users.noreply.github.com> Date: Tue, 11 Nov 2025 17:45:49 +0100 Subject: [PATCH 062/183] [Doc] Fix typo in serving docs (#28474) Signed-off-by: the-codeboy <71213855+the-codeboy@users.noreply.github.com> --- docs/serving/openai_compatible_server.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/serving/openai_compatible_server.md b/docs/serving/openai_compatible_server.md index e331b3422ea6..821628e6e317 100644 --- a/docs/serving/openai_compatible_server.md +++ b/docs/serving/openai_compatible_server.md @@ -77,11 +77,11 @@ In addition, we have the following custom APIs: In order for the language model to support chat protocol, vLLM requires the model to include a chat template in its tokenizer configuration. The chat template is a Jinja2 template that -specifies how are roles, messages, and other chat-specific tokens are encoded in the input. +specifies how roles, messages, and other chat-specific tokens are encoded in the input. An example chat template for `NousResearch/Meta-Llama-3-8B-Instruct` can be found [here](https://github.com/meta-llama/llama3?tab=readme-ov-file#instruction-tuned-models) -Some models do not provide a chat template even though they are instruction/chat fine-tuned. For those model, +Some models do not provide a chat template even though they are instruction/chat fine-tuned. For those models, you can manually specify their chat template in the `--chat-template` parameter with the file path to the chat template, or the template in string form. Without a chat template, the server will not be able to process chat and all chat requests will error. From f9a4087182ffcd9404779fcda876f820b3b26d5f Mon Sep 17 00:00:00 2001 From: Michael Goin Date: Tue, 11 Nov 2025 09:46:04 -0700 Subject: [PATCH 063/183] Remove weight_scale.T special case for SM90 Block FP8 CUTLASS kernel (#28431) Signed-off-by: mgoin --- benchmarks/kernels/bench_block_fp8_gemm.py | 43 +++++++++++++------ .../scaled_mm_blockwise_sm90_fp8_dispatch.cuh | 3 +- .../schemes/compressed_tensors_w8a8_fp8.py | 2 +- .../model_executor/layers/quantization/fp8.py | 2 +- .../layers/quantization/utils/fp8_utils.py | 22 ++-------- 5 files changed, 36 insertions(+), 36 deletions(-) diff --git a/benchmarks/kernels/bench_block_fp8_gemm.py b/benchmarks/kernels/bench_block_fp8_gemm.py index f1e504499eaf..11e3ac7f0c1f 100644 --- a/benchmarks/kernels/bench_block_fp8_gemm.py +++ b/benchmarks/kernels/bench_block_fp8_gemm.py @@ -1,10 +1,18 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import os + +# Disable DeepGEMM for this benchmark to use CUTLASS +os.environ["VLLM_USE_DEEP_GEMM"] = "0" + import torch from vllm.model_executor.layers.quantization.utils.fp8_utils import ( - apply_w8a8_block_fp8_linear, + W8A8BlockFp8LinearOp, +) +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + GroupShape, ) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( CUTLASS_BLOCK_FP8_SUPPORTED, @@ -39,13 +47,14 @@ def build_w8a8_block_fp8_runner(M, N, K, block_size, device, use_cutlass): fp8_info = torch.finfo(torch.float8_e4m3fn) fp8_max, fp8_min = fp8_info.max, fp8_info.min - # Create random FP8 tensors + # Create random input tensor (bfloat16, will be quantized by W8A8BlockFp8LinearOp) A_ref = (torch.rand(M, K, dtype=torch.bfloat16, device=device) - 0.5) * 2 * fp8_max + # Create quantized weight tensor B_ref = (torch.rand(N, K, dtype=torch.bfloat16, device=device) - 0.5) * 2 * fp8_max B = B_ref.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) - # Create scales + # Create weight scales block_n, block_k = block_size[0], block_size[1] n_tiles = (N + block_n - 1) // block_n k_tiles = (K + block_k - 1) // block_k @@ -55,19 +64,25 @@ def build_w8a8_block_fp8_runner(M, N, K, block_size, device, use_cutlass): * factor_for_scale ) - # SM90 CUTLASS requires row-major format for scales - if use_cutlass and current_platform.is_device_capability(90): - Bs = Bs.T.contiguous() + # Create W8A8BlockFp8LinearOp instance + weight_group_shape = GroupShape(block_n, block_k) + act_quant_group_shape = GroupShape(1, block_k) # Per-token, per-group quantization + + linear_op = W8A8BlockFp8LinearOp( + weight_group_shape=weight_group_shape, + act_quant_group_shape=act_quant_group_shape, + cutlass_block_fp8_supported=use_cutlass, + use_aiter_and_is_supported=False, + ) def run(): - if use_cutlass: - return apply_w8a8_block_fp8_linear( - A_ref, B, block_size, Bs, cutlass_block_fp8_supported=True - ) - else: - return apply_w8a8_block_fp8_linear( - A_ref, B, block_size, Bs, cutlass_block_fp8_supported=False - ) + return linear_op.apply( + input=A_ref, + weight=B, + weight_scale=Bs, + input_scale=None, + bias=None, + ) return run diff --git a/csrc/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm90_fp8_dispatch.cuh b/csrc/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm90_fp8_dispatch.cuh index 147eb8efc077..c40d49966271 100644 --- a/csrc/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm90_fp8_dispatch.cuh +++ b/csrc/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm90_fp8_dispatch.cuh @@ -48,7 +48,8 @@ struct cutlass_3x_gemm_fp8_blockwise { using ElementBlockScale = float; using ScaleConfig = cutlass::detail::Sm90BlockwiseScaleConfig< - ScaleGranularityM, ScaleGranularityN, ScaleGranularityK>; + ScaleGranularityM, ScaleGranularityN, ScaleGranularityK, + cute::GMMA::Major::MN, cute::GMMA::Major::K>; using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA()); using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB()); diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py index 6da136cbc8f6..ee99572f5f49 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py @@ -173,7 +173,7 @@ def process_weights_after_loading(self, layer) -> None: layer.input_scale = None if self.strategy == QuantizationStrategy.BLOCK: - maybe_post_process_fp8_weight_block(layer, self.cutlass_block_fp8_supported) + maybe_post_process_fp8_weight_block(layer) def apply_weights( self, diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 83d136600b77..cb065eb68b66 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -540,7 +540,7 @@ def process_weights_after_loading(self, layer: Module) -> None: return if self.block_quant: - maybe_post_process_fp8_weight_block(layer, self.cutlass_block_fp8_supported) + maybe_post_process_fp8_weight_block(layer) def apply( self, diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index c63196b89357..0c54cf4def00 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -55,17 +55,13 @@ def cutlass_scaled_mm( Bs: torch.Tensor, block_size: list[int], output_dtype: torch.dtype = torch.float16, - is_hopper: bool | None = None, ) -> torch.Tensor: - if is_hopper is None: - is_hopper = current_platform.is_device_capability(90) return ops.cutlass_scaled_mm( A, B.T, out_dtype=output_dtype, scale_a=As, - # SM90 block FP8 requires row-major scale_b, which we do ahead of time - scale_b=Bs if block_size is not None and is_hopper else Bs.T, + scale_b=Bs.T, ) @@ -130,7 +126,7 @@ def _padded_cutlass( padded_x_scale[0 : x_scale.shape[0], ...].copy_(x_scale) output = cutlass_scaled_mm( - padded_qx, weight, padded_x_scale, weight_scale, block_size, output_dtype, True + padded_qx, weight, padded_x_scale, weight_scale, block_size, output_dtype ) return output[0 : qx.shape[0], ...] @@ -303,7 +299,6 @@ def _run_cutlass( weight_scale, list(self.weight_group_shape), input_2d.dtype, - False, ) def _run_aiter( @@ -1125,9 +1120,7 @@ def process_fp8_weight_block_strategy( return weight, weight_scale -def maybe_post_process_fp8_weight_block( - layer: torch.nn.Module, cutlass_block_fp8_supported: bool -): +def maybe_post_process_fp8_weight_block(layer: torch.nn.Module): assert layer.weight_block_size is not None from vllm.utils.deep_gemm import ( @@ -1146,15 +1139,6 @@ def maybe_post_process_fp8_weight_block( requant_weight_ue8m0_inplace( layer.weight.data, layer.weight_scale.data, block_sz ) - # SM90 Block FP8 CUTLASS requires row-major weight scales - elif ( - current_platform.is_device_capability(90) - and cutlass_block_fp8_supported - and not should_use_deepgemm - ): - layer.weight_scale = torch.nn.Parameter( - layer.weight_scale.data.T.contiguous(), requires_grad=False - ) def expert_weight_is_col_major(x: torch.Tensor) -> bool: From a7ef3eb0cd03e729c7a29914400e0ca928767999 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nicol=C3=B2=20Lucchesi?= Date: Tue, 11 Nov 2025 17:57:43 +0100 Subject: [PATCH 064/183] [NIXL] Generalize block-first backend layouts (FlashInfer-like) (#28282) --- .../kv_connector/unit/test_nixl_connector.py | 17 ++++++- .../kv_connector/v1/nixl_connector.py | 47 +++++++++++++++---- 2 files changed, 52 insertions(+), 12 deletions(-) diff --git a/tests/v1/kv_connector/unit/test_nixl_connector.py b/tests/v1/kv_connector/unit/test_nixl_connector.py index 475cf2285e39..8e421717fea3 100644 --- a/tests/v1/kv_connector/unit/test_nixl_connector.py +++ b/tests/v1/kv_connector/unit/test_nixl_connector.py @@ -1096,7 +1096,8 @@ def _run_abort_timeout_test(llm: LLM, timeout: int): llm.llm_engine.engine_core.shutdown() -def test_register_kv_caches(dist_init): +@pytest.mark.parametrize("attn_backend", ["FLASH_ATTN", "TRITON_ATTN"]) +def test_register_kv_caches(dist_init, attn_backend, monkeypatch): """ Test that register_kv_caches() properly calls nixl_wrapper methods with correct data. @@ -1108,10 +1109,22 @@ def test_register_kv_caches(dist_init): block layout info """ + monkeypatch.setenv("VLLM_ATTENTION_BACKEND", attn_backend) + vllm_config = create_vllm_config() + # Import the appropriate backend based on the parameter + if attn_backend == "FLASH_ATTN": + from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend + + backend_cls = FlashAttentionBackend + else: # TRITON_ATTN + from vllm.v1.attention.backends.triton_attn import TritonAttentionBackend + + backend_cls = TritonAttentionBackend + # Create test kv cache tensors using proper backend shape - kv_cache_shape = FlashAttentionBackend.get_kv_cache_shape( + kv_cache_shape = backend_cls.get_kv_cache_shape( num_blocks=2, block_size=16, num_kv_heads=4, head_size=64 ) shared_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 6c20eee1ecbf..375ea79d0e81 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -21,6 +21,7 @@ import zmq from vllm import envs +from vllm.attention import AttentionBackend from vllm.attention.backends.registry import AttentionBackendEnum from vllm.attention.selector import get_attn_backend from vllm.config import VllmConfig @@ -669,6 +670,33 @@ class TpKVTopology: remote_tp_size: dict[EngineId, int] is_mla: bool total_num_kv_heads: int + attn_backend: type[AttentionBackend] + + def __post_init__(self): + # Figure out whether the first dimension of the cache is K/V + # or num_blocks. This is used to register the memory regions correctly. + kv_cache_shape = self.attn_backend.get_kv_cache_shape( + num_blocks=1, block_size=16, num_kv_heads=1, head_size=1 + ) + # Non-MLA backends caches have 5 dims [2, num_blocks, H,N,D], + # we just mock num_blocks to 1 for the dimension check below. + self._is_kv_layout_blocks_first = ( + len(kv_cache_shape) == 5 and kv_cache_shape[0] == 1 + ) + + attn_backend = AttentionBackendEnum[self.attn_backend.get_name()] + self._use_pallas = attn_backend == AttentionBackendEnum.PALLAS + + @property + def is_kv_layout_blocks_first(self) -> bool: + return self._is_kv_layout_blocks_first + + @property + def split_k_and_v(self) -> bool: + # Whether to register regions for K and V separately (when present). + return not ( + self.is_mla or self._use_pallas or self.is_kv_layout_blocks_first + ) def tp_ratio( self, @@ -876,9 +904,6 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): use_mla=self.use_mla, ) self.backend_name = backend.get_name() - attn_backend = AttentionBackendEnum[self.backend_name] - self._use_flashinfer = attn_backend == AttentionBackendEnum.FLASHINFER - self._use_pallas = attn_backend == AttentionBackendEnum.PALLAS self.kv_cache_layout = get_kv_cache_layout() self.host_buffer_kv_cache_layout = self.kv_cache_layout logger.debug("Detected attention backend %s", self.backend_name) @@ -896,7 +921,9 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): remote_tp_size=self._tp_size, # shared state is_mla=self.use_mla, total_num_kv_heads=self.model_config.get_total_num_kv_heads(), + attn_backend=backend, ) + self._use_pallas = self.kv_topo._use_pallas def _nixl_handshake( self, @@ -1076,7 +1103,7 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): # (roughly 8KB vs 5KB). # Conversely for FlashInfer, K and V are registered in the same region # to better exploit the memory layout (ie num_blocks is the first dim). - split_k_and_v = not (self.use_mla or self._use_pallas or self._use_flashinfer) + split_k_and_v = self.kv_topo.split_k_and_v tensor_size_bytes = None # Enable different block lengths for different layers when MLA is used. self.block_len_per_layer = list[int]() @@ -1141,7 +1168,7 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): self.device_kv_caches = kv_caches self.dst_num_blocks[self.engine_id] = self.num_blocks - if self._use_flashinfer: + if self.kv_topo.is_kv_layout_blocks_first: for i in range(len(self.slot_size_per_layer)): assert self.slot_size_per_layer[i] % 2 == 0 self.slot_size_per_layer[i] //= 2 @@ -1169,7 +1196,7 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): # (addr, len, device id) blocks_data.append((addr, kv_block_len, self.device_id)) - if self._use_flashinfer: + if self.kv_topo.is_kv_layout_blocks_first: # Separate and interleave K/V regions to maintain the same # descs ordering. This is needed for selecting contiguous heads # when split across TP ranks. @@ -1331,7 +1358,7 @@ def add_remote_agent( # (addr, len, device id) blocks_data.append((addr, kv_block_len, nixl_agent_meta.device_id)) - if self._use_flashinfer: + if self.kv_topo.is_kv_layout_blocks_first: # With FlashInfer index V separately to allow head splitting. for block_id in range(nixl_agent_meta.num_blocks): block_offset = block_id * nixl_agent_meta.block_lens[i] @@ -1414,7 +1441,7 @@ def _validate_remote_agent_handshake( remote_block_size = remote_block_len // ( self.slot_size_per_layer[0] * tp_ratio ) - if self._use_flashinfer: + if self.kv_topo.is_kv_layout_blocks_first: # With flashinfer, KV are sent in the same message. remote_block_size //= 2 @@ -1494,7 +1521,7 @@ def permute_device_kv(self, block_ids: list[int]): - cache.index_copy_(0, indices, permuted_blocks) # copy permuted kv back """ - split_k_and_v = not (self.use_mla or self._use_pallas or self._use_flashinfer) + split_k_and_v = self.kv_topo.split_k_and_v inv_order = [0, 2, 1, 3] sample_cache = list(self.device_kv_caches.values())[0][0] target_shape = list(sample_cache.shape) @@ -1874,7 +1901,7 @@ def get_backend_aware_kv_block_len(self, layer_idx: int): For FlashInfer, this is half the length of the whole block, as K and V share the same region. """ - if self._use_flashinfer: + if self.kv_topo.is_kv_layout_blocks_first: # For indexing only half (either just the K or V part). block_len = self.block_len_per_layer[layer_idx] // 2 else: From 68c09efc37e87032640cf8db571eaf486bd744ac Mon Sep 17 00:00:00 2001 From: zhrrr <43847754+izhuhaoran@users.noreply.github.com> Date: Wed, 12 Nov 2025 01:00:31 +0800 Subject: [PATCH 065/183] [Kernel][Perf] fuse QK Norm and RoPE into one cuda kernel for Qwen Model (#27165) Signed-off-by: zhuhaoran --- .buildkite/test-pipeline.yaml | 1 + CMakeLists.txt | 1 + csrc/fused_qknorm_rope_kernel.cu | 418 ++++++++++++++++++ csrc/ops.h | 6 + csrc/torch_bindings.cpp | 10 + csrc/type_convert.cuh | 60 ++- tests/compile/test_qk_norm_rope_fusion.py | 195 ++++++++ tests/kernels/core/test_fused_qk_norm_rope.py | 141 ++++++ vllm/_custom_ops.py | 29 ++ vllm/compilation/fix_functionalization.py | 17 + vllm/compilation/fusion.py | 4 + vllm/compilation/matcher_utils.py | 81 +++- vllm/compilation/pass_manager.py | 4 + vllm/compilation/qk_norm_rope_fusion.py | 238 ++++++++++ vllm/config/compilation.py | 13 + .../layers/rotary_embedding/base.py | 45 +- 16 files changed, 1234 insertions(+), 29 deletions(-) create mode 100644 csrc/fused_qknorm_rope_kernel.cu create mode 100644 tests/compile/test_qk_norm_rope_fusion.py create mode 100644 tests/kernels/core/test_fused_qk_norm_rope.py create mode 100644 vllm/compilation/qk_norm_rope_fusion.py diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 25f711dd60b3..8d2a7bc5a802 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -451,6 +451,7 @@ steps: - pytest -v -s compile/test_decorator.py - pytest -v -s compile/test_noop_elimination.py - pytest -v -s compile/test_aot_compile.py + - pytest -v -s compile/test_qk_norm_rope_fusion.py - label: PyTorch Fullgraph Smoke Test # 15min timeout_in_minutes: 30 diff --git a/CMakeLists.txt b/CMakeLists.txt index 0e9fa63b178e..5cddf81a4b4a 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -265,6 +265,7 @@ set(VLLM_EXT_SRC "csrc/pos_encoding_kernels.cu" "csrc/activation_kernels.cu" "csrc/layernorm_kernels.cu" + "csrc/fused_qknorm_rope_kernel.cu" "csrc/layernorm_quant_kernels.cu" "csrc/sampler.cu" "csrc/cuda_view.cu" diff --git a/csrc/fused_qknorm_rope_kernel.cu b/csrc/fused_qknorm_rope_kernel.cu new file mode 100644 index 000000000000..cbd23975a773 --- /dev/null +++ b/csrc/fused_qknorm_rope_kernel.cu @@ -0,0 +1,418 @@ +/* + * Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include + +#include +#include + +#include "cuda_compat.h" +#include "dispatch_utils.h" +#include "type_convert.cuh" + +#define CHECK_TYPE(x, st) \ + TORCH_CHECK(x.scalar_type() == st, #x " dtype is ", x.scalar_type(), \ + ", while ", st, " is expected") +#define CHECK_TH_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) \ + TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") +#define CHECK_INPUT(x) \ + CHECK_TH_CUDA(x); \ + CHECK_CONTIGUOUS(x) + +#define FINAL_MASK 0xffffffff + +// TODO: suport for AMD ROCM platform +#ifndef USE_ROCM +namespace tensorrt_llm::common { +template +struct packed_as; +// Specialization for packed_as used in this kernel. +template <> +struct packed_as { + using type = uint; +}; + +template <> +struct packed_as { + using type = uint2; +}; + +template <> +struct packed_as { + using type = uint4; +}; + +template +__inline__ __device__ T warpReduceSum(T val) { + #pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) + val += __shfl_xor_sync(FINAL_MASK, val, mask, 32); + return val; +} + +template +inline __device__ __host__ T divUp(T m, T n) { + return (m + n - 1) / n; +} + +} // namespace tensorrt_llm::common + +namespace tensorrt_llm::kernels { +// NOTE(zhuhaoran): This kernel is adapted from TensorRT-LLM implementation, +// with added support for passing the cos_sin_cache as an input. +// https://github.com/NVIDIA/TensorRT-LLM/blob/main/cpp/tensorrt_llm/kernels/fusedQKNormRopeKernel.cu + +// Perform per-head QK Norm and RoPE in a single kernel. +// scalar_t_in: data type of QKV and RMSNorm weights +// scalar_t_cache: data type of cos/sin cache +// head_dim: the dimension of each head +// interleave: interleave=!is_neox. +template +__global__ void fusedQKNormRopeKernel( + void* qkv_void, // Combined QKV tensor + int const num_heads_q, // Number of query heads + int const num_heads_k, // Number of key heads + int const num_heads_v, // Number of value heads + float const eps, // Epsilon for RMS normalization + void const* q_weight_void, // RMSNorm weights for query + void const* k_weight_void, // RMSNorm weights for key + void const* cos_sin_cache_void, // Pre-computed cos/sin cache + int64_t const* position_ids, // Position IDs for RoPE + int const num_tokens // Number of tokens +) { + #if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ < 800 + if constexpr ((std::is_same_v) || + std::is_same_v) { + return; + } else { + #endif + + using Converter = vllm::_typeConvert; + static_assert(Converter::exists, + "Input QKV data type is not supported for this CUDA " + "architecture or toolkit version."); + using T_in = typename Converter::hip_type; + using T2_in = typename Converter::packed_hip_type; + + using CacheConverter = vllm::_typeConvert; + static_assert(CacheConverter::exists, + "Cache data type is not supported for this CUDA architecture " + "or toolkit version."); + using T_cache = typename CacheConverter::hip_type; + + T_in* qkv = reinterpret_cast(qkv_void); + T_in const* q_weight = reinterpret_cast(q_weight_void); + T_in const* k_weight = reinterpret_cast(k_weight_void); + T_cache const* cos_sin_cache = + reinterpret_cast(cos_sin_cache_void); + + int const warpsPerBlock = blockDim.x / 32; + int const warpId = threadIdx.x / 32; + int const laneId = threadIdx.x % 32; + + // Calculate global warp index to determine which head/token this warp + // processes + int const globalWarpIdx = blockIdx.x * warpsPerBlock + warpId; + + // Total number of attention heads (Q and K) + int const total_qk_heads = num_heads_q + num_heads_k; + + // Determine which token and head type (Q or K) this warp processes + int const tokenIdx = globalWarpIdx / total_qk_heads; + int const localHeadIdx = globalWarpIdx % total_qk_heads; + + // Skip if this warp is assigned beyond the number of tokens + if (tokenIdx >= num_tokens) return; + + bool const isQ = localHeadIdx < num_heads_q; + int const headIdx = isQ ? localHeadIdx : localHeadIdx - num_heads_q; + + int const num_heads = num_heads_q + num_heads_k + num_heads_v; + + static_assert(head_dim % (32 * 2) == 0, + "head_dim must be divisible by 64 (each warp processes one " + "head, and each thread gets even number of " + "elements)"); + constexpr int numElemsPerThread = head_dim / 32; + float elements[numElemsPerThread]; + constexpr int elemSizeBytes = numElemsPerThread * sizeof(__nv_bfloat16); + static_assert(elemSizeBytes % 4 == 0, + "numSizeBytes must be a multiple of 4"); + constexpr int vecSize = + elemSizeBytes / + 4; // Use packed_as to perform loading/saving. + using vec_T = typename tensorrt_llm::common::packed_as::type; + + int offsetWarp; // Offset for the warp + if (isQ) { + // Q segment: token offset + head offset within Q segment + offsetWarp = tokenIdx * num_heads * head_dim + headIdx * head_dim; + } else { + // K segment: token offset + entire Q segment + head offset within K + // segment + offsetWarp = tokenIdx * num_heads * head_dim + num_heads_q * head_dim + + headIdx * head_dim; + } + int offsetThread = offsetWarp + laneId * numElemsPerThread; + + // Sum of squares for RMSNorm + float sumOfSquares = 0.0f; + + // Load. + { + vec_T vec = *reinterpret_cast(&qkv[offsetThread]); + constexpr int num_packed_elems = elemSizeBytes / sizeof(T2_in); + #pragma unroll + for (int i = 0; i < num_packed_elems; i++) { + // Interpret the generic vector chunk as the specific packed type + T2_in packed_val = *(reinterpret_cast(&vec) + i); + // Convert to float2 for computation + float2 vals = Converter::convert(packed_val); + sumOfSquares += vals.x * vals.x; + sumOfSquares += vals.y * vals.y; + + elements[2 * i] = vals.x; + elements[2 * i + 1] = vals.y; + } + } + + // Reduce sum across warp using the utility function + sumOfSquares = tensorrt_llm::common::warpReduceSum(sumOfSquares); + + // Compute RMS normalization factor + float rms_rcp = rsqrtf(sumOfSquares / static_cast(head_dim) + eps); + + // Normalize elements + #pragma unroll + for (int i = 0; i < numElemsPerThread; i++) { + int dim = laneId * numElemsPerThread + i; + float weight = isQ ? Converter::convert(q_weight[dim]) + : Converter::convert(k_weight[dim]); + elements[i] *= rms_rcp * weight; + } + + // Apply RoPE to normalized elements + float elements2[numElemsPerThread]; // Additional buffer required for RoPE. + + int64_t pos_id = position_ids[tokenIdx]; + + // Calculate cache pointer for this position - similar to + // pos_encoding_kernels.cu + T_cache const* cache_ptr = cos_sin_cache + pos_id * head_dim; + int const embed_dim = head_dim / 2; + T_cache const* cos_ptr = cache_ptr; + T_cache const* sin_ptr = cache_ptr + embed_dim; + + if constexpr (interleave) { + // Perform interleaving. Use pre-computed cos/sin values. + #pragma unroll + for (int i = 0; i < numElemsPerThread / 2; ++i) { + int const idx0 = 2 * i; + int const idx1 = 2 * i + 1; + + float const val0 = elements[idx0]; + float const val1 = elements[idx1]; + + int const dim_idx = laneId * numElemsPerThread + idx0; + int const half_dim = dim_idx / 2; + float const cos_val = + CacheConverter::convert(VLLM_LDG(cos_ptr + half_dim)); + float const sin_val = + CacheConverter::convert(VLLM_LDG(sin_ptr + half_dim)); + + elements[idx0] = val0 * cos_val - val1 * sin_val; + elements[idx1] = val0 * sin_val + val1 * cos_val; + } + } else { + // Before data exchange with in warp, we need to sync. + __syncwarp(); + // Get the data from the other half of the warp. Use pre-computed cos/sin + // values. + #pragma unroll + for (int i = 0; i < numElemsPerThread; i++) { + elements2[i] = __shfl_xor_sync(0xffffffff, elements[i], 16); + if (laneId < 16) { + elements2[i] = -elements2[i]; + } + + int dim_idx = laneId * numElemsPerThread + i; + dim_idx = (dim_idx * 2) % head_dim; + int half_dim = dim_idx / 2; + // Use pre-computed cos/sin from cache + float cos_val = CacheConverter::convert(VLLM_LDG(cos_ptr + half_dim)); + float sin_val = CacheConverter::convert(VLLM_LDG(sin_ptr + half_dim)); + + elements[i] = elements[i] * cos_val + elements2[i] * sin_val; + } + // __shfl_xor_sync does not provide memfence. Need to sync again. + __syncwarp(); + } + + // Store. + { + vec_T vec; + constexpr int num_packed_elems = elemSizeBytes / sizeof(T2_in); + #pragma unroll + for (int i = 0; i < num_packed_elems; i++) { + // Convert from float2 back to the specific packed type + T2_in packed_val = Converter::convert( + make_float2(elements[2 * i], elements[2 * i + 1])); + // Place it into the generic vector + *(reinterpret_cast(&vec) + i) = packed_val; + } + *reinterpret_cast(&qkv[offsetThread]) = vec; + } + + #if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ < 800 + } + #endif +} + + // Borrowed from + // https://github.com/flashinfer-ai/flashinfer/blob/8125d079a43e9a0ba463a4ed1b639cefd084cec9/include/flashinfer/pos_enc.cuh#L568 + #define DISPATCH_INTERLEAVE(interleave, INTERLEAVE, ...) \ + if (interleave) { \ + const bool INTERLEAVE = true; \ + __VA_ARGS__ \ + } else { \ + const bool INTERLEAVE = false; \ + __VA_ARGS__ \ + } + +template +void launchFusedQKNormRope(void* qkv, int const num_tokens, + int const num_heads_q, int const num_heads_k, + int const num_heads_v, int const head_dim, + float const eps, void const* q_weight, + void const* k_weight, void const* cos_sin_cache, + bool const interleave, int64_t const* position_ids, + cudaStream_t stream) { + constexpr int blockSize = 256; + + int const warpsPerBlock = blockSize / 32; + int const totalQKHeads = num_heads_q + num_heads_k; + int const totalWarps = num_tokens * totalQKHeads; + + int const gridSize = common::divUp(totalWarps, warpsPerBlock); + dim3 gridDim(gridSize); + dim3 blockDim(blockSize); + + switch (head_dim) { + case 64: + DISPATCH_INTERLEAVE(interleave, INTERLEAVE, { + fusedQKNormRopeKernel + <<>>( + qkv, num_heads_q, num_heads_k, num_heads_v, eps, q_weight, + k_weight, cos_sin_cache, position_ids, num_tokens); + }); + break; + case 128: + DISPATCH_INTERLEAVE(interleave, INTERLEAVE, { + fusedQKNormRopeKernel + <<>>( + qkv, num_heads_q, num_heads_k, num_heads_v, eps, q_weight, + k_weight, cos_sin_cache, position_ids, num_tokens); + }); + break; + case 256: + DISPATCH_INTERLEAVE(interleave, INTERLEAVE, { + fusedQKNormRopeKernel + <<>>( + qkv, num_heads_q, num_heads_k, num_heads_v, eps, q_weight, + k_weight, cos_sin_cache, position_ids, num_tokens); + }); + break; + default: + TORCH_CHECK(false, + "Unsupported head dimension for fusedQKNormRope: ", head_dim); + } +} +} // namespace tensorrt_llm::kernels + +void fused_qk_norm_rope( + torch::Tensor& qkv, // Combined QKV tensor [num_tokens, + // (num_heads_q+num_heads_k+num_heads_v)*head_dim] + int64_t num_heads_q, // Number of query heads + int64_t num_heads_k, // Number of key heads + int64_t num_heads_v, // Number of value heads + int64_t head_dim, // Dimension per head + double eps, // Epsilon for RMS normalization + torch::Tensor& q_weight, // RMSNorm weights for query [head_dim] + torch::Tensor& k_weight, // RMSNorm weights for key [head_dim] + torch::Tensor& cos_sin_cache, // Cos/sin cache [max_position, head_dim] + bool is_neox, // Whether RoPE is applied in Neox style + torch::Tensor& position_ids // Position IDs for RoPE [num_tokens] +) { + // Input validation + CHECK_INPUT(qkv); + CHECK_INPUT(position_ids); + CHECK_INPUT(q_weight); + CHECK_INPUT(k_weight); + CHECK_INPUT(cos_sin_cache); + CHECK_TYPE(position_ids, torch::kInt64); + + TORCH_CHECK(qkv.dim() == 2, + "QKV tensor must be 2D: [num_tokens, " + "(num_heads_q+num_heads_k+num_heads_v)*head_dim]"); + TORCH_CHECK(position_ids.dim() == 1, "Position IDs must be 1D: [num_tokens]"); + TORCH_CHECK(q_weight.dim() == 1, "Query weights must be 1D: [head_dim]"); + TORCH_CHECK(k_weight.dim() == 1, "Key weights must be 1D: [head_dim]"); + TORCH_CHECK(cos_sin_cache.dim() == 2, + "Cos/sin cache must be 2D: [max_position, head_dim]"); + TORCH_CHECK(q_weight.size(0) == head_dim, + "Query weights size must match head dimension"); + TORCH_CHECK(k_weight.size(0) == head_dim, + "Key weights size must match head dimension"); + TORCH_CHECK(cos_sin_cache.size(1) == head_dim, + "Cos/sin cache dimension must match head_dim"); + TORCH_CHECK(qkv.scalar_type() == q_weight.scalar_type() && + qkv.scalar_type() == k_weight.scalar_type(), + "qkv, q_weight and k_weight must have the same dtype"); + + int64_t num_tokens = qkv.size(0); + TORCH_CHECK(position_ids.size(0) == num_tokens, + "Number of tokens in position_ids must match QKV"); + + int64_t total_heads = num_heads_q + num_heads_k + num_heads_v; + TORCH_CHECK( + qkv.size(1) == total_heads * head_dim, + "QKV tensor size must match total number of heads and head dimension"); + + auto stream = at::cuda::getCurrentCUDAStream(qkv.get_device()); + + VLLM_DISPATCH_HALF_TYPES(qkv.scalar_type(), "fused_qk_norm_rope_kernel", [&] { + using qkv_scalar_t = scalar_t; + VLLM_DISPATCH_FLOATING_TYPES( + cos_sin_cache.scalar_type(), "fused_qk_norm_rope_kernel", [&] { + using cache_scalar_t = scalar_t; + tensorrt_llm::kernels::launchFusedQKNormRope( + qkv.data_ptr(), static_cast(num_tokens), + static_cast(num_heads_q), static_cast(num_heads_k), + static_cast(num_heads_v), static_cast(head_dim), + static_cast(eps), q_weight.data_ptr(), k_weight.data_ptr(), + cos_sin_cache.data_ptr(), !is_neox, + reinterpret_cast(position_ids.data_ptr()), + stream); + }); + }); +} + +#endif // not USE_ROCM \ No newline at end of file diff --git a/csrc/ops.h b/csrc/ops.h index 3f5cb799b774..f8bdc61aaa8e 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -92,6 +92,12 @@ void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight, void fused_add_rms_norm(torch::Tensor& input, torch::Tensor& residual, torch::Tensor& weight, double epsilon); +void fused_qk_norm_rope(torch::Tensor& qkv, int64_t num_heads_q, + int64_t num_heads_k, int64_t num_heads_v, + int64_t head_dim, double eps, torch::Tensor& q_weight, + torch::Tensor& k_weight, torch::Tensor& cos_sin_cache, + bool is_neox, torch::Tensor& position_ids); + void apply_repetition_penalties_(torch::Tensor& logits, const torch::Tensor& prompt_mask, const torch::Tensor& output_mask, diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 9c0f524dcab1..d4a69cbe7971 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -175,6 +175,16 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { "float epsilon) -> ()"); ops.impl("fused_add_rms_norm", torch::kCUDA, &fused_add_rms_norm); +#ifndef USE_ROCM + // Function for fused QK Norm and RoPE + ops.def( + "fused_qk_norm_rope(Tensor! qkv, int num_heads_q, " + "int num_heads_k, int num_heads_v, int head_dim, float eps, " + "Tensor q_weight, Tensor k_weight, Tensor cos_sin_cache, " + "bool is_neox, Tensor position_ids) -> ()"); + ops.impl("fused_qk_norm_rope", torch::kCUDA, &fused_qk_norm_rope); +#endif + // Apply repetition penalties to logits in-place ops.def( "apply_repetition_penalties_(Tensor! logits, Tensor prompt_mask, " diff --git a/csrc/type_convert.cuh b/csrc/type_convert.cuh index 21b9d0ae515d..6da06f1e66cf 100644 --- a/csrc/type_convert.cuh +++ b/csrc/type_convert.cuh @@ -29,6 +29,22 @@ struct _typeConvert { static constexpr bool exists = false; }; +template <> +struct _typeConvert { + static constexpr bool exists = true; + using hip_type = float; + using packed_hip_type = float2; + using packed_hip_type4 = float4; // For 128-bit vectorization + + __device__ static __forceinline__ float convert(hip_type x) { return x; } + __device__ static __forceinline__ float2 convert(packed_hip_type x) { + return x; + } + __device__ static __forceinline__ float4 convert(packed_hip_type4 x) { + return x; + } +}; + #if defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >= 12000)) // CUDA < 12.0 runs into issues with packed type conversion template <> @@ -37,14 +53,16 @@ struct _typeConvert { using hip_type = __half; using packed_hip_type = __half2; - __device__ static inline float convert(hip_type x) { return __half2float(x); } - __device__ static inline float2 convert(packed_hip_type x) { + __device__ static __forceinline__ float convert(hip_type x) { + return __half2float(x); + } + __device__ static __forceinline__ float2 convert(packed_hip_type x) { return __half22float2(x); } - __device__ static inline hip_type convert(float x) { + __device__ static __forceinline__ hip_type convert(float x) { return __float2half_rn(x); } - __device__ static inline packed_hip_type convert(float2 x) { + __device__ static __forceinline__ packed_hip_type convert(float2 x) { return __float22half2_rn(x); } }; @@ -58,16 +76,16 @@ struct _typeConvert { using hip_type = __nv_bfloat16; using packed_hip_type = __nv_bfloat162; - __device__ static inline float convert(hip_type x) { + __device__ static __forceinline__ float convert(hip_type x) { return __bfloat162float(x); } - __device__ static inline float2 convert(packed_hip_type x) { + __device__ static __forceinline__ float2 convert(packed_hip_type x) { return __bfloat1622float2(x); } - __device__ static inline hip_type convert(float x) { + __device__ static __forceinline__ hip_type convert(float x) { return __float2bfloat16(x); } - __device__ static inline packed_hip_type convert(float2 x) { + __device__ static __forceinline__ packed_hip_type convert(float2 x) { return __float22bfloat162_rn(x); } }; @@ -95,10 +113,15 @@ struct alignas(16) _f16Vec { if constexpr (width % 2 == 0) { #pragma unroll for (int i = 0; i < width; i += 2) { - T2 temp{data[i], data[i + 1]}; - temp += T2{other.data[i], other.data[i + 1]}; - data[i] = temp.x; - data[i + 1] = temp.y; + if constexpr (std::is_same_v) { + data[i] += other.data[i]; + data[i + 1] += other.data[i + 1]; + } else { + T2 temp{data[i], data[i + 1]}; + temp += T2{other.data[i], other.data[i + 1]}; + data[i] = temp.x; + data[i + 1] = temp.y; + } } } else { #pragma unroll @@ -111,10 +134,15 @@ struct alignas(16) _f16Vec { if constexpr (width % 2 == 0) { #pragma unroll for (int i = 0; i < width; i += 2) { - T2 temp{data[i], data[i + 1]}; - temp *= T2{other.data[i], other.data[i + 1]}; - data[i] = temp.x; - data[i + 1] = temp.y; + if constexpr (std::is_same_v) { + data[i] *= other.data[i]; + data[i + 1] *= other.data[i + 1]; + } else { + T2 temp{data[i], data[i + 1]}; + temp *= T2{other.data[i], other.data[i + 1]}; + data[i] = temp.x; + data[i + 1] = temp.y; + } } } else { #pragma unroll diff --git a/tests/compile/test_qk_norm_rope_fusion.py b/tests/compile/test_qk_norm_rope_fusion.py new file mode 100644 index 000000000000..973123a3af92 --- /dev/null +++ b/tests/compile/test_qk_norm_rope_fusion.py @@ -0,0 +1,195 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest +import torch + +from tests.compile.backend import TestBackend +from vllm.attention import Attention, AttentionType +from vllm.compilation.matcher_utils import FLASHINFER_ROTARY_OP, RMS_OP, ROTARY_OP +from vllm.compilation.noop_elimination import NoOpEliminationPass +from vllm.compilation.post_cleanup import PostCleanupPass +from vllm.compilation.qk_norm_rope_fusion import ( + FUSED_QK_ROPE_OP, + QKNormRoPEFusionPass, +) +from vllm.config import ( + CompilationConfig, + CompilationMode, + ModelConfig, + PassConfig, + VllmConfig, + set_current_vllm_config, +) +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding +from vllm.platforms import current_platform + +RSQRT_OP = torch.ops.aten.rsqrt.default +INDEX_SELECT_OP = torch.ops.aten.index.Tensor + + +class QKNormRoPETestModel(torch.nn.Module): + def __init__( + self, + *, + num_heads: int, + num_kv_heads: int, + head_dim: int, + eps: float, + is_neox: bool, + vllm_config: VllmConfig, + dtype: torch.dtype, + prefix: str = "model.layers.0.self_attn.attn", + ) -> None: + super().__init__() + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = head_dim + self.q_size = num_heads * head_dim + self.kv_size = num_kv_heads * head_dim + self.rotary_dim = head_dim + self.eps = eps + self.dtype = dtype + + # Register layer metadata for the fusion pass via Attention. + self.attn = Attention( + num_heads=self.num_heads, + head_size=self.head_dim, + scale=1.0 / self.head_dim**0.5, + num_kv_heads=self.num_kv_heads, + cache_config=vllm_config.cache_config, + prefix=prefix, + attn_type=AttentionType.DECODER, + ) + + self.q_norm = RMSNorm(self.head_dim, eps=self.eps) + self.k_norm = RMSNorm(self.head_dim, eps=self.eps) + self.rotary_emb = RotaryEmbedding( + self.head_dim, + rotary_dim=self.rotary_dim, + max_position_embeddings=4096, + base=10000, + is_neox_style=is_neox, + dtype=self.dtype, + ) + self.enable_rms_norm_custom_op = self.q_norm.enabled() + self.enable_rope_custom_op = self.rotary_emb.enabled() + + def forward(self, qkv: torch.Tensor, positions: torch.Tensor): + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim, self.head_dim) + q_by_head = self.q_norm(q_by_head) + q = q_by_head.view(q.shape) + k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim, self.head_dim) + k_by_head = self.k_norm(k_by_head) + k = k_by_head.view(k.shape) + q, k = self.rotary_emb(positions, q, k) + return q, k, v + + def ops_in_model_before(self) -> list[torch._ops.OpOverload]: + ops = [] + if self.enable_rms_norm_custom_op: + ops.append(RMS_OP) + else: + ops.append(RSQRT_OP) + + if self.enable_rope_custom_op: + if self.rotary_emb.use_flashinfer: + ops.append(FLASHINFER_ROTARY_OP) + else: + ops.append(ROTARY_OP) + else: + ops.append(INDEX_SELECT_OP) + return ops + + def ops_in_model_after(self) -> list[torch._ops.OpOverload]: + return [FUSED_QK_ROPE_OP] + + +@pytest.mark.parametrize("eps", [1e-5, 1e-6]) +@pytest.mark.parametrize("is_neox", [True, False]) +@pytest.mark.parametrize("enable_rms_norm_custom_op", [True, False]) +@pytest.mark.parametrize("enable_rope_custom_op", [True]) +@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16]) +@pytest.mark.skipif( + not current_platform.is_cuda(), + reason="Only test on cuda platform", +) +def test_qk_norm_rope_fusion( + eps, is_neox, enable_rms_norm_custom_op, enable_rope_custom_op, dtype +): + if not hasattr(torch.ops._C, "fused_qk_norm_rope"): + pytest.skip("fused_qk_norm_rope custom op not available") + + torch.set_default_device("cuda") + torch.set_default_dtype(dtype) + torch.manual_seed(0) + + custom_ops: list[str] = [] + if enable_rms_norm_custom_op: + custom_ops.append("+rms_norm") + if enable_rope_custom_op: + custom_ops.append("+rotary_embedding") + + vllm_config = VllmConfig( + model_config=ModelConfig(dtype=dtype), + compilation_config=CompilationConfig( + mode=CompilationMode.VLLM_COMPILE, + custom_ops=custom_ops, + pass_config=PassConfig( + enable_qk_norm_rope_fusion=True, + enable_noop=True, + ), + ), + ) + + num_heads, num_kv_heads, head_dim = 16, 4, 128 + T = 5 + + with set_current_vllm_config(vllm_config): + model = QKNormRoPETestModel( + num_heads=num_heads, + num_kv_heads=num_kv_heads, + head_dim=head_dim, + eps=eps, + is_neox=is_neox, + vllm_config=vllm_config, + dtype=dtype, + ) + + noop_pass = NoOpEliminationPass(vllm_config) + fusion_pass = QKNormRoPEFusionPass(vllm_config) + cleanup_pass = PostCleanupPass(vllm_config) + + backend = TestBackend(noop_pass, fusion_pass, cleanup_pass) + backend_baseline = TestBackend(noop_pass, cleanup_pass) + + qkv = torch.randn(T, model.q_size + 2 * model.kv_size) + pos = torch.arange(T, dtype=torch.long, device=qkv.device) + qkv_unfused = qkv.clone() + pos_unfused = pos.clone() + + torch._dynamo.mark_dynamic(qkv, 0) + torch._dynamo.mark_dynamic(pos, 0) + model_fused = torch.compile(model, backend=backend) + q_fused, k_fused, v_fused = model_fused(qkv, pos) + + torch._dynamo.mark_dynamic(qkv_unfused, 0) + torch._dynamo.mark_dynamic(pos_unfused, 0) + model_unfused = torch.compile(model, backend=backend_baseline) + q_unfused, k_unfused, v_unfused = model_unfused(qkv_unfused, pos_unfused) + + if dtype == torch.float16: + ATOL, RTOL = (2e-3, 2e-3) + else: + ATOL, RTOL = (1e-2, 1e-2) + + torch.testing.assert_close(q_unfused, q_fused, atol=ATOL, rtol=RTOL) + torch.testing.assert_close(k_unfused, k_fused, atol=ATOL, rtol=RTOL) + torch.testing.assert_close(v_unfused, v_fused, atol=ATOL, rtol=RTOL) + + assert fusion_pass.matched_count == 1 + + backend.check_before_ops(model.ops_in_model_before()) + backend.check_after_ops(model.ops_in_model_after()) diff --git a/tests/kernels/core/test_fused_qk_norm_rope.py b/tests/kernels/core/test_fused_qk_norm_rope.py new file mode 100644 index 000000000000..88bb7691ec3b --- /dev/null +++ b/tests/kernels/core/test_fused_qk_norm_rope.py @@ -0,0 +1,141 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest +import torch + +from tests.kernels.utils import opcheck +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding +from vllm.platforms import current_platform + +DTYPES = [torch.bfloat16, torch.float16] +IS_NEOX = [True, False] +EPS_VALUES = [1e-5, 1e-6] +SEEDS = [13] +CUDA_DEVICES = ["cuda:0"] + + +def _apply_qk_norm_rope( + qkv: torch.Tensor, + positions: torch.Tensor, + q_norm: RMSNorm, + k_norm: RMSNorm, + rope: RotaryEmbedding, + num_heads_q: int, + num_heads_kv: int, + head_dim: int, +) -> torch.Tensor: + q_size = num_heads_q * head_dim + kv_size = num_heads_kv * head_dim + + q, k, v = qkv.split([q_size, kv_size, kv_size], dim=-1) + + q_by_head = q.view(*q.shape[:-1], q.shape[-1] // head_dim, head_dim) + q_by_head = q_norm.forward_native(q_by_head) + q = q_by_head.view(q.shape) + + k_by_head = k.view(*k.shape[:-1], k.shape[-1] // head_dim, head_dim) + k_by_head = k_norm.forward_native(k_by_head) + k = k_by_head.view(k.shape) + + q, k = rope.forward_native(positions, q, k) + return torch.cat([q, k, v], dim=-1) + + +@pytest.mark.skipif( + not current_platform.is_cuda(), + reason="fused_qk_norm_rope custom op requires cuda platform", +) +@pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("is_neox", IS_NEOX) +@pytest.mark.parametrize("eps", EPS_VALUES) +@pytest.mark.parametrize("seed", SEEDS) +@torch.inference_mode() +def test_fused_qk_norm_rope_matches_reference( + device: str, + dtype: torch.dtype, + is_neox: bool, + eps: float, + seed: int, +): + torch.set_default_device(device) + current_platform.seed_everything(seed) + num_heads, num_kv_heads, head_dim = 16, 4, 128 + num_tokens = 4 + + total_dim = (num_heads + 2 * num_kv_heads) * head_dim + qkv_base = torch.randn(num_tokens, total_dim, dtype=dtype, device=device) + qkv_fused = qkv_base.clone() + positions = torch.arange(num_tokens, dtype=torch.long, device=device) + + q_norm = RMSNorm(head_dim, eps=eps).to(device=device, dtype=dtype) + k_norm = RMSNorm(head_dim, eps=eps).to(device=device, dtype=dtype) + q_norm.weight.data.normal_(mean=1.0, std=0.1) + k_norm.weight.data.normal_(mean=1.0, std=0.1) + q_weight = q_norm.weight.data + k_weight = k_norm.weight.data + + rope = RotaryEmbedding( + head_size=head_dim, + rotary_dim=head_dim, + max_position_embeddings=4096, + base=10000.0, + is_neox_style=is_neox, + dtype=dtype, + ).to(device) + + ref_result = _apply_qk_norm_rope( + qkv=qkv_base, + positions=positions, + q_norm=q_norm, + k_norm=k_norm, + rope=rope, + num_heads_q=num_heads, + num_heads_kv=num_kv_heads, + head_dim=head_dim, + ) + + opcheck( + torch.ops._C.fused_qk_norm_rope, + ( + qkv_fused.clone(), + num_heads, + num_kv_heads, + num_kv_heads, + head_dim, + eps, + q_weight, + k_weight, + rope.cos_sin_cache, + is_neox, + positions.view(-1), + ), + ) + + torch.ops._C.fused_qk_norm_rope( + qkv_fused, + num_heads, + num_kv_heads, + num_kv_heads, + head_dim, + eps, + q_weight, + k_weight, + rope.cos_sin_cache, + is_neox, + positions.view(-1), + ) + + if dtype == torch.float16: + ATOL, RTOL = (2e-3, 2e-3) + else: + ATOL, RTOL = (1e-2, 1e-2) + + torch.testing.assert_close( + qkv_fused, + ref_result, + atol=ATOL, + rtol=RTOL, + ) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 36aab503dee7..136a3193efb5 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -329,6 +329,7 @@ def rms_norm( out: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, epsilon: float ) -> None: # TODO: Remove this contiguous call when the kernel is updated to support non-contiguous input + # If removed, also need to remove contiguous in MatcherRMSNorm input_contiguous = input.contiguous() torch.ops._C.rms_norm(out, input_contiguous, weight, epsilon) @@ -339,6 +340,34 @@ def fused_add_rms_norm( torch.ops._C.fused_add_rms_norm(input, residual, weight, epsilon) +def fused_qk_norm_rope( + qkv: torch.Tensor, + num_heads_q: int, + num_heads_k: int, + num_heads_v: int, + head_dim: int, + eps: float, + q_weight: torch.Tensor, + k_weight: torch.Tensor, + cos_sin_cache: torch.Tensor, + is_neox: bool, + position_ids: torch.Tensor, +) -> None: + torch.ops._C.fused_qk_norm_rope( + qkv, + num_heads_q, + num_heads_k, + num_heads_v, + head_dim, + eps, + q_weight, + k_weight, + cos_sin_cache, + is_neox, + position_ids, + ) + + def apply_repetition_penalties_torch( logits: torch.Tensor, prompt_mask: torch.Tensor, diff --git a/vllm/compilation/fix_functionalization.py b/vllm/compilation/fix_functionalization.py index 29462d9ff0e5..126ad35e527a 100644 --- a/vllm/compilation/fix_functionalization.py +++ b/vllm/compilation/fix_functionalization.py @@ -132,6 +132,23 @@ def __call__(self, graph: torch.fx.Graph): "input_global_scale", ), ) + # Defunctionalize fused_qk_norm_rope to remove higher-order wrapper. + elif at_target == torch.ops._C.fused_qk_norm_rope.default: + mutated_args = {1: "qkv"} + args = ( + "qkv", + "num_heads_q", + "num_heads_k", + "num_heads_v", + "head_dim", + "eps", + "q_weight", + "k_weight", + "cos_sin_cache", + "is_neox", + "position_ids", + ) + self.defunctionalize(graph, node, mutated_args=mutated_args, args=args) else: continue # skip the count diff --git a/vllm/compilation/fusion.py b/vllm/compilation/fusion.py index 8f0ad2d69fbe..1d6e297b495e 100644 --- a/vllm/compilation/fusion.py +++ b/vllm/compilation/fusion.py @@ -44,6 +44,10 @@ def empty_i32(*args, **kwargs): return torch.empty(*args, **kwargs, dtype=torch.int32, device="cuda") +def empty_i64(*args, **kwargs): + return torch.empty(*args, **kwargs, dtype=torch.int64, device="cuda") + + RMS_OP = torch.ops._C.rms_norm.default RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default diff --git a/vllm/compilation/matcher_utils.py b/vllm/compilation/matcher_utils.py index 383fe6033a6d..38eb4e5301a1 100644 --- a/vllm/compilation/matcher_utils.py +++ b/vllm/compilation/matcher_utils.py @@ -18,10 +18,13 @@ kFp8StaticTensorSym, kNvfp4Quant, ) +from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding from vllm.platforms import current_platform RMS_OP = torch.ops._C.rms_norm.default RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default +ROTARY_OP = torch.ops._C.rotary_embedding.default +FLASHINFER_ROTARY_OP = torch.ops.vllm.flashinfer_rotary_embedding.default QUANT_OPS: dict[QuantKey, OpOverload] = { kFp8StaticTensorSym: torch.ops._C.static_scaled_fp8_quant.default, # noqa: E501 @@ -58,6 +61,9 @@ def __call__(self, *args, **kws): def empty(self, *args, **kws): return torch.empty(*args, dtype=self.model_dtype, device=self.device, **kws) + def empty_int64(self, *args, **kws): + return torch.empty(*args, dtype=torch.int64, device=self.device, **kws) + def empty_f32(self, *args, **kws): return torch.empty(*args, dtype=torch.float32, device=self.device, **kws) @@ -66,6 +72,77 @@ def inputs(self) -> list[torch.Tensor]: raise NotImplementedError +class MatcherRotaryEmbedding(MatcherCustomOp): + def __init__( + self, + is_neox: bool, + head_size: int, + num_heads: int, + num_kv_heads: int, + use_flashinfer: bool = False, + enabled: bool | None = None, + ) -> None: + if enabled is None: + enabled = RotaryEmbedding.enabled() + + super().__init__(enabled) + self.is_neox = is_neox + self.head_size = head_size + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.q_size = self.num_heads * self.head_size + self.kv_size = self.num_kv_heads * self.head_size + self.rotary_dim = head_size + if use_flashinfer: + self.rotary_op = FLASHINFER_ROTARY_OP + else: + self.rotary_op = ROTARY_OP + + def inputs(self) -> list[torch.Tensor]: + positions = self.empty_int64(5) + query = self.empty(5, self.q_size) + key = self.empty(5, self.kv_size) + cos_sin_cache = self.empty(4096, self.rotary_dim) + return [positions, query, key, cos_sin_cache] + + def forward_custom( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor | None, + cos_sin_cache: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor | None]: + result = auto_functionalized( + self.rotary_op, + positions=positions, + query=query, + key=key, + head_size=self.head_size, + cos_sin_cache=cos_sin_cache, + is_neox=self.is_neox, + ) + query_out = result[1] + key_out = result[2] if len(result) > 2 else None + return query_out, key_out + + def forward_native( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor | None, + cos_sin_cache: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor | None]: + return RotaryEmbedding.forward_static( + positions, + query, + key, + self.head_size, + self.rotary_dim, + cos_sin_cache, + self.is_neox, + ) + + class MatcherRMSNorm(MatcherCustomOp): def __init__(self, epsilon: float, enabled: bool | None = None): if enabled is None: @@ -85,10 +162,12 @@ def forward_custom( weight: torch.Tensor, ) -> torch.Tensor: result = torch.empty_like(input) + # TODO: support non-contiguous input for RMSNorm and remove this + input_contiguous = input.contiguous() _, result = auto_functionalized( RMS_OP, result=result, - input=input, + input=input_contiguous, weight=weight, epsilon=self.epsilon, ) diff --git a/vllm/compilation/pass_manager.py b/vllm/compilation/pass_manager.py index dfda2adf1d3b..0c2210d72ce0 100644 --- a/vllm/compilation/pass_manager.py +++ b/vllm/compilation/pass_manager.py @@ -17,6 +17,7 @@ from .activation_quant_fusion import ActivationQuantFusionPass from .fusion import RMSNormQuantFusionPass from .fusion_attn import AttnFusionPass + from .qk_norm_rope_fusion import QKNormRoPEFusionPass if current_platform.is_cuda(): from .collective_fusion import AllReduceFusionPass, AsyncTPPass @@ -109,6 +110,9 @@ def configure(self, config: VllmConfig): if self.pass_config.enable_attn_fusion: self.passes += [AttnFusionPass(config)] + if self.pass_config.enable_qk_norm_rope_fusion: + self.passes += [QKNormRoPEFusionPass(config)] + # needs a functional graph self.post_cleanup = PostCleanupPass(config) self.fix_functionalization = FixFunctionalizationPass(config) diff --git a/vllm/compilation/qk_norm_rope_fusion.py b/vllm/compilation/qk_norm_rope_fusion.py new file mode 100644 index 000000000000..e3c399e07906 --- /dev/null +++ b/vllm/compilation/qk_norm_rope_fusion.py @@ -0,0 +1,238 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from collections.abc import Callable + +import torch +import torch._inductor.pattern_matcher as pm +from torch import fx +from torch._higher_order_ops.auto_functionalize import auto_functionalized +from torch._inductor.pattern_matcher import PatternMatcherPass + +from vllm.attention import Attention +from vllm.config import VllmConfig, get_layers_from_vllm_config +from vllm.logger import init_logger +from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding + +from .fusion import empty_bf16, empty_fp32, empty_i64 +from .inductor_pass import enable_fake_mode +from .matcher_utils import MatcherRMSNorm, MatcherRotaryEmbedding +from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass + +logger = init_logger(__name__) + +FUSED_QK_ROPE_OP = torch.ops._C.fused_qk_norm_rope.default + + +class QkNormRopePattern: + """ + Match the unfused sequence in attention blocks and replace with the fused op. + + Unfused (conceptually): + q, k, v = split(qkv, [qsz, kvsz, kvsz], -1) + qh = reshape(q, [-1, num_heads, head_dim]) + kh = reshape(k, [-1, num_kv_heads, head_dim]) + qn = rms_norm(qh, q_weight, eps) + kn = rms_norm(kh, k_weight, eps) + qf = reshape(qn, [-1, num_heads * head_dim]) + kf = reshape(kn, [-1, num_kv_heads * head_dim]) + qf, kf = rotary_embedding(positions, qf, kf, head_dim, cos_sin_cache, is_neox) + return qf, kf, v + + Fused replacement: + fused_qk_norm_rope(qkv, num_heads, num_kv_heads, num_kv_heads, head_dim, + eps, q_weight, k_weight, cos_sin_cache, is_neox, + positions.view(-1)) + return split(qkv, [qsz, kvsz, kvsz], -1) + """ + + def __init__( + self, + head_dim: int, + num_heads: int, + num_kv_heads: int, + eps: float, + is_neox: bool, + rope_flashinfer: bool = False, + ) -> None: + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = head_dim + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.eps = eps + self.rmsnorm_matcher = MatcherRMSNorm(eps) + self.is_neox = is_neox + self.rope_flashinfer = rope_flashinfer + self.rope_matcher = MatcherRotaryEmbedding( + is_neox=is_neox, + head_size=self.head_dim, + num_heads=self.num_heads, + num_kv_heads=self.num_kv_heads, + use_flashinfer=self.rope_flashinfer, + ) + + def get_inputs(self): + # Sample inputs to help pattern tracing + T = 5 + qkv = empty_bf16(T, self.q_size + 2 * self.kv_size) + positions = empty_i64(T) + q_weight = empty_bf16(1, self.head_dim) + k_weight = empty_bf16(1, self.head_dim) + if self.rope_flashinfer: + cos_sin_cache = empty_fp32(4096, self.head_dim) + else: + cos_sin_cache = empty_bf16(4096, self.head_dim) + return [ + qkv, + positions, + q_weight, + k_weight, + cos_sin_cache, + ] + + @staticmethod + def wrap_trace_fn(trace_fn, *process_fx_fns: Callable[[fx.GraphModule], None]): + def wrapped(*args, **kwargs): + gm = trace_fn(*args, **kwargs) + for process_fx in process_fx_fns: + process_fx(gm) + + return gm + + return wrapped + + @staticmethod + def fx_view_to_reshape(gm: torch.fx.GraphModule): + from torch._inductor.fx_passes.post_grad import view_to_reshape + + view_to_reshape(gm) + + def register(self, pm_pass: PatternMatcherPass): + def pattern( + qkv: torch.Tensor, + positions: torch.Tensor, + q_weight: torch.Tensor, + k_weight: torch.Tensor, + cos_sin_cache: torch.Tensor, + ): + # split qkv -> q,k,v + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + + # Q path: view -> RMS -> view back to q.shape + q_by_head = q.view( + *q.shape[:-1], q.shape[-1] // self.head_dim, self.head_dim + ) + q_normed_by_head = self.rmsnorm_matcher(q_by_head, q_weight) + q_flat = q_normed_by_head.view(q.shape) + + # K path: view -> RMS -> view back to k.shape + k_by_head = k.view( + *k.shape[:-1], k.shape[-1] // self.head_dim, self.head_dim + ) + k_normed_by_head = self.rmsnorm_matcher(k_by_head, k_weight) + k_flat = k_normed_by_head.view(k.shape) + + # RoPE: apply to flattened q/k + q_rope, k_rope = self.rope_matcher(positions, q_flat, k_flat, cos_sin_cache) + return q_rope, k_rope, v + + def replacement( + qkv: torch.Tensor, + positions: torch.Tensor, + q_weight: torch.Tensor, + k_weight: torch.Tensor, + cos_sin_cache: torch.Tensor, + ): + # Run fused qk_norm_rope op + result = auto_functionalized( + FUSED_QK_ROPE_OP, + qkv=qkv, + num_heads_q=self.num_heads, + num_heads_k=self.num_kv_heads, + num_heads_v=self.num_kv_heads, + head_dim=self.head_dim, + eps=self.eps, + q_weight=q_weight, + k_weight=k_weight, + cos_sin_cache=cos_sin_cache, + is_neox=self.is_neox, + position_ids=positions.view(-1), + ) + result_qkv = result[1] + + # Split back to q,k,v and return + return result_qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + + # NOTE: use fx_view_to_reshape to unify view/reshape to simplify + # pattern and increase matching opportunities + pm.register_replacement( + pattern, + replacement, + self.get_inputs(), + QkNormRopePattern.wrap_trace_fn( + pm.fwd_only, + QkNormRopePattern.fx_view_to_reshape, + ), + pm_pass, + ) + + +class QKNormRoPEFusionPass(VllmPatternMatcherPass): + """Fuse Q/K RMSNorm + RoPE into fused_qk_norm_rope when the custom op exists.""" + + @enable_fake_mode + def __init__(self, config: VllmConfig): + super().__init__(config) + self.patterns: PatternMatcherPass = PatternMatcherPass( + pass_name="qk_norm_rope_fusion_pass" + ) + + dtype = config.model_config.dtype + if dtype not in (torch.bfloat16, torch.float16): + logger.warning_once( + "QK Norm+RoPE fusion not enabled: unsupported dtype %s", dtype + ) + return + + # use one attn layer to get meta (such as head_dim) for QkNormRopePattern + attn_layers: dict[str, Attention] = get_layers_from_vllm_config( + config, Attention + ) + if len(attn_layers) == 0: + logger.warning_once( + "QK Norm+RoPE fusion enabled, but no Attention layers were discovered." + ) + return + layer = next(iter(attn_layers.values())) + + for epsilon in [1e-5, 1e-6]: + for neox in [True, False]: + if RotaryEmbedding.enabled(): + for rope_flashinfer in [False, True]: + QkNormRopePattern( + head_dim=layer.head_size, + num_heads=layer.num_heads, + num_kv_heads=layer.num_kv_heads, + eps=epsilon, + is_neox=neox, + rope_flashinfer=rope_flashinfer, + ).register(self.patterns) + else: + QkNormRopePattern( + head_dim=layer.head_size, + num_heads=layer.num_heads, + num_kv_heads=layer.num_kv_heads, + eps=epsilon, + is_neox=neox, + ).register(self.patterns) + + self.dump_patterns(config, self.patterns) + + @VllmInductorPass.time_and_log + def __call__(self, graph: fx.Graph) -> None: + self.matched_count = self.patterns.apply(graph) + logger.debug("Fused QK Norm+RoPE on %s sites", self.matched_count) + + def uuid(self): + return VllmInductorPass.hash_source(self, QkNormRopePattern) diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index 92cf16f259fe..9c9557df4e73 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -129,6 +129,8 @@ class PassConfig: 8: 1, # 1MB }, }, where key is the device capability""" + enable_qk_norm_rope_fusion: bool = False + """Whether to enable the fused Q/K RMSNorm + RoPE pass.""" # TODO(luka) better pass enabling system. @@ -182,6 +184,12 @@ def __post_init__(self) -> None: "Fusion enabled but reshape elimination disabled. " "Allreduce + rms norm + quant (fp8) fusion might not work" ) + if self.enable_qk_norm_rope_fusion and not current_platform.is_cuda(): + logger.warning_once( + "QK Norm + RoPE fusion enabled but the current platform is not " + "CUDA. The fusion will be disabled." + ) + self.enable_qk_norm_rope_fusion = False @config @@ -640,6 +648,11 @@ def __post_init__(self) -> None: if isinstance(self.pass_config, dict): self.pass_config = PassConfig(**self.pass_config) + if self.pass_config.enable_qk_norm_rope_fusion: + # TODO(zhuhaoran): support rope native forward match and remove this. + # Linked issue: https://github.com/vllm-project/vllm/issues/28042 + self.custom_ops.append("+rotary_embedding") + if ( is_torch_equal_or_newer("2.9.0.dev") and "combo_kernels" not in self.inductor_compile_config diff --git a/vllm/model_executor/layers/rotary_embedding/base.py b/vllm/model_executor/layers/rotary_embedding/base.py index 2ef54e75df44..ce4f40680b0a 100644 --- a/vllm/model_executor/layers/rotary_embedding/base.py +++ b/vllm/model_executor/layers/rotary_embedding/base.py @@ -98,35 +98,56 @@ def __init__( head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype ) - def forward_native( - self, + @staticmethod + def forward_static( positions: torch.Tensor, query: torch.Tensor, - key: torch.Tensor | None = None, + key: torch.Tensor | None, + head_size: int, + rotary_dim: int, + cos_sin_cache: torch.Tensor, + is_neox_style: bool, ) -> tuple[torch.Tensor, torch.Tensor | None]: """A PyTorch-native implementation of forward().""" positions = positions.flatten() num_tokens = positions.shape[0] - cos_sin = self.cos_sin_cache.index_select(0, positions) + cos_sin = cos_sin_cache.index_select(0, positions) cos, sin = cos_sin.chunk(2, dim=-1) query_shape = query.shape - query = query.view(num_tokens, -1, self.head_size) - query_rot = query[..., : self.rotary_dim] - query_pass = query[..., self.rotary_dim :] - query_rot = apply_rotary_emb_torch(query_rot, cos, sin, self.is_neox_style) + query = query.view(num_tokens, -1, head_size) + query_rot = query[..., :rotary_dim] + query_pass = query[..., rotary_dim:] + query_rot = apply_rotary_emb_torch(query_rot, cos, sin, is_neox_style) query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape) # key may be None in some cases, e.g. cross-layer KV sharing if key is not None: key_shape = key.shape - key = key.view(num_tokens, -1, self.head_size) - key_rot = key[..., : self.rotary_dim] - key_pass = key[..., self.rotary_dim :] - key_rot = apply_rotary_emb_torch(key_rot, cos, sin, self.is_neox_style) + key = key.view(num_tokens, -1, head_size) + key_rot = key[..., :rotary_dim] + key_pass = key[..., rotary_dim:] + key_rot = apply_rotary_emb_torch(key_rot, cos, sin, is_neox_style) key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape) return query, key + def forward_native( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor | None]: + """A PyTorch-native implementation of forward().""" + return self.forward_static( + positions, + query, + key, + self.head_size, + self.rotary_dim, + self.cos_sin_cache, + self.is_neox_style, + ) + def forward_cuda( self, positions: torch.Tensor, From 05576df85c5274ee3045d90b0779d4adeecc09b9 Mon Sep 17 00:00:00 2001 From: xuebwang-amd Date: Wed, 12 Nov 2025 01:05:22 +0800 Subject: [PATCH 066/183] [ROCm][Quantization] extend AMD Quark to support mixed-precision quantized model (#24239) Signed-off-by: xuebwang-amd Co-authored-by: fxmarty-amd Co-authored-by: Cyrus Leung --- docs/features/quantization/quark.md | 34 ++++++++- tests/quantization/test_mixed_precision.py | 69 +++++++++++++++++++ .../layers/quantization/quark/quark.py | 32 +++++++-- 3 files changed, 127 insertions(+), 8 deletions(-) create mode 100755 tests/quantization/test_mixed_precision.py diff --git a/docs/features/quantization/quark.md b/docs/features/quantization/quark.md index 385e3bbb8712..be0702f4c9e1 100644 --- a/docs/features/quantization/quark.md +++ b/docs/features/quantization/quark.md @@ -281,4 +281,36 @@ python quantize_quark.py --model_dir Qwen/Qwen1.5-MoE-A2.7B-Chat \ --group_size 32 ``` -The current integration supports [all combination of FP4, FP6_E3M2, FP6_E2M3](https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/utils/ocp_mx_utils.py) used for either weights or activations. Eventually, some target hardware support mixed precision GEMM, as AMD Instinct MI350/MI355, for example using FP6 for activations and FP4 for weights. +The current integration supports [all combination of FP4, FP6_E3M2, FP6_E2M3](https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/utils/ocp_mx_utils.py) used for either weights or activations. + +## Using Quark Quantized layerwise Auto Mixed Precision (AMP) Models + +vLLM also supports loading layerwise mixed precision model quantized using AMD Quark. Currently, mixed scheme of {MXFP4, FP8} is supported, where FP8 here denotes for FP8 per-tensor scheme. More mixed precision schemes are planned to be supported in a near future, including + +- Unquantized Linear and/or MoE layer(s) as an option for each layer, i.e., mixed of {MXFP4, FP8, BF16/FP16} +- MXFP6 quantization extension, i.e., {MXFP4, MXFP6, FP8, BF16/FP16} + +Although one can maximize serving throughput using the lowest precision supported on a given device (e.g. MXFP4 for AMD Instinct MI355, FP8 for AMD Instinct MI300), these aggressive schemes can be detrimental to accuracy recovering from quantization on target tasks. Mixed precision allows to strike a balance between maximizing accuracy and throughput. + +There are two steps to generate and deploy a mixed precision model quantized with AMD Quark, as shown below. + +### 1. Quantize a model using mixed precision in AMD Quark + +Firstly, the layerwise mixed-precision configuration for a given LLM model is searched and then quantized using AMD Quark. We will provide a detailed tutorial with Quark APIs later. + +As examples, we provide some ready-to-use quantized mixed precision model to show the usage in vLLM and the accuracy benifits. They are: + +- amd/Llama-2-70b-chat-hf-WMXFP4FP8-AMXFP4FP8-AMP-KVFP8 +- amd/Mixtral-8x7B-Instruct-v0.1-WMXFP4FP8-AMXFP4FP8-AMP-KVFP8 +- amd/Qwen3-8B-WMXFP4FP8-AMXFP4FP8-AMP-KVFP8 + +### 2. inference the quantized mixed precision model in vLLM + +Models quantized with AMD Quark using mixed precision can natively be reload in vLLM, and e.g. evaluated using lm-evaluation-harness as follow: + +```bash +lm_eval --model vllm \ + --model_args pretrained=amd/Llama-2-70b-chat-hf-WMXFP4FP8-AMXFP4FP8-AMP-KVFP8,tensor_parallel_size=4,dtype=auto,gpu_memory_utilization=0.8,trust_remote_code=False \ + --tasks mmlu \ + --batch_size auto +``` diff --git a/tests/quantization/test_mixed_precision.py b/tests/quantization/test_mixed_precision.py new file mode 100755 index 000000000000..51526470b423 --- /dev/null +++ b/tests/quantization/test_mixed_precision.py @@ -0,0 +1,69 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Test quark-quantized {MXFP4, FP8} mixed precision models. + +Run `pytest tests/quantization/test_mixed_precision.py`. + +""" + +import importlib +import importlib.metadata +from dataclasses import dataclass + +import lm_eval +import pytest +from packaging import version + +QUARK_MXFP4_AVAILABLE = importlib.util.find_spec("quark") is not None and version.parse( + importlib.metadata.version("amd-quark") +) >= version.parse("0.8.99") + + +@dataclass +class ModelCase: + model_id: str + tp: int + + +@dataclass +class EvaluationConfig: + model_name: str + + def get_model_args(self) -> str: + return ( + f"pretrained={self.model_name}," + "tensor_parallel_size=4,dtype=auto,gpu_memory_utilization=0.8,trust_remote_code=False" + ) + + +TEST_CONFIGS = { + # Mixed-precision (AMP) model + # - Demonstrates end-to-end pipeline functionality + "amd/Qwen3-8B-WMXFP4FP8-AMXFP4FP8-AMP-KVFP8": {"arc_challenge": 0.52, "mmlu": 0.72}, + # Non-mixed-precision (PTQ) model + # - Reference for pipeline compatibility verification -> No conflicts or breakings + "amd/Llama-2-70b-chat-hf-FP8-MLPerf-fp8_attn_quark_format": { + "arc_challenge": 0.53, + "mmlu": 0.61, + }, +} + + +@pytest.mark.parametrize("model_name, accuracy_numbers", TEST_CONFIGS.items()) +@pytest.mark.skipif(not QUARK_MXFP4_AVAILABLE, reason="amd-quark>=0.9 is not available") +def test_mixed_precision_model_accuracies(model_name: str, accuracy_numbers: dict): + results = lm_eval.simple_evaluate( + model="vllm", + model_args=EvaluationConfig(model_name).get_model_args(), + tasks=list(accuracy_numbers.keys()), + batch_size=8, + ) + + rtol = 0.05 + + for task, expect_accuracy in accuracy_numbers.items(): + measured_accuracy = results["results"][task]["acc,none"] + assert ( + measured_accuracy - rtol < expect_accuracy + and measured_accuracy + rtol > expect_accuracy + ), f"Expected: {expect_accuracy} | Measured: {measured_accuracy}" diff --git a/vllm/model_executor/layers/quantization/quark/quark.py b/vllm/model_executor/layers/quantization/quark/quark.py index d5459594b798..095a66ef10f9 100644 --- a/vllm/model_executor/layers/quantization/quark/quark.py +++ b/vllm/model_executor/layers/quantization/quark/quark.py @@ -114,7 +114,14 @@ def from_config(cls, config: dict[str, Any]) -> "QuarkConfig": layer_quant_names = list(layer_quant_config.keys()) layer_quant_set = set(layer_quant_names) - if not kv_cache_set.issubset(layer_quant_set): + if not ( + kv_cache_set.issubset(layer_quant_set) + or any( + fnmatch.fnmatchcase(layer_quant, pat) + for layer_quant in list(layer_quant_set) + for pat in list(kv_cache_set) + ) + ): raise ValueError( "The Quark quantized model has the " "kv_cache_group parameter setting, " @@ -124,10 +131,15 @@ def from_config(cls, config: dict[str, Any]) -> "QuarkConfig": ) q_configs = [ - cast(dict[str, Any], layer_quant_config.get(name)) - for name in kv_cache_group + quant_cfg + for name, quant_cfg in layer_quant_config.items() + if any(fnmatch.fnmatchcase(name, pattern) for pattern in kv_cache_group) ] - if not all(deep_compare(q_config, q_configs[0]) for q_config in q_configs): + + if not all( + deep_compare(q_config["output_tensors"], q_configs[0]["output_tensors"]) + for q_config in q_configs + ): raise ValueError( "The quantization method used for kv_cache should " "be the same, but the quantization method for the " @@ -312,9 +324,15 @@ def _find_matched_config( layer_quant_config = cast( dict[str, Any], self.quant_config.get("layer_quant_config") ) - for name_pattern in layer_quant_config: - if fnmatch.fnmatch(layer_name, name_pattern): - return layer_quant_config[name_pattern] + + def _matches_pattern(layer_name, pattern): + if "*" not in pattern: + return layer_name in pattern + return fnmatch.fnmatch(layer_name, pattern) + + for name_pattern, config in layer_quant_config.items(): + if _matches_pattern(layer_name, name_pattern): + return config layer_type = cast(str, type(module)) layer_type_quant_config = cast( From 5a1271d83a65be5ed8dc3e4c990ed42074197db3 Mon Sep 17 00:00:00 2001 From: xuebwang-amd Date: Wed, 12 Nov 2025 01:06:00 +0800 Subject: [PATCH 067/183] [Quantization] fix attention quantization of gpt_oss model (#27334) Signed-off-by: xuebwang-amd --- .../test_gpt_oss_attn_quantization.py | 80 +++++++++++++++++++ .../layers/quantization/mxfp4.py | 15 +++- vllm/model_executor/models/gpt_oss.py | 10 ++- 3 files changed, 101 insertions(+), 4 deletions(-) create mode 100644 tests/models/quantization/test_gpt_oss_attn_quantization.py diff --git a/tests/models/quantization/test_gpt_oss_attn_quantization.py b/tests/models/quantization/test_gpt_oss_attn_quantization.py new file mode 100644 index 000000000000..780165ea2ba7 --- /dev/null +++ b/tests/models/quantization/test_gpt_oss_attn_quantization.py @@ -0,0 +1,80 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Test attention quantization of gpt-oss model. +The qkv_proj and o_proj in self_attention can be either quantized or excluded. + +Run `pytest tests/models/quantization/test_gpt_oss_attn_quantization.py`. + +""" + +import importlib +import importlib.metadata +from dataclasses import dataclass + +import huggingface_hub +import lm_eval +import pytest +from packaging import version + +MODEL_NAMES = ["amd/gpt-oss-20b-customized-attention-quantization"] + +QUARK_MXFP4_AVAILABLE = importlib.util.find_spec("quark") is not None and version.parse( + importlib.metadata.version("amd-quark") +) >= version.parse("0.8.99") + + +def has_huggingface_access(repo): + try: + huggingface_hub.list_repo_refs(repo) + return True + except huggingface_hub.errors.RepositoryNotFoundError: + return False + + +HF_HUB_AMD_ORG_ACCESS = all( + [has_huggingface_access(model_name) for model_name in MODEL_NAMES] +) + + +@dataclass +class ModelCase: + model_id: str + tp: int + + +@dataclass +class EvaluationConfig: + model_name: str + + def get_model_args(self) -> str: + return ( + f"pretrained={self.model_name}," + "tensor_parallel_size=4,dtype=auto,gpu_memory_utilization=0.9,trust_remote_code=False" + ) + + +EXPECTED_ACCURACIES = {"arc_challenge": 0.20} + + +@pytest.mark.skipif(not QUARK_MXFP4_AVAILABLE, reason="amd-quark>=0.9 is not available") +@pytest.mark.skipif( + not HF_HUB_AMD_ORG_ACCESS, + reason="Read access to huggingface.co/amd is required for this test.", +) +@pytest.mark.parametrize("model_name", MODEL_NAMES) +@pytest.mark.parametrize("task_name, expected_accuracy", EXPECTED_ACCURACIES.items()) +def test_gpt_oss_attention_quantization( + model_name: str, task_name: str, expected_accuracy: float +): + measured_accuracy = lm_eval.simple_evaluate( + model="vllm", + model_args=EvaluationConfig(model_name).get_model_args(), + tasks=task_name, + batch_size="auto", + )["results"][task_name]["acc,none"] + + rtol = 0.05 + assert ( + measured_accuracy - rtol < expected_accuracy + and measured_accuracy + rtol > expected_accuracy + ), f"Expected: {expected_accuracy} | Measured: {measured_accuracy}" diff --git a/vllm/model_executor/layers/quantization/mxfp4.py b/vllm/model_executor/layers/quantization/mxfp4.py index 4e51249f2d25..8d7297a0a1b3 100644 --- a/vllm/model_executor/layers/quantization/mxfp4.py +++ b/vllm/model_executor/layers/quantization/mxfp4.py @@ -190,14 +190,25 @@ def get_quant_method( fused_mapping=self.packed_modules_mapping, ): return UnquantizedLinearMethod() - raise NotImplementedError("Mxfp4 linear layer is not implemented") + # TODO: Add support for MXFP4 Linear Method. + # MXFP4 LinearMethod is available in AMD-Quark, refer to that implementation + # if you are interested in enabling MXFP4 here. + logger.warning_once( + "MXFP4 linear layer is not implemented - falling back to " + "UnquantizedLinearMethod." + ) + return UnquantizedLinearMethod() elif isinstance(layer, FusedMoE): if current_platform.is_xpu(): return IpexMxfp4MoEMethod(layer.moe_config) else: return Mxfp4MoEMethod(layer.moe_config) elif isinstance(layer, Attention): - raise NotImplementedError("Mxfp4 attention layer is not implemented") + # TODO: Add support for MXFP4 Attention. + logger.warning_once( + "MXFP4 attention layer is not implemented. " + "Skipping quantization for this layer." + ) return None diff --git a/vllm/model_executor/models/gpt_oss.py b/vllm/model_executor/models/gpt_oss.py index 04038ae74882..291ac833f26a 100644 --- a/vllm/model_executor/models/gpt_oss.py +++ b/vllm/model_executor/models/gpt_oss.py @@ -198,6 +198,7 @@ class TransformerBlock(torch.nn.Module): def __init__( self, vllm_config: VllmConfig, + quant_config: QuantizationConfig, prefix: str = "", ): super().__init__() @@ -207,7 +208,10 @@ def __init__( self.layer_idx = extract_layer_index(prefix) self.attn = OAIAttention( - config, prefix=f"{prefix}.attn", cache_config=cache_config + config, + prefix=f"{prefix}.attn", + quant_config=quant_config, + cache_config=cache_config, ) self.mlp = MLPBlock(vllm_config, self.layer_idx, prefix=f"{prefix}.mlp") self.input_layernorm = RMSNorm(config.hidden_size, eps=1e-5) @@ -243,6 +247,7 @@ def __init__( ): super().__init__() self.config = vllm_config.model_config.hf_config + self.quant_config = vllm_config.quant_config self.parallel_config = vllm_config.parallel_config self.config.hidden_size = self.config.hidden_size self.embedding = VocabParallelEmbedding( @@ -254,6 +259,7 @@ def __init__( lambda prefix: TransformerBlock( vllm_config, prefix=prefix, + quant_config=self.quant_config, ), prefix=f"{prefix}.layers", ) @@ -645,7 +651,7 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: class GptOssForCausalLM(nn.Module, SupportsPP, SupportsEagle3, SupportsLoRA): - packed_modules_mapping = {"qkv": ["q_proj", "k_proj", "v_proj"]} + packed_modules_mapping = {"qkv_proj": ["q_proj", "k_proj", "v_proj"]} hf_to_vllm_mapper = WeightsMapper( orig_to_new_substr={ From e55342491968a56d39dc8e03e6cf39d12fef5dcd Mon Sep 17 00:00:00 2001 From: Zhewen Li Date: Tue, 11 Nov 2025 09:09:47 -0800 Subject: [PATCH 068/183] [CI/Build] Refactor Attention backend for test_prefix_prefill from xformers to SDPA (#28424) Signed-off-by: zhewenli Signed-off-by: Roger Wang Co-authored-by: Roger Wang --- .../kernels/attention/test_prefix_prefill.py | 310 +++++++++++------- 1 file changed, 194 insertions(+), 116 deletions(-) diff --git a/tests/kernels/attention/test_prefix_prefill.py b/tests/kernels/attention/test_prefix_prefill.py index 65972d02f2f6..78cdbbbf7379 100644 --- a/tests/kernels/attention/test_prefix_prefill.py +++ b/tests/kernels/attention/test_prefix_prefill.py @@ -8,10 +8,8 @@ import pytest import torch -from xformers import ops as xops -from xformers.ops.fmha.attn_bias import BlockDiagonalCausalFromBottomRightMask +import torch.nn.functional as F -from tests.kernels.utils import make_alibi_bias from vllm.attention.ops.chunked_prefill_paged_decode import chunked_prefill_paged_decode from vllm.attention.ops.prefix_prefill import context_attention_fwd from vllm.platforms import current_platform @@ -28,6 +26,74 @@ OPS = [chunked_prefill_paged_decode, context_attention_fwd] +def create_causal_attention_mask_for_sdpa( + query_lens: list[int], + seq_lens: list[int], + sliding_window: int = 0, + device: torch.device = None, + dtype: torch.dtype = None, +) -> torch.Tensor: + total_queries = sum(query_lens) + total_keys = sum(seq_lens) + + # Create a mask filled with -inf + mask = torch.full( + (total_queries, total_keys), float("-inf"), device=device, dtype=dtype + ) + + query_start = 0 + key_start = 0 + + for query_len, seq_len in zip(query_lens, seq_lens): + query_end = query_start + query_len + key_end = key_start + seq_len + q_indices = torch.arange(query_len, device=device) + k_indices = torch.arange(seq_len, device=device) + q_pos_in_seq = seq_len - query_len + q_indices + + valid_mask = k_indices[None, :] <= q_pos_in_seq[:, None] + + if sliding_window > 0: + valid_mask &= k_indices[None, :] >= ( + q_pos_in_seq[:, None] - sliding_window + 1 + ) + + mask[query_start:query_end, key_start:key_end][valid_mask] = 0.0 + + query_start = query_end + key_start = key_end + + return mask + + +def create_alibi_causal_mask( + query_len: int, + seq_len: int, + alibi_slopes: torch.Tensor, + device: torch.device, + dtype: torch.dtype, +) -> torch.Tensor: + query_pos = torch.arange( + seq_len - query_len, seq_len, device=device, dtype=torch.float32 + ) + key_pos = torch.arange(seq_len, device=device, dtype=torch.float32) + + rel_pos = key_pos[None, :] - query_pos[:, None] + + # Apply ALiBi slopes: [num_heads, query_len, seq_len] + alibi_bias = alibi_slopes[:, None, None] * rel_pos[None, :, :] + alibi_bias = alibi_bias.to(dtype) + + # Apply causal mask: prevent attending to future positions + # causal_mask[i, j] = True if key_pos[j] <= query_pos[i] + causal_mask = key_pos[None, :] <= query_pos[:, None] + alibi_bias = alibi_bias.masked_fill(~causal_mask[None, :, :], float("-inf")) + + # Add batch dimension: [1, num_heads, query_len, seq_len] + # SDPA expects batch dimension even for single sequences + return alibi_bias.unsqueeze(0) + + @pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("num_queries_per_kv", NUM_QUERIES_PER_KV) @pytest.mark.parametrize("head_size", HEAD_SIZES) @@ -52,6 +118,13 @@ def test_contexted_kv_attention( "Triton limitation: fp8e4nv data type is not supported on CUDA arch < 89" ) + if ( + current_platform.is_rocm() + and op is chunked_prefill_paged_decode + and kv_cache_dtype == "fp8_e5m2" + ): + pytest.skip("ROCm custom paged attention does not support fp8_e5m2 KV cache") + current_platform.seed_everything(0) torch.set_default_device(device) @@ -96,16 +169,16 @@ def test_contexted_kv_attention( ) k = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype) v = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype) - values = torch.arange(0, cache_size, dtype=torch.long) + values = torch.arange(0, cache_size, dtype=torch.int32) values = values[torch.randperm(cache_size)] block_table = values[: BS * max_block_per_request].view(BS, max_block_per_request) - b_seq_len = torch.tensor(seq_lens, dtype=torch.long) - b_ctx_len = torch.tensor(ctx_lens, dtype=torch.long) - b_start_loc = torch.cumsum(torch.tensor([0] + query_lens, dtype=torch.long), dim=0) + b_seq_len = torch.tensor(seq_lens, dtype=torch.int32) + b_ctx_len = torch.tensor(ctx_lens, dtype=torch.int32) + b_start_loc = torch.cumsum(torch.tensor([0] + query_lens, dtype=torch.int32), dim=0) max_input_len = MAX_SEQ_LEN # copy kv to cache b_seq_start_loc = torch.cumsum( - torch.tensor([0] + seq_lens[:-1], dtype=torch.long), dim=0 + torch.tensor([0] + seq_lens[:-1], dtype=torch.int32), dim=0 ) for i in range(BS): for j in range(query_lens[i]): @@ -189,56 +262,57 @@ def test_contexted_kv_attention( scale = float(1.0 / (head_size**0.5)) - attn_op = xops.fmha.cutlass.FwOp() + # Reshape for SDPA: (seq_len, num_heads, head_size) -> + # (1, num_heads, seq_len, head_size) + query_sdpa = query.view(num_tokens, num_kv_heads, num_queries_per_kv, head_size) + query_sdpa = query_sdpa.permute(1, 2, 0, 3).reshape( + 1, num_heads, num_tokens, head_size + ) - if num_kv_heads != num_heads: - # As of Nov 2023, xformers only supports MHA. For MQA/GQA, - # project the key and value tensors to the desired number of - # heads. - # - # see also: vllm/model_executor/layers/attention.py - query = query.view( - query.shape[0], num_kv_heads, num_queries_per_kv, query.shape[-1] - ) - key = key[:, :, None, :].expand( - key.shape[0], num_kv_heads, num_queries_per_kv, key.shape[-1] - ) - value = value[:, :, None, :].expand( - value.shape[0], num_kv_heads, num_queries_per_kv, value.shape[-1] - ) - query = query.unsqueeze(0) - key = key.unsqueeze(0) - value = value.unsqueeze(0) + # Expand key and value for GQA/MQA to match query heads + key_sdpa = key[:, :, None, :].expand( + key.shape[0], num_kv_heads, num_queries_per_kv, key.shape[-1] + ) + key_sdpa = key_sdpa.permute(1, 2, 0, 3).reshape( + 1, num_heads, sum(seq_lens), head_size + ) - attn_bias = BlockDiagonalCausalFromBottomRightMask.from_seqlens( - query_lens, seq_lens + value_sdpa = value[:, :, None, :].expand( + value.shape[0], num_kv_heads, num_queries_per_kv, value.shape[-1] ) - if sliding_window > 0: - attn_bias = attn_bias.make_local_attention_from_bottomright(sliding_window) - output_ref = xops.memory_efficient_attention_forward( - query, - key, - value, - attn_bias=attn_bias, - p=0.0, + value_sdpa = value_sdpa.permute(1, 2, 0, 3).reshape( + 1, num_heads, sum(seq_lens), head_size + ) + + attn_mask = create_causal_attention_mask_for_sdpa( + query_lens, seq_lens, sliding_window, device=device, dtype=dtype + ) + + output_ref = F.scaled_dot_product_attention( + query_sdpa, + key_sdpa, + value_sdpa, + attn_mask=attn_mask, + dropout_p=0.0, scale=scale, - op=attn_op, ) torch.cuda.synchronize() start_time = time.time() - output_ref = xops.memory_efficient_attention_forward( - query, - key, - value, - attn_bias=attn_bias, - p=0.0, + output_ref = F.scaled_dot_product_attention( + query_sdpa, + key_sdpa, + value_sdpa, + attn_mask=attn_mask, + dropout_p=0.0, scale=scale, - op=attn_op, ) torch.cuda.synchronize() end_time = time.time() - print(f"xformers Time: {(end_time - start_time) * 1000:.2f} ms") - output_ref = output_ref.reshape(output.shape) + print(f"PyTorch SDPA Time: {(end_time - start_time) * 1000:.2f} ms") + + # Reshape output back to (num_tokens, num_heads, head_size) + output_ref = output_ref.view(num_heads, num_tokens, head_size) + output_ref = output_ref.permute(1, 0, 2).contiguous() atol = 1e-3 if "fp8" in kv_cache_dtype else 1e-4 torch.testing.assert_close(output, output_ref, atol=atol, rtol=0) @@ -265,6 +339,13 @@ def test_contexted_kv_attention_alibi( "Triton limitation: fp8e4nv data type is not supported on CUDA arch < 89" ) + if ( + current_platform.is_rocm() + and op is chunked_prefill_paged_decode + and kv_cache_dtype == "fp8_e5m2" + ): + pytest.skip("ROCm custom paged attention does not support fp8_e5m2 KV cache") + current_platform.seed_everything(0) torch.set_default_device(device) @@ -331,16 +412,16 @@ def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor: ) k = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype) v = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype) - values = torch.arange(0, cache_size, dtype=torch.long) + values = torch.arange(0, cache_size, dtype=torch.int32) values = values[torch.randperm(cache_size)] block_table = values[: BS * max_block_per_request].view(BS, max_block_per_request) - b_seq_len = torch.tensor(seq_lens, dtype=torch.long) - b_ctx_len = torch.tensor(ctx_lens, dtype=torch.long) - b_start_loc = torch.cumsum(torch.tensor([0] + query_lens, dtype=torch.long), dim=0) + b_seq_len = torch.tensor(seq_lens, dtype=torch.int32) + b_ctx_len = torch.tensor(ctx_lens, dtype=torch.int32) + b_start_loc = torch.cumsum(torch.tensor([0] + query_lens, dtype=torch.int32), dim=0) max_input_len = MAX_SEQ_LEN # copy kv to cache b_seq_start_loc = torch.cumsum( - torch.tensor([0] + seq_lens[:-1], dtype=torch.long), dim=0 + torch.tensor([0] + seq_lens[:-1], dtype=torch.int32), dim=0 ) for i in range(BS): for j in range(query_lens[i]): @@ -423,78 +504,75 @@ def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor: print(f"triton Time: {(end_time - start_time) * 1000:.2f} ms") scale = float(1.0 / (head_size**0.5)) - # NOTE(DefTruth): In order to reuse _make_alibi_bias function, - # we have to pad query tensor before MQA/GQA expanding. - if query.shape[0] != key.shape[0]: - query_pad = torch.empty(sum(seq_lens), num_heads, head_size, dtype=dtype) - query_pad.uniform_(-1e-3, 1e-3) - seq_start = 0 - query_start = 0 - for i, (query_len, seq_len) in enumerate(zip(query_lens, seq_lens)): - seq_end = seq_start + seq_len - query_end = query_start + query_len - query_pad[seq_start:seq_end, ...] = torch.cat( - [ - torch.zeros(seq_len - query_len, num_heads, head_size, dtype=dtype), - query[query_start:query_end, ...], - ], - dim=0, - ) - seq_start += seq_len - query_start += query_len - query = query_pad - - if num_kv_heads != num_heads: - # As of Nov 2023, xformers only supports MHA. For MQA/GQA, - # project the key and value tensors to the desired number of - # heads. - # - # see also: vllm/model_executor/layers/attention.py - key = key[:, :, None, :].expand( - key.shape[0], num_kv_heads, num_queries_per_kv, key.shape[-1] - ) - value = value[:, :, None, :].expand( - value.shape[0], num_kv_heads, num_queries_per_kv, value.shape[-1] - ) - # [seq, num_kv_heads, num_queries_per_kv, dk]=> - # [seq, num_kv_heads*num_queries_per_kv, dk] to comply with rest of the - # codebase. We save some time reshaping alibi matrix at runtime. - key = key.reshape(key.shape[0], -1, key.shape[-1]) - value = value.reshape(value.shape[0], -1, value.shape[-1]) - query = query.unsqueeze(0) - key = key.unsqueeze(0) - value = value.unsqueeze(0) - - attn_bias = make_alibi_bias(alibi_slopes, num_kv_heads, dtype, seq_lens) + # Prepare query, key, value for SDPA + # Expand key and value for GQA/MQA to match query heads + key_expanded = key[:, :, None, :].expand( + key.shape[0], num_kv_heads, num_queries_per_kv, key.shape[-1] + ) + value_expanded = value[:, :, None, :].expand( + value.shape[0], num_kv_heads, num_queries_per_kv, value.shape[-1] + ) + output_ref = torch.empty_like(output) - seq_start = 0 - query_start = 0 + + torch.cuda.synchronize() start_time = time.time() - # Attention with alibi slopes. - # FIXME(DefTruth): Because xformers does not support dynamic sequence - # lengths with custom attention bias, we process each prompt one by - # one. This is inefficient, especially when we have many short prompts. - # modified from: vllm/v1/attention/backends/xformers.py#L343 + + query_start = 0 + key_start = 0 for i, (query_len, seq_len) in enumerate(zip(query_lens, seq_lens)): - seq_end = seq_start + seq_len query_end = query_start + query_len - out = xops.memory_efficient_attention_forward( - query[:, seq_start:seq_end], - key[:, seq_start:seq_end], - value[:, seq_start:seq_end], - attn_bias=attn_bias[i], - p=0.0, - scale=scale, + key_end = key_start + seq_len + + # Get query, key, value for this sequence + q = query[query_start:query_end] # [query_len, num_heads, head_size] + k = key_expanded[ + key_start:key_end + ] # [seq_len, num_kv_heads, num_queries_per_kv, head_size] + v = value_expanded[ + key_start:key_end + ] # [seq_len, num_kv_heads, num_queries_per_kv, head_size] + + # Reshape for SDPA: (batch=1, num_heads, seq_len, head_size) + q_sdpa = q.view(query_len, num_kv_heads, num_queries_per_kv, head_size) + q_sdpa = ( + q_sdpa.permute(1, 2, 0, 3) + .reshape(1, num_heads, query_len, head_size) + .contiguous() + ) + + k_sdpa = ( + k.permute(1, 2, 0, 3).reshape(1, num_heads, seq_len, head_size).contiguous() + ) + v_sdpa = ( + v.permute(1, 2, 0, 3).reshape(1, num_heads, seq_len, head_size).contiguous() ) - out = out.view_as(query[:, seq_start:seq_end]).view( - seq_len, num_heads, head_size + + # Create ALiBi causal mask for this sequence using utility function + alibi_mask = create_alibi_causal_mask( + query_len, seq_len, alibi_slopes, device, dtype + ) + + # Compute attention + out = F.scaled_dot_product_attention( + q_sdpa, + k_sdpa, + v_sdpa, + attn_mask=alibi_mask, + dropout_p=0.0, + scale=scale, ) - output_ref[query_start:query_end, ...].copy_(out[seq_len - query_len :, ...]) - seq_start += seq_len - query_start += query_len + + # Reshape output back to [query_len, num_heads, head_size] + out = out.view(num_heads, query_len, head_size).permute(1, 0, 2) + output_ref[query_start:query_end].copy_(out) + + query_start = query_end + key_start = key_end + torch.cuda.synchronize() end_time = time.time() - print(f"xformers Time: {(end_time - start_time) * 1000:.2f} ms") + print(f"PyTorch SDPA Time: {(end_time - start_time) * 1000:.2f} ms") atol = 1e-3 if "fp8" in kv_cache_dtype else 1e-6 torch.testing.assert_close(output, output_ref, atol=atol, rtol=0) From 684f2545851ee0ee49be9a80545ed497324f1a96 Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Tue, 11 Nov 2025 11:13:51 -0600 Subject: [PATCH 069/183] Prefer FlashAttention MLA as default over FlashMLA (#27363) Signed-off-by: Matthew Bonanni --- vllm/platforms/cuda.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index 43daf5e75b66..22c6dde754d0 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -55,15 +55,15 @@ def _get_backend_priorities( return [ AttentionBackendEnum.CUTLASS_MLA, AttentionBackendEnum.FLASHINFER_MLA, - AttentionBackendEnum.FLASHMLA, AttentionBackendEnum.FLASH_ATTN_MLA, + AttentionBackendEnum.FLASHMLA, AttentionBackendEnum.TRITON_MLA, AttentionBackendEnum.FLASHMLA_SPARSE, ] else: return [ - AttentionBackendEnum.FLASHMLA, AttentionBackendEnum.FLASH_ATTN_MLA, + AttentionBackendEnum.FLASHMLA, AttentionBackendEnum.FLASHINFER_MLA, AttentionBackendEnum.TRITON_MLA, AttentionBackendEnum.FLASHMLA_SPARSE, From 6c3c0f8235cacce28982687e362b80d953ea7617 Mon Sep 17 00:00:00 2001 From: Xin Yang <105740670+xyang16@users.noreply.github.com> Date: Tue, 11 Nov 2025 10:02:23 -0800 Subject: [PATCH 070/183] [Kernel] Optimize rms_norm kernel (#27931) Signed-off-by: Xin Yang --- csrc/dispatch_utils.h | 29 ++++++++++++++++++++++ csrc/layernorm_kernels.cu | 39 +++++++++++++++++++++--------- csrc/layernorm_quant_kernels.cu | 43 ++++++++++++++++++++++----------- 3 files changed, 86 insertions(+), 25 deletions(-) diff --git a/csrc/dispatch_utils.h b/csrc/dispatch_utils.h index 995374a50b03..9ae0ed975edd 100644 --- a/csrc/dispatch_utils.h +++ b/csrc/dispatch_utils.h @@ -88,3 +88,32 @@ #define VLLM_DISPATCH_INTEGRAL_AND_UNSIGNED_TYPES(TYPE, NAME, ...) \ AT_DISPATCH_SWITCH( \ TYPE, NAME, VLLM_DISPATCH_CASE_INTEGRAL_AND_UNSIGNED_TYPES(__VA_ARGS__)) + +#define VLLM_DISPATCH_VEC_SIZE(VEC_SIZE, ...) \ + switch (VEC_SIZE) { \ + case 16: { \ + constexpr int vec_size = 16; \ + __VA_ARGS__(); \ + break; \ + } \ + case 8: { \ + constexpr int vec_size = 8; \ + __VA_ARGS__(); \ + break; \ + } \ + case 4: { \ + constexpr int vec_size = 4; \ + __VA_ARGS__(); \ + break; \ + } \ + case 2: { \ + constexpr int vec_size = 2; \ + __VA_ARGS__(); \ + break; \ + } \ + default: { \ + constexpr int vec_size = 1; \ + __VA_ARGS__(); \ + break; \ + } \ + } diff --git a/csrc/layernorm_kernels.cu b/csrc/layernorm_kernels.cu index 8cfcf9f41283..48771e4b3aff 100644 --- a/csrc/layernorm_kernels.cu +++ b/csrc/layernorm_kernels.cu @@ -10,7 +10,7 @@ namespace vllm { // TODO(woosuk): Further optimize this kernel. -template +template __global__ void rms_norm_kernel( scalar_t* __restrict__ out, // [..., hidden_size] const scalar_t* __restrict__ input, // [..., hidden_size] @@ -21,7 +21,6 @@ __global__ void rms_norm_kernel( float variance = 0.0f; const scalar_t* input_row = input + blockIdx.x * input_stride; - constexpr int VEC_SIZE = 8; auto vec_op = [&variance](const vec_n_t& vec) { #pragma unroll for (int i = 0; i < VEC_SIZE; ++i) { @@ -45,10 +44,20 @@ __global__ void rms_norm_kernel( } __syncthreads(); - for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { - float x = (float)input[blockIdx.x * input_stride + idx]; - out[blockIdx.x * hidden_size + idx] = - ((scalar_t)(x * s_variance)) * weight[idx]; + scalar_t* out_row = out + blockIdx.x * hidden_size; + auto* v_in = reinterpret_cast*>(input_row); + auto* v_w = reinterpret_cast*>(weight); + auto* v_out = reinterpret_cast*>(out_row); + for (int i = threadIdx.x; i < hidden_size / VEC_SIZE; i += blockDim.x) { + vec_n_t dst; + vec_n_t src1 = v_in[i]; + vec_n_t src2 = v_w[i]; +#pragma unroll + for (int j = 0; j < VEC_SIZE; j++) { + float x = static_cast(src1.val[j]); + dst.val[j] = ((scalar_t)(x * s_variance)) * src2.val[j]; + } + v_out[i] = dst; } } @@ -168,16 +177,24 @@ void rms_norm(torch::Tensor& out, // [..., hidden_size] int num_tokens = input_view.numel() / hidden_size; int64_t input_stride = input_view.stride(-2); + // For large num_tokens, use smaller blocks to increase SM concurrency. + const int max_block_size = (num_tokens < 256) ? 1024 : 256; dim3 grid(num_tokens); - dim3 block(std::min(hidden_size, 1024)); const at::cuda::OptionalCUDAGuard device_guard(device_of(input_view)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); VLLM_DISPATCH_FLOATING_TYPES( input_view.scalar_type(), "rms_norm_kernel", [&] { - vllm::rms_norm_kernel<<>>( - out.data_ptr(), input_view.data_ptr(), - input_stride, weight.data_ptr(), epsilon, num_tokens, - hidden_size); + const int calculated_vec_size = + std::gcd(16 / sizeof(scalar_t), hidden_size); + const int block_size = + std::min(hidden_size / calculated_vec_size, max_block_size); + dim3 block(block_size); + VLLM_DISPATCH_VEC_SIZE(calculated_vec_size, [&] { + vllm::rms_norm_kernel<<>>( + out.data_ptr(), input_view.data_ptr(), + input_stride, weight.data_ptr(), epsilon, num_tokens, + hidden_size); + }); }); } diff --git a/csrc/layernorm_quant_kernels.cu b/csrc/layernorm_quant_kernels.cu index 0f7f034ee180..0880b8d50a79 100644 --- a/csrc/layernorm_quant_kernels.cu +++ b/csrc/layernorm_quant_kernels.cu @@ -18,7 +18,7 @@ namespace vllm { // TODO(woosuk): Further optimize this kernel. -template +template __global__ void rms_norm_static_fp8_quant_kernel( fp8_type* __restrict__ out, // [..., hidden_size] const scalar_t* __restrict__ input, // [..., hidden_size] @@ -31,7 +31,6 @@ __global__ void rms_norm_static_fp8_quant_kernel( const scalar_t* input_row = input + blockIdx.x * input_stride; - constexpr int VEC_SIZE = 8; auto vec_op = [&variance](const vec_n_t& vec) { #pragma unroll for (int i = 0; i < VEC_SIZE; ++i) { @@ -58,11 +57,18 @@ __global__ void rms_norm_static_fp8_quant_kernel( // invert scale to avoid division float const scale_inv = 1.0f / *scale; - for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { - float x = (float)input[blockIdx.x * input_stride + idx]; - float const out_norm = ((scalar_t)(x * s_variance)) * weight[idx]; - out[blockIdx.x * hidden_size + idx] = - scaled_fp8_conversion(out_norm, scale_inv); + auto* v_in = reinterpret_cast*>(input_row); + auto* v_w = reinterpret_cast*>(weight); + for (int idx = threadIdx.x; idx < hidden_size / VEC_SIZE; idx += blockDim.x) { + vec_n_t src1 = v_in[idx]; + vec_n_t src2 = v_w[idx]; +#pragma unroll + for (int j = 0; j < VEC_SIZE; j++) { + float x = static_cast(src1.val[j]); + float const out_norm = ((scalar_t)(x * s_variance)) * src2.val[j]; + out[blockIdx.x * hidden_size + idx * VEC_SIZE + j] = + scaled_fp8_conversion(out_norm, scale_inv); + } } } @@ -188,20 +194,29 @@ void rms_norm_static_fp8_quant(torch::Tensor& out, // [..., hidden_size] int input_stride = input.stride(-2); int num_tokens = input.numel() / hidden_size; + // For large num_tokens, use smaller blocks to increase SM concurrency. + const int max_block_size = (num_tokens < 256) ? 1024 : 256; dim3 grid(num_tokens); - dim3 block(std::min(hidden_size, 1024)); const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); VLLM_DISPATCH_FLOATING_TYPES( input.scalar_type(), "rms_norm_kernel_scalar_type", [&] { VLLM_DISPATCH_FP8_TYPES( out.scalar_type(), "rms_norm_kernel_fp8_type", [&] { - vllm::rms_norm_static_fp8_quant_kernel - <<>>( - out.data_ptr(), input.data_ptr(), - input_stride, weight.data_ptr(), - scale.data_ptr(), epsilon, num_tokens, - hidden_size); + const int calculated_vec_size = + std::gcd(16 / sizeof(scalar_t), hidden_size); + const int block_size = + std::min(hidden_size / calculated_vec_size, max_block_size); + dim3 block(block_size); + VLLM_DISPATCH_VEC_SIZE(calculated_vec_size, [&] { + vllm::rms_norm_static_fp8_quant_kernel + <<>>( + out.data_ptr(), input.data_ptr(), + input_stride, weight.data_ptr(), + scale.data_ptr(), epsilon, num_tokens, + hidden_size); + }); }); }); } From d5edcb86781ea56f1eb0c9c5d7482a7cae00ec17 Mon Sep 17 00:00:00 2001 From: Fanli Lin Date: Wed, 12 Nov 2025 02:18:02 +0800 Subject: [PATCH 071/183] [BugFix] Fix Siglip2Attention on XPU (#28448) Signed-off-by: Lin, Fanli --- vllm/model_executor/models/siglip2navit.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/siglip2navit.py b/vllm/model_executor/models/siglip2navit.py index c20bcd975ca3..29dd164ad37f 100644 --- a/vllm/model_executor/models/siglip2navit.py +++ b/vllm/model_executor/models/siglip2navit.py @@ -25,6 +25,7 @@ ) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.platforms import current_platform from .vision import get_vit_attn_backend @@ -188,7 +189,7 @@ def apply_rotary_pos_emb( ) -> tuple[torch.Tensor, torch.Tensor]: cos = cos.chunk(2, dim=-1)[0].contiguous() sin = sin.chunk(2, dim=-1)[0].contiguous() - if is_flash_attn_backend: + if is_flash_attn_backend and not current_platform.is_xpu(): from flash_attn.layers.rotary import apply_rotary_emb apply_rotary_emb_func = apply_rotary_emb @@ -306,7 +307,13 @@ def forward( max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() if self.is_flash_attn_backend: attn_output = self.flash_attn_varlen_func( - queries, keys, values, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen + queries, + keys, + values, + cu_seqlens_q=cu_seqlens, + cu_seqlens_k=cu_seqlens, + max_seqlen_q=max_seqlen, + max_seqlen_k=max_seqlen, ).reshape(seq_length, -1) elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA: # Execute attention entry by entry for speed & less VRAM. From 76e4dcf225e4de115bdc20b00a78d49bec767c09 Mon Sep 17 00:00:00 2001 From: Lukas Geiger Date: Tue, 11 Nov 2025 18:26:04 +0000 Subject: [PATCH 072/183] [Misc] Remove unused attention prefix prefill ops functions (#26971) Signed-off-by: Lukas Geiger --- vllm/attention/ops/prefix_prefill.py | 210 ------------------ .../compressed_tensors_moe.py | 3 - 2 files changed, 213 deletions(-) diff --git a/vllm/attention/ops/prefix_prefill.py b/vllm/attention/ops/prefix_prefill.py index addf1d9dea73..f101d5c4a927 100644 --- a/vllm/attention/ops/prefix_prefill.py +++ b/vllm/attention/ops/prefix_prefill.py @@ -335,216 +335,6 @@ def _fwd_kernel( return -@triton.jit -def _fwd_kernel_flash_attn_v2( - Q, - K, - V, - K_cache, - V_cache, - B_Loc, - sm_scale, - B_Start_Loc, - B_Seqlen, - B_Ctxlen, - block_size, - x, - Out, - stride_b_loc_b, - stride_b_loc_s, - stride_qbs, - stride_qh, - stride_qd, - stride_kbs, - stride_kh, - stride_kd, - stride_vbs, - stride_vh, - stride_vd, - stride_obs, - stride_oh, - stride_od, - stride_k_cache_bs, - stride_k_cache_h, - stride_k_cache_d, - stride_k_cache_bl, - stride_k_cache_x, - stride_v_cache_bs, - stride_v_cache_h, - stride_v_cache_d, - stride_v_cache_bl, - num_queries_per_kv: int, - BLOCK_M: tl.constexpr, - BLOCK_DMODEL: tl.constexpr, - BLOCK_N: tl.constexpr, -): - cur_batch = tl.program_id(0) - cur_head = tl.program_id(1) - start_m = tl.program_id(2) - - cur_kv_head = cur_head // num_queries_per_kv - - cur_batch_ctx_len = tl.load(B_Ctxlen + cur_batch) - cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) - cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) - - block_start_loc = BLOCK_M * start_m - - # initialize offsets - offs_n = tl.arange(0, BLOCK_N) - offs_d = tl.arange(0, BLOCK_DMODEL) - offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) - off_q = ( - (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs - + cur_head * stride_qh - + offs_d[None, :] * stride_qd - ) - - q = tl.load( - Q + off_q, - mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len, - other=0.0, - ) - - # # initialize pointer to m and l - m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") - l_i = tl.zeros([BLOCK_M], dtype=tl.float32) - acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) - - for start_n in range(0, cur_batch_ctx_len, BLOCK_N): - start_n = tl.multiple_of(start_n, BLOCK_N) - # -- compute qk ---- - bn = tl.load( - B_Loc - + cur_batch * stride_b_loc_b - + ((start_n + offs_n) // block_size) * stride_b_loc_s, - mask=(start_n + offs_n) < cur_batch_ctx_len, - other=0, - ).to(tl.int64) - off_k = ( - bn[None, :] * stride_k_cache_bs - + cur_kv_head * stride_k_cache_h - + (offs_d[:, None] // x) * stride_k_cache_d - + ((start_n + offs_n[None, :]) % block_size) * stride_k_cache_bl - + (offs_d[:, None] % x) * stride_k_cache_x - ) - off_v = ( - bn[:, None] * stride_v_cache_bs - + cur_kv_head * stride_v_cache_h - + offs_d[None, :] * stride_v_cache_d - + (start_n + offs_n[:, None]) % block_size * stride_v_cache_bl - ) - k = tl.load( - K_cache + off_k, - mask=(start_n + offs_n[None, :]) < cur_batch_ctx_len, - other=0.0, - ) - qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - qk += tl.dot(q, k) - qk = tl.where( - (start_n + offs_n[None, :]) < cur_batch_ctx_len, qk, float("-inf") - ) - qk *= sm_scale - - # -- compute m_ij, p, l_ij - m_ij = tl.max(qk, 1) - m_i_new = tl.maximum(m_i, m_ij) - p = tl.math.exp(qk - m_i_new[:, None]) - l_ij = tl.sum(p, 1) - # -- update m_i and l_i - - alpha = tl.math.exp(m_i - m_i_new) - l_i_new = alpha * l_i + l_ij - # -- update output accumulator -- - # scale p - # scale acc - acc_scale = alpha - # acc_scale = l_i / l_i_new * alpha - acc = acc * acc_scale[:, None] - # update acc - v = tl.load( - V_cache + off_v, - mask=(start_n + offs_n[:, None]) < cur_batch_ctx_len, - other=0.0, - ) - - p = p.to(v.dtype) - acc += tl.dot(p, v) - # update m_i and l_i - l_i = l_i_new - m_i = m_i_new - - off_k = ( - offs_n[None, :] * stride_kbs - + cur_kv_head * stride_kh - + offs_d[:, None] * stride_kd - ) - off_v = ( - offs_n[:, None] * stride_vbs - + cur_kv_head * stride_vh - + offs_d[None, :] * stride_vd - ) - k_ptrs = K + off_k - v_ptrs = V + off_v - - block_mask = tl.where(block_start_loc < cur_batch_seq_len - cur_batch_ctx_len, 1, 0) - - for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N): - start_n = tl.multiple_of(start_n, BLOCK_N) - # -- compute qk ---- - k = tl.load( - k_ptrs + (cur_batch_in_all_start_index + start_n) * stride_kbs, - mask=(start_n + offs_n[None, :]) < cur_batch_seq_len - cur_batch_ctx_len, - other=0.0, - ) - - qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - qk += tl.dot(q, k) - qk *= sm_scale - qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf")) - - # -- compute m_ij, p, l_ij - m_ij = tl.max(qk, 1) - m_i_new = tl.maximum(m_i, m_ij) - p = tl.math.exp(qk - m_i_new[:, None]) - l_ij = tl.sum(p, 1) - # -- update m_i and l_i - - alpha = tl.math.exp(m_i - m_i_new) - l_i_new = alpha * l_i + l_ij - # -- update output accumulator -- - # scale p - # scale acc - acc_scale = alpha - # acc_scale = l_i / l_i_new * alpha - acc = acc * acc_scale[:, None] - # update acc - v = tl.load( - v_ptrs + (cur_batch_in_all_start_index + start_n) * stride_vbs, - mask=(start_n + offs_n[:, None]) < cur_batch_seq_len - cur_batch_ctx_len, - other=0.0, - ) - - p = p.to(v.dtype) - acc += tl.dot(p, v) - # update m_i and l_i - l_i = l_i_new - m_i = m_i_new - - # acc /= l_i[:, None] - # initialize pointers to output - off_o = ( - (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs - + cur_head * stride_oh - + offs_d[None, :] * stride_od - ) - out_ptrs = Out + off_o - tl.store( - out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len - ) - return - - @triton.jit def _fwd_kernel_alibi( Q, diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index 59567f2ca13c..6257a410e943 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -98,9 +98,6 @@ class GPTQMarlinState(Enum): class CompressedTensorsMoEMethod(FusedMoEMethodBase): - def __init_(self, moe: FusedMoEConfig): - super().__init__(moe) - @staticmethod def get_moe_method( quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501 From 4228be7959e98e57d88501bd97aca7ef34ff562e Mon Sep 17 00:00:00 2001 From: Jialin Ouyang Date: Tue, 11 Nov 2025 10:28:47 -0800 Subject: [PATCH 073/183] [Perf] Use np.ndarray instead of list[list[int]] to reduce GC overhead (#28245) Signed-off-by: Jialin Ouyang --- tests/v1/engine/utils.py | 7 ++++--- vllm/v1/engine/logprobs.py | 7 ++++++- vllm/v1/outputs.py | 13 +++++++------ 3 files changed, 17 insertions(+), 10 deletions(-) diff --git a/tests/v1/engine/utils.py b/tests/v1/engine/utils.py index 23684a2c55ce..3541ef89bfc1 100644 --- a/tests/v1/engine/utils.py +++ b/tests/v1/engine/utils.py @@ -5,6 +5,7 @@ from dataclasses import dataclass from typing import TypeAlias +import numpy as np import torch from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast @@ -369,9 +370,9 @@ def get_outputs(self) -> list[EngineCoreOutput]: self.generated_logprobs_raw[req_idx][token_idx] ) logprobs = LogprobsLists( - [logprobs_token_ids_], - [logprobs_], - [sampled_token_ranks_], + np.array([logprobs_token_ids_]), + np.array([logprobs_]), + np.array([sampled_token_ranks_]), ) else: logprobs = None diff --git a/vllm/v1/engine/logprobs.py b/vllm/v1/engine/logprobs.py index 4c5955d7ee2e..b618d2347265 100644 --- a/vllm/v1/engine/logprobs.py +++ b/vllm/v1/engine/logprobs.py @@ -74,7 +74,12 @@ def _update_sample_logprobs(self, logprobs_lists: LogprobsLists) -> None: token_ids_lst, logprobs_lst, ranks_lst, _ = logprobs_lists - for rank, logprobs, token_ids in zip(ranks_lst, logprobs_lst, token_ids_lst): + for rank_np, logprobs_np, token_ids_np in zip( + ranks_lst, logprobs_lst, token_ids_lst + ): + rank = rank_np.tolist() + logprobs = logprobs_np.tolist() + token_ids = token_ids_np.tolist() # Detokenize (non-incrementally). decoded_tokens = ( NONES diff --git a/vllm/v1/outputs.py b/vllm/v1/outputs.py index b5cba96e1026..5f65e4ee0d1f 100644 --- a/vllm/v1/outputs.py +++ b/vllm/v1/outputs.py @@ -5,6 +5,7 @@ from dataclasses import dataclass, field from typing import TYPE_CHECKING, NamedTuple +import numpy as np import torch if TYPE_CHECKING: @@ -15,11 +16,11 @@ class LogprobsLists(NamedTuple): # [num_reqs x num_generated_tokens, max_num_logprobs + 1] - logprob_token_ids: list[list[int]] + logprob_token_ids: np.ndarray # [num_reqs x num_generated_tokens, max_num_logprobs + 1] - logprobs: list[list[float]] + logprobs: np.ndarray # [num_reqs x num_generated_tokens] - sampled_token_ranks: list[int] + sampled_token_ranks: np.ndarray # [num_reqs] # Used for slicing the logprobs in cases like speculative # decoding where the number of generated tokens may be @@ -60,9 +61,9 @@ class LogprobsTensors(NamedTuple): def tolists(self, cu_num_generated_tokens: list[int] | None = None): return LogprobsLists( - self.logprob_token_ids.tolist(), - self.logprobs.tolist(), - self.selected_token_ranks.tolist(), + self.logprob_token_ids.cpu().numpy(), + self.logprobs.cpu().numpy(), + self.selected_token_ranks.cpu().numpy(), cu_num_generated_tokens, ) From de120bc94f2e51633824093c626423ec8e7cb3a9 Mon Sep 17 00:00:00 2001 From: Canlin Guo <961750412@qq.com> Date: Wed, 12 Nov 2025 02:57:12 +0800 Subject: [PATCH 074/183] [V0 deprecation] Clean up num_prefill_tokens logic for V0 (#28203) Signed-off-by: gcanlin --- vllm/forward_context.py | 18 +++--------------- 1 file changed, 3 insertions(+), 15 deletions(-) diff --git a/vllm/forward_context.py b/vllm/forward_context.py index ef37cf862c9f..44bc2a4cda31 100644 --- a/vllm/forward_context.py +++ b/vllm/forward_context.py @@ -5,7 +5,7 @@ from collections import defaultdict from contextlib import contextmanager from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, NamedTuple, Union +from typing import TYPE_CHECKING, Any, NamedTuple import torch @@ -185,18 +185,13 @@ class ForwardContext: # copy from vllm_config.compilation_config.static_forward_context no_compile_layers: dict[str, Any] """ - Type AttentionMetadata for v0, Type Dict[str, AttentionMetadata] for v1, map from layer_name of each attention layer to its attention metadata Type List[Dict[str, AttentionMetadata]] for DBO. List of size two, one for each microbatch. Set dynamically for each forward pass """ - attn_metadata: Union[ - "AttentionMetadata", - dict[str, "AttentionMetadata"], - list[dict[str, "AttentionMetadata"]], - ] + attn_metadata: dict[str, "AttentionMetadata"] | list[dict[str, "AttentionMetadata"]] # TODO: remove after making all virtual_engines share the same kv cache virtual_engine: int # set dynamically for each forward pass # set dynamically for each forward pass @@ -324,14 +319,7 @@ def set_forward_context( finally: global last_logging_time, batchsize_logging_interval if need_to_track_batchsize: - if hasattr(attn_metadata, "num_prefill_tokens"): - # for v0 attention backends - batchsize = ( - attn_metadata.num_prefill_tokens + attn_metadata.num_decode_tokens - ) - else: - # for v1 attention backends - batchsize = num_tokens + batchsize = num_tokens # we use synchronous scheduling right now, # adding a sync point here should not affect # scheduling of the next batch From 8c32c6e4b485f1cae1a1dc8a3f9895cf63f3e7af Mon Sep 17 00:00:00 2001 From: Jie Luo <65482183+Livinfly@users.noreply.github.com> Date: Wed, 12 Nov 2025 02:59:16 +0800 Subject: [PATCH 075/183] [Misc] fix typo in DCP comment (#28389) Signed-off-by: Livinfly --- vllm/v1/attention/backends/mla/common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index b4cb5c200da3..19bd102cb1e3 100755 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -2000,7 +2000,7 @@ def forward( decode_q, kv_cache, attn_metadata, layer ) - # recorect dcp attn_out with lse. + # correct dcp attn_out with lse. if self.dcp_world_size > 1: attn_out = cp_lse_ag_out_rs(attn_out, lse, get_dcp_group()) From 9d1c47470430ba31c02946aa1fd01aadf6e18b91 Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Wed, 12 Nov 2025 03:06:21 +0800 Subject: [PATCH 076/183] [LoRA][1/N]Remove LoRA extra vocab (#28382) Signed-off-by: Jee Jee Li --- vllm/model_executor/models/apertus.py | 30 +++------------- vllm/model_executor/models/arcee.py | 10 ++---- vllm/model_executor/models/arctic.py | 6 ++-- vllm/model_executor/models/aria.py | 8 ++--- vllm/model_executor/models/baichuan.py | 4 +-- vllm/model_executor/models/bailing_moe.py | 2 -- vllm/model_executor/models/bamba.py | 30 ++++------------ vllm/model_executor/models/chameleon.py | 8 ++--- vllm/model_executor/models/chatglm.py | 3 +- vllm/model_executor/models/commandr.py | 19 ++++------- vllm/model_executor/models/dbrx.py | 9 ++--- vllm/model_executor/models/exaone.py | 27 +++------------ vllm/model_executor/models/exaone4.py | 26 +++----------- vllm/model_executor/models/falcon_h1.py | 31 ++++------------- vllm/model_executor/models/gemma.py | 2 -- vllm/model_executor/models/gemma2.py | 3 +- vllm/model_executor/models/gemma3.py | 3 +- vllm/model_executor/models/gemma3n.py | 3 +- vllm/model_executor/models/glm4.py | 2 -- vllm/model_executor/models/gpt_bigcode.py | 20 +++-------- vllm/model_executor/models/granitemoe.py | 27 +++------------ .../model_executor/models/granitemoehybrid.py | 27 +++------------ .../model_executor/models/granitemoeshared.py | 28 +++------------ vllm/model_executor/models/grok1.py | 26 ++++---------- vllm/model_executor/models/hunyuan_v1.py | 21 ++++-------- vllm/model_executor/models/internlm2.py | 2 -- vllm/model_executor/models/jamba.py | 30 ++++------------ vllm/model_executor/models/kimi_vl.py | 10 ++---- vllm/model_executor/models/lfm2.py | 31 +++-------------- vllm/model_executor/models/lfm2_moe.py | 32 ++++------------- vllm/model_executor/models/llama_eagle3.py | 3 -- vllm/model_executor/models/longcat_flash.py | 3 +- vllm/model_executor/models/mamba.py | 29 ++++------------ vllm/model_executor/models/mamba2.py | 28 +++------------ vllm/model_executor/models/medusa.py | 12 ++----- vllm/model_executor/models/mimo.py | 2 -- vllm/model_executor/models/minicpm.py | 30 ++++------------ vllm/model_executor/models/minicpm_eagle.py | 29 ++++------------ vllm/model_executor/models/minimax_text_01.py | 11 ++---- vllm/model_executor/models/mlp_speculator.py | 1 - vllm/model_executor/models/molmo.py | 3 +- vllm/model_executor/models/nemotron.py | 30 ++++------------ vllm/model_executor/models/nemotron_h.py | 30 ++++------------ vllm/model_executor/models/nemotron_nas.py | 31 ++++------------- vllm/model_executor/models/olmo.py | 4 +-- vllm/model_executor/models/olmo2.py | 2 -- vllm/model_executor/models/ouro.py | 2 -- vllm/model_executor/models/phi.py | 3 +- vllm/model_executor/models/phi3v.py | 1 - vllm/model_executor/models/phi4mm.py | 14 ++------ vllm/model_executor/models/phimoe.py | 34 ++++--------------- vllm/model_executor/models/plamo2.py | 11 ++---- vllm/model_executor/models/qwen2.py | 2 -- vllm/model_executor/models/qwen2_rm.py | 2 -- vllm/model_executor/models/qwen3.py | 2 -- vllm/model_executor/models/qwen3_next.py | 30 ++++------------ vllm/model_executor/models/qwen3_next_mtp.py | 23 ++++--------- vllm/model_executor/models/qwen3_vl.py | 2 -- vllm/model_executor/models/seed_oss.py | 2 -- vllm/model_executor/models/solar.py | 30 ++++------------ vllm/model_executor/models/starcoder2.py | 12 ++----- vllm/model_executor/models/step3_text.py | 16 ++------- .../models/transformers/causal.py | 3 +- vllm/model_executor/models/whisper.py | 6 ++-- vllm/model_executor/models/zamba2.py | 28 +++------------ 65 files changed, 197 insertions(+), 754 deletions(-) diff --git a/vllm/model_executor/models/apertus.py b/vllm/model_executor/models/apertus.py index 72e5ddcf1abe..233b8c79f299 100644 --- a/vllm/model_executor/models/apertus.py +++ b/vllm/model_executor/models/apertus.py @@ -49,7 +49,6 @@ from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding, ) @@ -346,24 +345,18 @@ def __init__( config = vllm_config.model_config.hf_config cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config - lora_config = vllm_config.lora_config self.config = config self.quant_config = quant_config - lora_vocab = ( - (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) - if lora_config - else 0 - ) - self.vocab_size = config.vocab_size + lora_vocab - self.org_vocab_size = config.vocab_size + + self.vocab_size = config.vocab_size + if get_pp_group().is_first_rank or ( config.tie_word_embeddings and get_pp_group().is_last_rank ): self.embed_tokens = VocabParallelEmbedding( self.vocab_size, config.hidden_size, - org_num_embeddings=config.vocab_size, quant_config=quant_config, ) else: @@ -518,9 +511,7 @@ def __init__( super().__init__() config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config - lora_config = vllm_config.lora_config self.config = config - self.lora_config = lora_config self.model = self._init_model( vllm_config=vllm_config, @@ -529,20 +520,9 @@ def __init__( ) if get_pp_group().is_last_rank: - self.unpadded_vocab_size = config.vocab_size - if lora_config: - self.unpadded_vocab_size += lora_config.lora_extra_vocab_size self.lm_head = ParallelLMHead( - self.unpadded_vocab_size, + config.vocab_size, config.hidden_size, - org_num_embeddings=config.vocab_size, - padding_size=( - DEFAULT_VOCAB_PADDING_SIZE - # We need bigger padding if using lora for kernel - # compatibility - if not lora_config - else lora_config.lora_vocab_padding_size - ), quant_config=quant_config, prefix=maybe_prefix(prefix, "lm_head"), ) @@ -551,7 +531,7 @@ def __init__( logit_scale = getattr(config, "logit_scale", 1.0) self.logits_processor = LogitsProcessor( - self.unpadded_vocab_size, config.vocab_size, logit_scale + config.vocab_size, scale=logit_scale ) else: self.lm_head = PPMissingLayer() diff --git a/vllm/model_executor/models/arcee.py b/vllm/model_executor/models/arcee.py index 08bf1a6aad75..f33970aff279 100644 --- a/vllm/model_executor/models/arcee.py +++ b/vllm/model_executor/models/arcee.py @@ -23,7 +23,6 @@ from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding, ) @@ -200,7 +199,6 @@ def __init__( self.quant_config = quant_config self.config = config self.vocab_size = config.vocab_size - self.org_vocab_size = config.vocab_size # Word embeddings (parallelized if using pipeline parallel) if get_pp_group().is_first_rank or ( @@ -209,7 +207,6 @@ def __init__( self.embed_tokens = VocabParallelEmbedding( self.vocab_size, config.hidden_size, - org_num_embeddings=config.vocab_size, quant_config=quant_config, ) else: @@ -383,13 +380,10 @@ def __init__(self, *, vllm_config, prefix: str = "") -> None: if get_pp_group().is_last_rank: # Determine vocabulary size (including any LoRA extra tokens # for padded LM head) - self.unpadded_vocab_size = config.vocab_size self.lm_head = ParallelLMHead( - self.unpadded_vocab_size, + config.vocab_size, config.hidden_size, - org_num_embeddings=config.vocab_size, - padding_size=DEFAULT_VOCAB_PADDING_SIZE, quant_config=vllm_config.quant_config, bias=getattr(config, "lm_head_bias", False), prefix=f"{prefix}.lm_head", @@ -399,7 +393,7 @@ def __init__(self, *, vllm_config, prefix: str = "") -> None: self.lm_head = self.lm_head.tie_weights(self.model.embed_tokens) logit_scale = getattr(config, "logit_scale", 1.0) self.logits_processor = LogitsProcessor( - self.unpadded_vocab_size, config.vocab_size, logit_scale + config.vocab_size, scale=logit_scale ) else: # Placeholder for lm_head on non-last ranks diff --git a/vllm/model_executor/models/arctic.py b/vllm/model_executor/models/arctic.py index bb505219ea17..ae3b96c83509 100644 --- a/vllm/model_executor/models/arctic.py +++ b/vllm/model_executor/models/arctic.py @@ -490,10 +490,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.lm_head.weight = self.model.embed_tokens.weight self.num_experts = config.num_local_experts self.num_experts_per_tok = config.num_experts_per_tok - self.unpadded_vocab_size = config.vocab_size - self.logits_processor = LogitsProcessor( - self.unpadded_vocab_size, config.vocab_size - ) + + self.logits_processor = LogitsProcessor(config.vocab_size) self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors ) diff --git a/vllm/model_executor/models/aria.py b/vllm/model_executor/models/aria.py index 222a42579054..fe37487d6ed8 100644 --- a/vllm/model_executor/models/aria.py +++ b/vllm/model_executor/models/aria.py @@ -547,18 +547,14 @@ def __init__( self.pad_token_id = ( self.config.pad_token_id if self.config.pad_token_id is not None else -1 ) - self.unpadded_vocab_size = config.text_config.vocab_size self.lm_head = ParallelLMHead( - self.unpadded_vocab_size, + self.vocab_size, config.text_config.hidden_size, - org_num_embeddings=self.language_model.org_vocab_size, quant_config=quant_config, prefix=maybe_prefix(prefix, "lm_head"), ) logit_scale = getattr(config, "logit_scale", 1.0) - self.logits_processor = LogitsProcessor( - self.unpadded_vocab_size, self.vocab_size, logit_scale - ) + self.logits_processor = LogitsProcessor(self.vocab_size, scale=logit_scale) def _parse_and_validate_image_input( self, **kwargs: object diff --git a/vllm/model_executor/models/baichuan.py b/vllm/model_executor/models/baichuan.py index 39990b9fd683..dac012eb9f82 100644 --- a/vllm/model_executor/models/baichuan.py +++ b/vllm/model_executor/models/baichuan.py @@ -402,9 +402,9 @@ def __init__( super().__init__() config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config - lora_config = vllm_config.lora_config + self.config = config - self.lora_config = lora_config + self.tp_size = get_tensor_model_parallel_world_size() self.quant_config = quant_config self.model = BaiChuanModel( diff --git a/vllm/model_executor/models/bailing_moe.py b/vllm/model_executor/models/bailing_moe.py index 1549c653482f..641bdb69c366 100644 --- a/vllm/model_executor/models/bailing_moe.py +++ b/vllm/model_executor/models/bailing_moe.py @@ -581,10 +581,8 @@ def __init__( config = vllm_config.model_config.hf_config.get_text_config() vllm_config.model_config.hf_config = config quant_config = vllm_config.quant_config - lora_config = vllm_config.lora_config self.config = config - self.lora_config = lora_config self.quant_config = quant_config self.max_position_embeddings = config.max_position_embeddings self.model = BailingMoeModel( diff --git a/vllm/model_executor/models/bamba.py b/vllm/model_executor/models/bamba.py index bc7dbb618f65..4a2b3da1c194 100644 --- a/vllm/model_executor/models/bamba.py +++ b/vllm/model_executor/models/bamba.py @@ -30,7 +30,6 @@ from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding, ) @@ -284,21 +283,14 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): model_config = vllm_config.model_config cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config - lora_config = vllm_config.lora_config self.config = config - lora_vocab = ( - (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) - if lora_config - else 0 - ) - self.vocab_size = config.vocab_size + lora_vocab - self.org_vocab_size = config.vocab_size + + self.vocab_size = config.vocab_size self.embed_tokens = VocabParallelEmbedding( self.vocab_size, config.hidden_size, - org_num_embeddings=config.vocab_size, ) def get_layer(prefix: str): @@ -478,7 +470,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config = vllm_config.model_config.hf_config self.vllm_config = vllm_config self.model_config = vllm_config.model_config - lora_config = vllm_config.lora_config + scheduler_config = vllm_config.scheduler_config self.quant_config = vllm_config.quant_config @@ -488,24 +480,14 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.model = BambaModel( vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") ) - self.unpadded_vocab_size = config.vocab_size - if lora_config: - self.unpadded_vocab_size += lora_config.lora_extra_vocab_size + self.lm_head = ParallelLMHead( - self.unpadded_vocab_size, + config.vocab_size, config.hidden_size, - org_num_embeddings=config.vocab_size, - padding_size=DEFAULT_VOCAB_PADDING_SIZE - # We need bigger padding if using lora for kernel - # compatibility - if not lora_config - else lora_config.lora_vocab_padding_size, prefix=maybe_prefix(prefix, "lm_head"), ) - self.logits_processor = LogitsProcessor( - self.unpadded_vocab_size, config.vocab_size - ) + self.logits_processor = LogitsProcessor(config.vocab_size) self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors diff --git a/vllm/model_executor/models/chameleon.py b/vllm/model_executor/models/chameleon.py index 54ff6991fa70..64f73e938bf6 100644 --- a/vllm/model_executor/models/chameleon.py +++ b/vllm/model_executor/models/chameleon.py @@ -963,9 +963,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.model = ChameleonModel( vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") ) - self.unpadded_vocab_size = config.vocab_size + self.lm_head = ParallelLMHead( - self.unpadded_vocab_size, + config.vocab_size, config.hidden_size, prefix=maybe_prefix(prefix, "lm_head"), ) @@ -973,9 +973,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.lm_head.weight = self.model.embed_tokens.weight logit_scale = getattr(config, "logit_scale", 1.0) - self.logits_processor = LogitsProcessor( - self.unpadded_vocab_size, config.vocab_size, logit_scale - ) + self.logits_processor = LogitsProcessor(config.vocab_size, scale=logit_scale) self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors ) diff --git a/vllm/model_executor/models/chatglm.py b/vllm/model_executor/models/chatglm.py index bcbe82b78c3b..ccf7c9300166 100644 --- a/vllm/model_executor/models/chatglm.py +++ b/vllm/model_executor/models/chatglm.py @@ -433,10 +433,9 @@ def __init__( super().__init__() config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config - lora_config = vllm_config.lora_config + multimodal_config = vllm_config.model_config.multimodal_config self.config = config - self.lora_config = lora_config self.multimodal_config = multimodal_config self.quant_config = quant_config diff --git a/vllm/model_executor/models/commandr.py b/vllm/model_executor/models/commandr.py index 75459601f76b..6ae1dc356082 100644 --- a/vllm/model_executor/models/commandr.py +++ b/vllm/model_executor/models/commandr.py @@ -288,17 +288,12 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config = vllm_config.model_config.hf_config cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config - lora_config = vllm_config.lora_config self.quant_config = quant_config self.config = config - lora_vocab = ( - (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) - if lora_config - else 0 - ) - self.vocab_size = config.vocab_size + lora_vocab - self.org_vocab_size = config.vocab_size + + self.vocab_size = config.vocab_size + self.embed_tokens = VocabParallelEmbedding( config.vocab_size, config.hidden_size ) @@ -424,17 +419,15 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config - lora_config = vllm_config.lora_config + self.config = config # currently all existing command R models have `tie_word_embeddings` # enabled assert config.tie_word_embeddings - self.unpadded_vocab_size = config.vocab_size - if lora_config: - self.unpadded_vocab_size += lora_config.lora_extra_vocab_size + self.quant_config = quant_config self.logits_processor = LogitsProcessor( - self.unpadded_vocab_size, config.vocab_size, scale=config.logit_scale + config.vocab_size, scale=config.logit_scale ) self.model = CohereModel( vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") diff --git a/vllm/model_executor/models/dbrx.py b/vllm/model_executor/models/dbrx.py index 22095d05848c..70999501f4c6 100644 --- a/vllm/model_executor/models/dbrx.py +++ b/vllm/model_executor/models/dbrx.py @@ -25,7 +25,6 @@ from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding, ) @@ -441,21 +440,17 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): if config.tie_word_embeddings: raise ValueError("tie_word_embeddings is not supported for Dbrx models.") self.quant_config = quant_config - self.unpadded_vocab_size = config.vocab_size + self.transformer = DbrxModel( vllm_config=vllm_config, prefix=maybe_prefix(prefix, "transformer") ) self.lm_head = ParallelLMHead( config.vocab_size, config.d_model, - org_num_embeddings=config.vocab_size, - padding_size=DEFAULT_VOCAB_PADDING_SIZE, quant_config=quant_config, prefix=maybe_prefix(prefix, "lm_head"), ) - self.logits_processor = LogitsProcessor( - self.unpadded_vocab_size, config.vocab_size - ) + self.logits_processor = LogitsProcessor(config.vocab_size) self.make_empty_intermediate_tensors = ( self.transformer.make_empty_intermediate_tensors ) diff --git a/vllm/model_executor/models/exaone.py b/vllm/model_executor/models/exaone.py index 84fb52d13854..b9c7a520caff 100644 --- a/vllm/model_executor/models/exaone.py +++ b/vllm/model_executor/models/exaone.py @@ -48,7 +48,6 @@ from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding, ) @@ -323,16 +322,11 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config = vllm_config.model_config.hf_config cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config - lora_config = vllm_config.lora_config self.config = config self.quant_config = quant_config - lora_vocab = ( - (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) - if lora_config - else 0 - ) - self.vocab_size = config.vocab_size + lora_vocab + + self.vocab_size = config.vocab_size self.wte = config.vocab_size if get_pp_group().is_first_rank or ( config.tie_word_embeddings and get_pp_group().is_last_rank @@ -340,7 +334,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.wte = VocabParallelEmbedding( self.vocab_size, config.hidden_size, - org_num_embeddings=config.vocab_size, quant_config=quant_config, ) else: @@ -489,10 +482,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config - lora_config = vllm_config.lora_config self.config = config - self.lora_config = lora_config + self.quant_config = quant_config self.transformer = ExaoneModel( @@ -500,18 +492,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): prefix=maybe_prefix(prefix, "model"), ) if get_pp_group().is_last_rank: - self.unpadded_vocab_size = config.vocab_size - if lora_config: - self.unpadded_vocab_size += lora_config.lora_extra_vocab_size self.lm_head = ParallelLMHead( - self.unpadded_vocab_size, + config.vocab_size, config.hidden_size, - org_num_embeddings=config.vocab_size, - padding_size=DEFAULT_VOCAB_PADDING_SIZE - # We need bigger padding if using lora for kernel - # compatibility - if not lora_config - else lora_config.lora_vocab_padding_size, quant_config=quant_config, prefix=maybe_prefix(prefix, "lm_head"), ) @@ -520,7 +503,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): logit_scale = getattr(config, "logit_scale", 1.0) self.logits_processor = LogitsProcessor( - self.unpadded_vocab_size, config.vocab_size, logit_scale + config.vocab_size, scale=logit_scale ) else: self.lm_head = PPMissingLayer() diff --git a/vllm/model_executor/models/exaone4.py b/vllm/model_executor/models/exaone4.py index d5e4d9a1486f..6a5c888c095a 100644 --- a/vllm/model_executor/models/exaone4.py +++ b/vllm/model_executor/models/exaone4.py @@ -44,7 +44,6 @@ from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding, ) @@ -311,23 +310,17 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config = vllm_config.model_config.hf_config cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config - lora_config = vllm_config.lora_config self.config = config self.quant_config = quant_config - lora_vocab = ( - (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) - if lora_config - else 0 - ) - self.vocab_size = config.vocab_size + lora_vocab + + self.vocab_size = config.vocab_size if get_pp_group().is_first_rank or ( config.tie_word_embeddings and get_pp_group().is_last_rank ): self.embed_tokens = VocabParallelEmbedding( self.vocab_size, config.hidden_size, - org_num_embeddings=config.vocab_size, quant_config=quant_config, ) else: @@ -476,10 +469,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config - lora_config = vllm_config.lora_config self.config = config - self.lora_config = lora_config self.quant_config = quant_config self.model = Exaone4Model( @@ -487,18 +478,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): prefix=maybe_prefix(prefix, "model"), ) if get_pp_group().is_last_rank: - self.unpadded_vocab_size = config.vocab_size - if lora_config: - self.unpadded_vocab_size += lora_config.lora_extra_vocab_size self.lm_head = ParallelLMHead( - self.unpadded_vocab_size, + config.vocab_size, config.hidden_size, - org_num_embeddings=config.vocab_size, - padding_size=DEFAULT_VOCAB_PADDING_SIZE - # We need bigger padding if using lora for kernel - # compatibility - if not lora_config - else lora_config.lora_vocab_padding_size, quant_config=quant_config, prefix=maybe_prefix(prefix, "lm_head"), ) @@ -507,7 +489,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): logit_scale = getattr(config, "logit_scale", 1.0) self.logits_processor = LogitsProcessor( - self.unpadded_vocab_size, config.vocab_size, logit_scale + config.vocab_size, scale=logit_scale ) else: self.lm_head = PPMissingLayer() diff --git a/vllm/model_executor/models/falcon_h1.py b/vllm/model_executor/models/falcon_h1.py index ac5846cfd869..38838be29093 100644 --- a/vllm/model_executor/models/falcon_h1.py +++ b/vllm/model_executor/models/falcon_h1.py @@ -30,7 +30,6 @@ from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding, ) @@ -424,21 +423,15 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): model_config = vllm_config.model_config cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config - lora_config = vllm_config.lora_config self.config = config - lora_vocab = ( - (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) - if lora_config - else 0 - ) - self.vocab_size = config.vocab_size + lora_vocab - self.org_vocab_size = config.vocab_size + + self.vocab_size = config.vocab_size + if get_pp_group().is_first_rank: self.embed_tokens = VocabParallelEmbedding( self.vocab_size, config.hidden_size, - org_num_embeddings=config.vocab_size, ) self.embedding_multiplier = config.embedding_multiplier else: @@ -572,7 +565,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config = vllm_config.model_config.hf_config self.vllm_config = vllm_config self.model_config = vllm_config.model_config - lora_config = vllm_config.lora_config + scheduler_config = vllm_config.scheduler_config self.quant_config = vllm_config.quant_config @@ -584,21 +577,11 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") ) self.tie_word_embeddings = config.tie_word_embeddings - self.unpadded_vocab_size = config.vocab_size - if lora_config: - self.unpadded_vocab_size += lora_config.lora_extra_vocab_size + if get_pp_group().is_last_rank: self.lm_head = ParallelLMHead( - self.unpadded_vocab_size, + config.vocab_size, config.hidden_size, - org_num_embeddings=config.vocab_size, - padding_size=( - DEFAULT_VOCAB_PADDING_SIZE - # We need bigger padding if using lora for kernel - # compatibility - if not lora_config - else lora_config.lora_vocab_padding_size - ), prefix=maybe_prefix(prefix, "lm_head"), ) self.lm_head_multiplier = config.lm_head_multiplier @@ -607,7 +590,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): # Used to track and store by the Mamba cache between steps. self.logits_processor = LogitsProcessor( - self.unpadded_vocab_size, + config.vocab_size, config.vocab_size, scale=config.lm_head_multiplier, ) diff --git a/vllm/model_executor/models/gemma.py b/vllm/model_executor/models/gemma.py index 46b111f4d939..caeee7c2e1ec 100644 --- a/vllm/model_executor/models/gemma.py +++ b/vllm/model_executor/models/gemma.py @@ -382,12 +382,10 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config - lora_config = vllm_config.lora_config self.config = config # currently all existing Gemma models have `tie_word_embeddings` enabled assert config.tie_word_embeddings - self.lora_config = lora_config self.quant_config = quant_config self.model = GemmaModel( diff --git a/vllm/model_executor/models/gemma2.py b/vllm/model_executor/models/gemma2.py index 1938efd4895e..efd01535fc3e 100644 --- a/vllm/model_executor/models/gemma2.py +++ b/vllm/model_executor/models/gemma2.py @@ -393,8 +393,7 @@ class Gemma2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config - lora_config = vllm_config.lora_config - del lora_config # Unused. + super().__init__() self.config = config # currently all existing Gemma models have `tie_word_embeddings` enabled diff --git a/vllm/model_executor/models/gemma3.py b/vllm/model_executor/models/gemma3.py index 80ec40f478c6..213f9f562f8a 100644 --- a/vllm/model_executor/models/gemma3.py +++ b/vllm/model_executor/models/gemma3.py @@ -524,8 +524,7 @@ class Gemma3ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config - lora_config = vllm_config.lora_config - del lora_config # Unused. + super().__init__() self.config = config # currently all existing Gemma models have `tie_word_embeddings` enabled diff --git a/vllm/model_executor/models/gemma3n.py b/vllm/model_executor/models/gemma3n.py index 547884f393eb..22d51ab76269 100644 --- a/vllm/model_executor/models/gemma3n.py +++ b/vllm/model_executor/models/gemma3n.py @@ -1114,8 +1114,7 @@ class Gemma3nForCausalLM(nn.Module): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config = vllm_config.model_config.hf_config - lora_config = vllm_config.lora_config - del lora_config # Unused. + super().__init__() self.config = config self.cache_config = vllm_config.cache_config diff --git a/vllm/model_executor/models/glm4.py b/vllm/model_executor/models/glm4.py index d7fd2b109d24..4172f16737c1 100644 --- a/vllm/model_executor/models/glm4.py +++ b/vllm/model_executor/models/glm4.py @@ -248,10 +248,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config - lora_config = vllm_config.lora_config self.config = config - self.lora_config = lora_config self.quant_config = quant_config self.model = Glm4Model( diff --git a/vllm/model_executor/models/gpt_bigcode.py b/vllm/model_executor/models/gpt_bigcode.py index f2c8e2aeb822..99cdaabb98df 100644 --- a/vllm/model_executor/models/gpt_bigcode.py +++ b/vllm/model_executor/models/gpt_bigcode.py @@ -207,18 +207,13 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config = vllm_config.model_config.hf_config cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config - lora_config = vllm_config.lora_config self.config = config assert not config.add_cross_attention self.embed_dim = config.hidden_size - lora_vocab = ( - (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) - if lora_config - else 0 - ) - self.vocab_size = config.vocab_size + lora_vocab + + self.vocab_size = config.vocab_size self.wte = VocabParallelEmbedding( self.vocab_size, self.embed_dim, org_num_embeddings=config.vocab_size ) @@ -290,10 +285,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config - lora_config = vllm_config.lora_config self.config = config - self.lora_config = lora_config self.quant_config = quant_config self.transformer = GPTBigCodeModel( @@ -305,15 +298,10 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.lm_head = ParallelLMHead( self.transformer.vocab_size, self.transformer.embed_dim, - org_num_embeddings=self.config.vocab_size, prefix=maybe_prefix(prefix, "lm_head"), ) - self.unpadded_vocab_size = config.vocab_size - if lora_config: - self.unpadded_vocab_size += lora_config.lora_extra_vocab_size - self.logits_processor = LogitsProcessor( - self.unpadded_vocab_size, config.vocab_size - ) + + self.logits_processor = LogitsProcessor(config.vocab_size) self.make_empty_intermediate_tensors = ( self.transformer.make_empty_intermediate_tensors ) diff --git a/vllm/model_executor/models/granitemoe.py b/vllm/model_executor/models/granitemoe.py index e683f30805f3..c5b36c362ee3 100644 --- a/vllm/model_executor/models/granitemoe.py +++ b/vllm/model_executor/models/granitemoe.py @@ -50,7 +50,6 @@ from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding, ) @@ -296,22 +295,15 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config - lora_config = vllm_config.lora_config self.config = config self.quant_config = quant_config # Required by MixtralModel - lora_vocab = ( - (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) - if lora_config - else 0 - ) - self.vocab_size = config.vocab_size + lora_vocab - self.org_vocab_size = config.vocab_size + + self.vocab_size = config.vocab_size self.embed_tokens = VocabParallelEmbedding( self.vocab_size, config.hidden_size, - org_num_embeddings=config.vocab_size, ) self.embedding_multiplier = config.embedding_multiplier @@ -518,26 +510,16 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config - lora_config = vllm_config.lora_config self.config = config - self.lora_config = lora_config self.model = GraniteMoeModel( vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") ) - self.unpadded_vocab_size = config.vocab_size - if lora_config: - self.unpadded_vocab_size += lora_config.lora_extra_vocab_size + self.lm_head = ParallelLMHead( - self.unpadded_vocab_size, + config.vocab_size, config.hidden_size, - org_num_embeddings=config.vocab_size, - padding_size=DEFAULT_VOCAB_PADDING_SIZE - # We need bigger padding if using lora for kernel - # compatibility - if not lora_config - else lora_config.lora_vocab_padding_size, quant_config=quant_config, prefix=maybe_prefix(prefix, "lm_head"), ) @@ -545,7 +527,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.lm_head.weight = self.model.embed_tokens.weight self.logits_processor = LogitsProcessor( - self.unpadded_vocab_size, config.vocab_size, scale=1 / self.config.logits_scaling, ) diff --git a/vllm/model_executor/models/granitemoehybrid.py b/vllm/model_executor/models/granitemoehybrid.py index bac64eec8c55..3a98abed76fd 100644 --- a/vllm/model_executor/models/granitemoehybrid.py +++ b/vllm/model_executor/models/granitemoehybrid.py @@ -25,7 +25,6 @@ from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding, ) @@ -334,22 +333,15 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): model_config = vllm_config.model_config cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config - lora_config = vllm_config.lora_config self.config = config self.quant_config = quant_config - lora_vocab = ( - (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) - if lora_config - else 0 - ) - self.vocab_size = config.vocab_size + lora_vocab - self.org_vocab_size = config.vocab_size + + self.vocab_size = config.vocab_size self.embed_tokens = VocabParallelEmbedding( self.vocab_size, config.hidden_size, - org_num_embeddings=config.vocab_size, ) self.embedding_multiplier = config.embedding_multiplier @@ -658,7 +650,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config = vllm_config.model_config.hf_config self.vllm_config = vllm_config self.model_config = vllm_config.model_config - lora_config = vllm_config.lora_config + scheduler_config = vllm_config.scheduler_config self.quant_config = vllm_config.quant_config self.config = config @@ -666,26 +658,17 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.model = GraniteMoeHybridModel( vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") ) - self.unpadded_vocab_size = config.vocab_size - if lora_config: - self.unpadded_vocab_size += lora_config.lora_extra_vocab_size self.lm_head = ParallelLMHead( - self.unpadded_vocab_size, + config.vocab_size, config.hidden_size, - org_num_embeddings=config.vocab_size, - padding_size=DEFAULT_VOCAB_PADDING_SIZE - # We need bigger padding if using lora for kernel - # compatibility - if not lora_config - else lora_config.lora_vocab_padding_size, quant_config=self.quant_config, prefix=maybe_prefix(prefix, "lm_head"), ) if config.tie_word_embeddings: self.lm_head.weight = self.model.embed_tokens.weight self.logits_processor = LogitsProcessor( - self.unpadded_vocab_size, + config.vocab_size, config.vocab_size, scale=1 / self.config.logits_scaling, ) diff --git a/vllm/model_executor/models/granitemoeshared.py b/vllm/model_executor/models/granitemoeshared.py index e222109f2a94..e08e9f73ec87 100644 --- a/vllm/model_executor/models/granitemoeshared.py +++ b/vllm/model_executor/models/granitemoeshared.py @@ -25,7 +25,6 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding, ) @@ -159,23 +158,16 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config = vllm_config.model_config.hf_config cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config - lora_config = vllm_config.lora_config self.config = config self.quant_config = quant_config # Required by MixtralModel self.padding_idx = config.pad_token_id - lora_vocab = ( - (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) - if lora_config - else 0 - ) - self.vocab_size = config.vocab_size + lora_vocab - self.org_vocab_size = config.vocab_size + + self.vocab_size = config.vocab_size self.embed_tokens = VocabParallelEmbedding( self.vocab_size, config.hidden_size, - org_num_embeddings=config.vocab_size, quant_config=quant_config, ) self.embedding_multiplier = config.embedding_multiplier @@ -281,26 +273,16 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config - lora_config = vllm_config.lora_config self.config = config - self.lora_config = lora_config self.model = GraniteMoeSharedModel( vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") ) - self.unpadded_vocab_size = config.vocab_size - if lora_config: - self.unpadded_vocab_size += lora_config.lora_extra_vocab_size + self.lm_head = ParallelLMHead( - self.unpadded_vocab_size, + config.vocab_size, config.hidden_size, - org_num_embeddings=config.vocab_size, - padding_size=DEFAULT_VOCAB_PADDING_SIZE - # We need bigger padding if using lora for kernel - # compatibility - if not lora_config - else lora_config.lora_vocab_padding_size, quant_config=quant_config, prefix=maybe_prefix(prefix, "lm_head"), ) @@ -308,7 +290,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.lm_head.weight = self.model.embed_tokens.weight self.logits_processor = LogitsProcessor( - self.unpadded_vocab_size, + config.vocab_size, config.vocab_size, scale=1 / self.config.logits_scaling, ) diff --git a/vllm/model_executor/models/grok1.py b/vllm/model_executor/models/grok1.py index d77a0bc2993a..0770e03b5356 100644 --- a/vllm/model_executor/models/grok1.py +++ b/vllm/model_executor/models/grok1.py @@ -45,7 +45,6 @@ from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding, ) @@ -305,18 +304,13 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config = vllm_config.model_config.hf_config cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config - lora_config = vllm_config.lora_config self.config = config self.quant_config = quant_config self.padding_idx = config.pad_token_id - lora_vocab = ( - (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) - if lora_config - else 0 - ) - self.vocab_size = config.vocab_size + lora_vocab - self.org_vocab_size = config.vocab_size + + self.vocab_size = config.vocab_size + self.embedding_multiplier_scale = getattr( config, "embedding_multiplier_scale", DEFAULT_EMBEDDING_MULTIPLIER_SCALE ) @@ -324,7 +318,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.embed_tokens = VocabParallelEmbedding( self.vocab_size, config.hidden_size, - org_num_embeddings=config.vocab_size, quant_config=quant_config, ) @@ -499,25 +492,18 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config - lora_config = vllm_config.lora_config self.config = config - self.lora_config = lora_config + self.quant_config = quant_config self.model = Grok1Model( vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") ) - self.unpadded_vocab_size = config.vocab_size - if lora_config: - self.unpadded_vocab_size += lora_config.lora_extra_vocab_size - self.lm_head = ParallelLMHead( - self.unpadded_vocab_size, + config.vocab_size, config.hidden_size, - org_num_embeddings=config.vocab_size, - padding_size=DEFAULT_VOCAB_PADDING_SIZE, quant_config=quant_config, prefix=maybe_prefix(prefix, "lm_head"), ) @@ -529,7 +515,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config, "output_multiplier_scale", DEFAULT_OUTPUT_MULTIPLIER_SCALE ) self.logits_processor = LogitsProcessor( - self.unpadded_vocab_size, config.vocab_size, self.output_multiplier_scale + config.vocab_size, scale=self.output_multiplier_scale ) self.make_empty_intermediate_tensors = ( diff --git a/vllm/model_executor/models/hunyuan_v1.py b/vllm/model_executor/models/hunyuan_v1.py index 8fa9776bd018..a05a00932c13 100644 --- a/vllm/model_executor/models/hunyuan_v1.py +++ b/vllm/model_executor/models/hunyuan_v1.py @@ -57,7 +57,6 @@ from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding, ) @@ -606,7 +605,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config = vllm_config.model_config.hf_config cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config - lora_config = vllm_config.lora_config + eplb_config = vllm_config.parallel_config.eplb_config enable_eplb = vllm_config.parallel_config.enable_eplb self.num_redundant_experts = eplb_config.num_redundant_experts @@ -614,20 +613,15 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.config = config self.quant_config = quant_config self.padding_idx = config.pad_token_id - lora_vocab = ( - (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) - if lora_config - else 0 - ) - self.vocab_size = config.vocab_size + lora_vocab - self.org_vocab_size = config.vocab_size + + self.vocab_size = config.vocab_size + if get_pp_group().is_first_rank or ( config.tie_word_embeddings and get_pp_group().is_last_rank ): self.embed_tokens = VocabParallelEmbedding( self.vocab_size, config.hidden_size, - org_num_embeddings=config.vocab_size, quant_config=quant_config, ) else: @@ -937,12 +931,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.model = HunYuanModel(vllm_config=vllm_config, prefix="model") if get_pp_group().is_last_rank: - self.unpadded_vocab_size = config.vocab_size self.lm_head = ParallelLMHead( - self.unpadded_vocab_size, + config.vocab_size, config.hidden_size, - org_num_embeddings=config.vocab_size, - padding_size=DEFAULT_VOCAB_PADDING_SIZE, quant_config=quant_config, prefix=maybe_prefix(prefix, "lm_head"), ) @@ -951,7 +942,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): logit_scale = getattr(config, "logit_scale", 1.0) self.logits_processor = LogitsProcessor( - self.unpadded_vocab_size, config.vocab_size, logit_scale + config.vocab_size, scale=logit_scale ) else: self.lm_head = PPMissingLayer() diff --git a/vllm/model_executor/models/internlm2.py b/vllm/model_executor/models/internlm2.py index c5bbd5497a14..d856f5c79e33 100644 --- a/vllm/model_executor/models/internlm2.py +++ b/vllm/model_executor/models/internlm2.py @@ -330,11 +330,9 @@ def __init__( super().__init__() config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config - lora_config = vllm_config.lora_config self.config = config self.quant_config = quant_config - self.lora_config = lora_config self.model = model_type( vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index 0cb993901fd3..70f52e3106f8 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -30,7 +30,6 @@ from vllm.model_executor.layers.pooler import DispatchPooler, Pooler from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding, ) @@ -307,21 +306,14 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): model_config = vllm_config.model_config cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config - lora_config = vllm_config.lora_config self.config = config - lora_vocab = ( - (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) - if lora_config - else 0 - ) - self.vocab_size = config.vocab_size + lora_vocab - self.org_vocab_size = config.vocab_size + + self.vocab_size = config.vocab_size self.embed_tokens = VocabParallelEmbedding( self.vocab_size, config.hidden_size, - org_num_embeddings=config.vocab_size, ) extra_kwargs = {"is_lora_enabled": bool(vllm_config.lora_config)} @@ -492,7 +484,7 @@ class JambaForCausalLM( def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config = vllm_config.model_config.hf_config - lora_config = vllm_config.lora_config + scheduler_config = vllm_config.scheduler_config super().__init__() @@ -503,24 +495,14 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.model = JambaModel( vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") ) - self.unpadded_vocab_size = config.vocab_size - if lora_config: - self.unpadded_vocab_size += lora_config.lora_extra_vocab_size + self.lm_head = ParallelLMHead( - self.unpadded_vocab_size, + config.vocab_size, config.hidden_size, - org_num_embeddings=config.vocab_size, - padding_size=DEFAULT_VOCAB_PADDING_SIZE - # We need bigger padding if using lora for kernel - # compatibility - if not lora_config - else lora_config.lora_vocab_padding_size, prefix=maybe_prefix(prefix, "lm_head"), ) - self.logits_processor = LogitsProcessor( - self.unpadded_vocab_size, config.vocab_size - ) + self.logits_processor = LogitsProcessor(config.vocab_size) self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors diff --git a/vllm/model_executor/models/kimi_vl.py b/vllm/model_executor/models/kimi_vl.py index b79bdf8595ca..fa04f60b9c14 100644 --- a/vllm/model_executor/models/kimi_vl.py +++ b/vllm/model_executor/models/kimi_vl.py @@ -60,7 +60,6 @@ from vllm.model_executor.layers.linear import ReplicatedLinear from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, ) from vllm.model_executor.model_loader.weight_utils import ( @@ -347,13 +346,10 @@ def __init__( vllm_config=sub_vllm_config, prefix=maybe_prefix(prefix, "language_model"), ) - self.unpadded_vocab_size = config.text_config.vocab_size if get_pp_group().is_last_rank: self.lm_head = ParallelLMHead( - self.unpadded_vocab_size, + config.vocab_size, config.text_config.hidden_size, - org_num_embeddings=self.config.text_config.vocab_size, - padding_size=DEFAULT_VOCAB_PADDING_SIZE, prefix=maybe_prefix(prefix, "lm_head"), ) else: @@ -362,9 +358,7 @@ def __init__( self.language_model.make_empty_intermediate_tensors ) logit_scale = getattr(config, "logit_scale", 1.0) - self.logits_processor = LogitsProcessor( - self.unpadded_vocab_size, config.vocab_size, logit_scale - ) + self.logits_processor = LogitsProcessor(config.vocab_size, scale=logit_scale) self.media_placeholder: int = self.config.media_placeholder_token_id def _parse_and_validate_image_input( diff --git a/vllm/model_executor/models/lfm2.py b/vllm/model_executor/models/lfm2.py index 5684b9a89125..21d71887178e 100644 --- a/vllm/model_executor/models/lfm2.py +++ b/vllm/model_executor/models/lfm2.py @@ -28,7 +28,6 @@ from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding, ) @@ -316,16 +315,10 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): model_config = vllm_config.model_config cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config - lora_config = vllm_config.lora_config self.config = config - lora_vocab = ( - (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) - if lora_config - else 0 - ) - self.vocab_size = config.vocab_size + lora_vocab - self.org_vocab_size = config.vocab_size + + self.vocab_size = config.vocab_size self.embed_tokens = VocabParallelEmbedding( self.vocab_size, config.hidden_size, org_num_embeddings=config.vocab_size @@ -483,7 +476,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config cache_config = vllm_config.cache_config - lora_config = vllm_config.lora_config + assert not cache_config.enable_prefix_caching, ( "Lfm2 currently does not support prefix caching" ) @@ -495,21 +488,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: ) if get_pp_group().is_last_rank: - self.unpadded_vocab_size = self.config.vocab_size - if lora_config: - self.unpadded_vocab_size += lora_config.lora_extra_vocab_size - self.lm_head = ParallelLMHead( - self.unpadded_vocab_size, + config.vocab_size, config.hidden_size, - org_num_embeddings=config.vocab_size, - padding_size=( - DEFAULT_VOCAB_PADDING_SIZE - # We need bigger padding if using lora for kernel - # compatibility - if not lora_config - else lora_config.lora_vocab_padding_size - ), quant_config=quant_config, prefix=maybe_prefix(prefix, "lm_head"), ) @@ -517,9 +498,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: else: self.lm_head = PPMissingLayer() - self.logits_processor = LogitsProcessor( - self.unpadded_vocab_size, config.vocab_size - ) + self.logits_processor = LogitsProcessor(config.vocab_size) self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors diff --git a/vllm/model_executor/models/lfm2_moe.py b/vllm/model_executor/models/lfm2_moe.py index 02a490e9c7fd..b19116467105 100644 --- a/vllm/model_executor/models/lfm2_moe.py +++ b/vllm/model_executor/models/lfm2_moe.py @@ -33,7 +33,6 @@ from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding, ) @@ -423,20 +422,15 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): model_config = vllm_config.model_config cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config - lora_config = vllm_config.lora_config + parallel_config = vllm_config.parallel_config enable_eplb = parallel_config.enable_eplb eplb_config = parallel_config.eplb_config self.num_redundant_experts = eplb_config.num_redundant_experts self.config = config - lora_vocab = ( - (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) - if lora_config - else 0 - ) - self.vocab_size = config.vocab_size + lora_vocab - self.org_vocab_size = config.vocab_size + + self.vocab_size = config.vocab_size self.embed_tokens = VocabParallelEmbedding( self.vocab_size, config.hidden_size, org_num_embeddings=config.vocab_size @@ -662,7 +656,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config cache_config = vllm_config.cache_config - lora_config = vllm_config.lora_config + assert not cache_config.enable_prefix_caching, ( "Lfm2Moe currently does not support prefix caching" ) @@ -674,21 +668,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: ) if get_pp_group().is_last_rank: - self.unpadded_vocab_size = self.config.vocab_size - if lora_config: - self.unpadded_vocab_size += lora_config.lora_extra_vocab_size - self.lm_head = ParallelLMHead( - self.unpadded_vocab_size, + config.vocab_size, config.hidden_size, - org_num_embeddings=config.vocab_size, - padding_size=( - DEFAULT_VOCAB_PADDING_SIZE - # We need bigger padding if using lora for kernel - # compatibility - if not lora_config - else lora_config.lora_vocab_padding_size - ), quant_config=quant_config, prefix=maybe_prefix(prefix, "lm_head"), ) @@ -696,9 +678,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: else: self.lm_head = PPMissingLayer() - self.logits_processor = LogitsProcessor( - self.unpadded_vocab_size, config.vocab_size - ) + self.logits_processor = LogitsProcessor(config.vocab_size) self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors diff --git a/vllm/model_executor/models/llama_eagle3.py b/vllm/model_executor/models/llama_eagle3.py index da4bbda186b1..b8b9cc76d08d 100644 --- a/vllm/model_executor/models/llama_eagle3.py +++ b/vllm/model_executor/models/llama_eagle3.py @@ -15,7 +15,6 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization.base_config import QuantizationConfig from vllm.model_executor.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding, ) @@ -252,8 +251,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.lm_head = ParallelLMHead( self.config.draft_vocab_size, self.config.hidden_size, - org_num_embeddings=self.config.draft_vocab_size, - padding_size=(DEFAULT_VOCAB_PADDING_SIZE), prefix=maybe_prefix(prefix, "lm_head"), ) self.logits_processor = LogitsProcessor( diff --git a/vllm/model_executor/models/longcat_flash.py b/vllm/model_executor/models/longcat_flash.py index 5671347c00a2..b848ae6e822f 100644 --- a/vllm/model_executor/models/longcat_flash.py +++ b/vllm/model_executor/models/longcat_flash.py @@ -554,7 +554,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = FlashConfig(**vllm_config.model_config.hf_config.__dict__) quant_config = vllm_config.quant_config - lora_config = vllm_config.lora_config self.config = config config.intermediate_size = ( @@ -562,7 +561,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): if hasattr(config, "ffn_hidden_size") else config.intermediate_size ) - self.lora_config = lora_config + self.quant_config = quant_config self.model = FlashModel( diff --git a/vllm/model_executor/models/mamba.py b/vllm/model_executor/models/mamba.py index f684203f6d35..02abe693e071 100644 --- a/vllm/model_executor/models/mamba.py +++ b/vllm/model_executor/models/mamba.py @@ -21,7 +21,6 @@ ) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding, ) @@ -110,18 +109,12 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): is_lora_enabled = bool(lora_config) self.config = config - lora_vocab = ( - (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) - if lora_config - else 0 - ) - self.vocab_size = config.vocab_size + lora_vocab - self.org_vocab_size = config.vocab_size + + self.vocab_size = config.vocab_size self.embeddings = VocabParallelEmbedding( self.vocab_size, config.hidden_size, - org_num_embeddings=config.vocab_size, ) self.start_layer, self.end_layer, self.layers = make_layers( @@ -199,7 +192,7 @@ class MambaForCausalLM( ): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config = vllm_config.model_config.hf_config - lora_config = vllm_config.lora_config + self.scheduler_config = vllm_config.scheduler_config super().__init__() @@ -209,27 +202,17 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.backbone = MambaModel( vllm_config=vllm_config, prefix=maybe_prefix(prefix, "backbone") ) - self.unpadded_vocab_size = config.vocab_size - if lora_config: - self.unpadded_vocab_size += lora_config.lora_extra_vocab_size + if config.tie_word_embeddings: self.lm_head = self.backbone.embeddings else: self.lm_head = ParallelLMHead( - self.unpadded_vocab_size, + config.vocab_size, config.hidden_size, - org_num_embeddings=config.vocab_size, - padding_size=DEFAULT_VOCAB_PADDING_SIZE - # We need bigger padding if using lora for kernel - # compatibility - if not lora_config - else lora_config.lora_vocab_padding_size, prefix=maybe_prefix(prefix, "lm_head"), ) - self.logits_processor = LogitsProcessor( - self.unpadded_vocab_size, config.vocab_size - ) + self.logits_processor = LogitsProcessor(config.vocab_size) self.make_empty_intermediate_tensors = ( self.backbone.make_empty_intermediate_tensors diff --git a/vllm/model_executor/models/mamba2.py b/vllm/model_executor/models/mamba2.py index 8ba8af66635b..d19480b064e0 100644 --- a/vllm/model_executor/models/mamba2.py +++ b/vllm/model_executor/models/mamba2.py @@ -20,7 +20,6 @@ ) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding, ) @@ -107,18 +106,12 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): assert not is_lora_enabled self.config = config - lora_vocab = ( - (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) - if lora_config - else 0 - ) - self.vocab_size = config.vocab_size + lora_vocab - self.org_vocab_size = config.vocab_size + + self.vocab_size = config.vocab_size self.embeddings = VocabParallelEmbedding( self.vocab_size, config.hidden_size, - org_num_embeddings=config.vocab_size, ) self.start_layer, self.end_layer, self.layers = make_layers( @@ -238,7 +231,7 @@ def get_mamba_state_shape_from_config( def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config = vllm_config.model_config.hf_config - lora_config = vllm_config.lora_config + scheduler_config = vllm_config.scheduler_config super().__init__() @@ -249,27 +242,16 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.backbone = Mamba2Model( vllm_config=vllm_config, prefix=maybe_prefix(prefix, "backbone") ) - self.unpadded_vocab_size = config.vocab_size - if lora_config: - self.unpadded_vocab_size += lora_config.lora_extra_vocab_size self.lm_head = ParallelLMHead( - self.unpadded_vocab_size, + config.vocab_size, config.hidden_size, - org_num_embeddings=config.vocab_size, - padding_size=DEFAULT_VOCAB_PADDING_SIZE - # We need bigger padding if using lora for kernel - # compatibility - if not lora_config - else lora_config.lora_vocab_padding_size, prefix=maybe_prefix(prefix, "lm_head"), ) if config.tie_word_embeddings: self.lm_head = self.lm_head.tie_weights(self.backbone.embeddings) - self.logits_processor = LogitsProcessor( - self.unpadded_vocab_size, config.vocab_size - ) + self.logits_processor = LogitsProcessor(config.vocab_size) self.make_empty_intermediate_tensors = ( self.backbone.make_empty_intermediate_tensors diff --git a/vllm/model_executor/models/medusa.py b/vllm/model_executor/models/medusa.py index 7e1d2bf14bb5..fd7fc2c73f16 100644 --- a/vllm/model_executor/models/medusa.py +++ b/vllm/model_executor/models/medusa.py @@ -9,7 +9,6 @@ from vllm.config import VllmConfig from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, ) from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -70,14 +69,11 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: ) self.orig_vocab_size = config.vocab_size self.truncated_vocab_size = config.truncated_vocab_size - self.unpadded_vocab_size = self.truncated_vocab_size if getattr(config, "original_lm_head", False): self.lm_head = ParallelLMHead( - self.unpadded_vocab_size, + self.truncated_vocab_size, config.hidden_size, - org_num_embeddings=self.truncated_vocab_size, - padding_size=DEFAULT_VOCAB_PADDING_SIZE, prefix=maybe_prefix(prefix, "lm_head"), ) self.lm_heads = [self.lm_head for _ in range(self.config.num_heads)] @@ -85,10 +81,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: self.lm_heads = nn.ModuleList( [ ParallelLMHead( - self.unpadded_vocab_size, + config.vocab_size, config.hidden_size, - org_num_embeddings=self.truncated_vocab_size, - padding_size=DEFAULT_VOCAB_PADDING_SIZE, prefix=maybe_prefix(prefix, f"lm_heads.{i}"), ) for i in range(self.config.num_heads) @@ -97,7 +91,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: logit_scale = getattr(config, "logit_scale", 1.0) self.logits_processor = LogitsProcessor( - self.unpadded_vocab_size, self.truncated_vocab_size, logit_scale + config.vocab_size, self.truncated_vocab_size, logit_scale ) # Token map is a idx to token mapping to reduce the vocab size for diff --git a/vllm/model_executor/models/mimo.py b/vllm/model_executor/models/mimo.py index 726752a77e0d..666ac90c4429 100644 --- a/vllm/model_executor/models/mimo.py +++ b/vllm/model_executor/models/mimo.py @@ -151,10 +151,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): nn.Module.__init__(self) config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config - lora_config = vllm_config.lora_config self.config = config - self.lora_config = lora_config self.quant_config = quant_config diff --git a/vllm/model_executor/models/minicpm.py b/vllm/model_executor/models/minicpm.py index 85d3542317a1..d9f0b477180e 100644 --- a/vllm/model_executor/models/minicpm.py +++ b/vllm/model_executor/models/minicpm.py @@ -55,7 +55,6 @@ from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding, ) @@ -405,22 +404,16 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config = vllm_config.model_config.hf_config cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config - lora_config = vllm_config.lora_config self.config = config self.cache_config = cache_config self.quant_config = quant_config - lora_vocab = ( - (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) - if lora_config - else 0 - ) - self.vocab_size = config.vocab_size + lora_vocab - self.org_vocab_size = config.vocab_size + + self.vocab_size = config.vocab_size + self.embed_tokens = VocabParallelEmbedding( self.vocab_size, config.hidden_size, - org_num_embeddings=config.vocab_size, ) self.num_experts = getattr(self.config, "num_experts", 0) self._init_layers(prefix, config, cache_config, quant_config) @@ -588,13 +581,13 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config = vllm_config.model_config.hf_config cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config - lora_config = vllm_config.lora_config + parallel_config = vllm_config.parallel_config self.prefix = prefix self.vllm_config = vllm_config self.config = config - self.lora_config = lora_config + self.cache_config = cache_config self.quant_config = quant_config @@ -602,18 +595,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") ) - unpadded_vocab_size = config.vocab_size - if lora_config: - unpadded_vocab_size += lora_config.lora_extra_vocab_size self.lm_head = ParallelLMHead( - unpadded_vocab_size, + config.vocab_size, config.hidden_size, - org_num_embeddings=config.vocab_size, - padding_size=DEFAULT_VOCAB_PADDING_SIZE - # We need bigger padding if using lora for kernel - # compatibility - if not lora_config - else lora_config.lora_vocab_padding_size, quant_config=quant_config, prefix=maybe_prefix(prefix, "lm_head"), ) @@ -621,7 +605,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.lm_head = self.lm_head.tie_weights(self.model.embed_tokens) self.scale_width = self.config.hidden_size / self.config.dim_model_base - self.logits_processor = LogitsProcessor(unpadded_vocab_size, config.vocab_size) + self.logits_processor = LogitsProcessor(config.vocab_size) self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors ) diff --git a/vllm/model_executor/models/minicpm_eagle.py b/vllm/model_executor/models/minicpm_eagle.py index 463af9bbe139..6efc61e25ea1 100644 --- a/vllm/model_executor/models/minicpm_eagle.py +++ b/vllm/model_executor/models/minicpm_eagle.py @@ -37,7 +37,6 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding, ) @@ -151,18 +150,13 @@ def __init__( config = vllm_config.speculative_config.draft_model_config.hf_config cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config - lora_config = vllm_config.lora_config self.config = config self.cache_config = cache_config self.quant_config = quant_config - lora_vocab = ( - (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) - if lora_config - else 0 - ) - self.vocab_size = config.vocab_size + lora_vocab - self.org_vocab_size = config.vocab_size + + self.vocab_size = config.vocab_size + self.fc = torch.nn.Linear( self.config.hidden_size * 2, self.config.hidden_size, bias=False ) @@ -171,7 +165,6 @@ def __init__( self.embed_tokens = VocabParallelEmbedding( self.vocab_size, config.hidden_size, - org_num_embeddings=config.vocab_size, ) self.num_experts = getattr(self.config, "num_experts", 0) self._init_layers(prefix, config, cache_config, quant_config, start_layer) @@ -321,12 +314,11 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config = vllm_config.speculative_config.draft_model_config.hf_config cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config - lora_config = vllm_config.lora_config self.prefix = prefix self.vllm_config = vllm_config self.config = config - self.lora_config = lora_config + self.cache_config = cache_config self.quant_config = quant_config @@ -340,18 +332,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): start_layer=target_layer_num, ) - unpadded_vocab_size = config.vocab_size - if lora_config: - unpadded_vocab_size += lora_config.lora_extra_vocab_size self.lm_head = ParallelLMHead( - unpadded_vocab_size, + config.vocab_size, config.hidden_size, - org_num_embeddings=config.vocab_size, - padding_size=DEFAULT_VOCAB_PADDING_SIZE - # We need bigger padding if using lora for kernel - # compatibility - if not lora_config - else lora_config.lora_vocab_padding_size, quant_config=quant_config, prefix=maybe_prefix(prefix, "lm_head"), ) @@ -359,7 +342,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.lm_head = self.lm_head.tie_weights(self.model.embed_tokens) self.scale_width = self.config.hidden_size / self.config.dim_model_base - self.logits_processor = LogitsProcessor(unpadded_vocab_size, config.vocab_size) + self.logits_processor = LogitsProcessor(config.vocab_size) self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors ) diff --git a/vllm/model_executor/models/minimax_text_01.py b/vllm/model_executor/models/minimax_text_01.py index e262012dcd52..1409a309f3ae 100644 --- a/vllm/model_executor/models/minimax_text_01.py +++ b/vllm/model_executor/models/minimax_text_01.py @@ -41,7 +41,6 @@ from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding, ) @@ -669,16 +668,14 @@ class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid): def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: super().__init__() config = vllm_config.model_config.hf_config - lora_config = vllm_config.lora_config + self.config = config - self.lora_config = lora_config if not hasattr(config, "sliding_window"): config.sliding_window = None self.CONCAT_FFN = True - self.unpadded_vocab_size = self.config.vocab_size if hasattr(vllm_config.model_config, "max_model_len"): self.config.max_model_len = vllm_config.model_config.max_model_len self.model = MiniMaxText01Model( @@ -686,15 +683,13 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: ) if get_pp_group().is_last_rank: self.lm_head = ParallelLMHead( - self.unpadded_vocab_size, + config.vocab_size, self.config.hidden_size, - org_num_embeddings=self.config.vocab_size, - padding_size=DEFAULT_VOCAB_PADDING_SIZE, prefix=maybe_prefix(prefix, "lm_head"), ) self.logits_processor = LogitsProcessor( - self.unpadded_vocab_size, self.config.vocab_size + config.vocab_size, self.config.vocab_size ) else: diff --git a/vllm/model_executor/models/mlp_speculator.py b/vllm/model_executor/models/mlp_speculator.py index 4901ac74fb28..48604d8e5103 100644 --- a/vllm/model_executor/models/mlp_speculator.py +++ b/vllm/model_executor/models/mlp_speculator.py @@ -123,7 +123,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: VocabParallelEmbedding( config.vocab_size, self.inner_dim, - org_num_embeddings=config.vocab_size, ) for _ in range(self.max_speculative_tokens) ] diff --git a/vllm/model_executor/models/molmo.py b/vllm/model_executor/models/molmo.py index dce94d181c4c..7a9e3d81b73a 100644 --- a/vllm/model_executor/models/molmo.py +++ b/vllm/model_executor/models/molmo.py @@ -1404,10 +1404,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config multimodal_config = vllm_config.model_config.multimodal_config - lora_config = vllm_config.lora_config + self.config = config self.multimodal_config = multimodal_config - self.lora_config = lora_config vision_config = VisionBackboneConfig() self.vision_backbone = MolmoVisionBackbone(config, vision_config, quant_config) diff --git a/vllm/model_executor/models/nemotron.py b/vllm/model_executor/models/nemotron.py index 845798b18d1b..17e8e7f28258 100644 --- a/vllm/model_executor/models/nemotron.py +++ b/vllm/model_executor/models/nemotron.py @@ -45,7 +45,6 @@ from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding, ) @@ -319,24 +318,18 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config = vllm_config.model_config.hf_config cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config - lora_config = vllm_config.lora_config self.config = config self.quant_config = quant_config - lora_vocab = ( - (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) - if lora_config - else 0 - ) - self.vocab_size = config.vocab_size + lora_vocab - self.org_vocab_size = config.vocab_size + + self.vocab_size = config.vocab_size + if get_pp_group().is_first_rank or ( config.tie_word_embeddings and get_pp_group().is_last_rank ): self.embed_tokens = VocabParallelEmbedding( self.vocab_size, config.hidden_size, - org_num_embeddings=config.vocab_size, ) else: self.embed_tokens = PPMissingLayer() @@ -467,29 +460,20 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config - lora_config = vllm_config.lora_config + assert isinstance(config, NemotronConfig) self.config = config - self.lora_config = lora_config + self.quant_config = quant_config self.model = NemotronModel( vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") ) if get_pp_group().is_last_rank: - self.unpadded_vocab_size = config.vocab_size - if lora_config: - self.unpadded_vocab_size += lora_config.lora_extra_vocab_size self.lm_head = ParallelLMHead( - self.unpadded_vocab_size, + config.vocab_size, config.hidden_size, - org_num_embeddings=config.vocab_size, - padding_size=DEFAULT_VOCAB_PADDING_SIZE - # We need bigger padding if using lora for kernel - # compatibility - if not lora_config - else lora_config.lora_vocab_padding_size, quant_config=quant_config, prefix=maybe_prefix(prefix, "lm_head"), ) @@ -498,7 +482,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): logit_scale = getattr(config, "logit_scale", 1.0) self.logits_processor = LogitsProcessor( - self.unpadded_vocab_size, config.vocab_size, logit_scale + config.vocab_size, scale=logit_scale ) else: self.lm_head = PPMissingLayer() diff --git a/vllm/model_executor/models/nemotron_h.py b/vllm/model_executor/models/nemotron_h.py index fb58d01be7ba..8ef3eee173eb 100644 --- a/vllm/model_executor/models/nemotron_h.py +++ b/vllm/model_executor/models/nemotron_h.py @@ -50,7 +50,6 @@ ) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding, ) @@ -513,21 +512,14 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config parallel_config = vllm_config.parallel_config - lora_config = vllm_config.lora_config self.config = config - lora_vocab = ( - (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) - if lora_config - else 0 - ) - self.vocab_size = config.vocab_size + lora_vocab - self.org_vocab_size = config.vocab_size + + self.vocab_size = config.vocab_size self.embed_tokens = VocabParallelEmbedding( self.vocab_size, config.hidden_size, - org_num_embeddings=config.vocab_size, ) self.has_moe = "E" in config.hybrid_override_pattern @@ -768,7 +760,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config = vllm_config.model_config.hf_config self.vllm_config = vllm_config self.model_config = vllm_config.model_config - lora_config = vllm_config.lora_config + scheduler_config = vllm_config.scheduler_config self.quant_config = vllm_config.quant_config @@ -779,24 +771,14 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.model = NemotronHModel( vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") ) - self.unpadded_vocab_size = config.vocab_size - if lora_config: - self.unpadded_vocab_size += lora_config.lora_extra_vocab_size + self.lm_head = ParallelLMHead( - self.unpadded_vocab_size, + config.vocab_size, config.hidden_size, - org_num_embeddings=config.vocab_size, - padding_size=DEFAULT_VOCAB_PADDING_SIZE - # We need bigger padding if using lora for kernel - # compatibility - if not lora_config - else lora_config.lora_vocab_padding_size, prefix=maybe_prefix(prefix, "lm_head"), ) - self.logits_processor = LogitsProcessor( - self.unpadded_vocab_size, config.vocab_size - ) + self.logits_processor = LogitsProcessor(config.vocab_size) self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors diff --git a/vllm/model_executor/models/nemotron_nas.py b/vllm/model_executor/models/nemotron_nas.py index 17e009612df4..acd0d0c98234 100644 --- a/vllm/model_executor/models/nemotron_nas.py +++ b/vllm/model_executor/models/nemotron_nas.py @@ -41,7 +41,6 @@ from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding, ) @@ -250,25 +249,19 @@ def __init__( config = vllm_config.model_config.hf_config cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config - lora_config = vllm_config.lora_config self.config = config self.quant_config = quant_config self.padding_idx = config.pad_token_id - lora_vocab = ( - (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) - if lora_config - else 0 - ) - self.vocab_size = config.vocab_size + lora_vocab - self.org_vocab_size = config.vocab_size + + self.vocab_size = config.vocab_size + if get_pp_group().is_first_rank or ( config.tie_word_embeddings and get_pp_group().is_last_rank ): self.embed_tokens = VocabParallelEmbedding( self.vocab_size, config.hidden_size, - org_num_embeddings=config.vocab_size, quant_config=quant_config, ) else: @@ -437,29 +430,17 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config - lora_config = vllm_config.lora_config + self.config = config - self.lora_config = lora_config self.model = self._init_model( vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") ) if get_pp_group().is_last_rank: - self.unpadded_vocab_size = config.vocab_size - if lora_config: - self.unpadded_vocab_size += lora_config.lora_extra_vocab_size self.lm_head = ParallelLMHead( - self.unpadded_vocab_size, + config.vocab_size, config.hidden_size, - org_num_embeddings=config.vocab_size, - padding_size=( - DEFAULT_VOCAB_PADDING_SIZE - # We need bigger padding if using lora for kernel - # compatibility - if not lora_config - else lora_config.lora_vocab_padding_size - ), quant_config=quant_config, prefix=maybe_prefix(prefix, "lm_head"), ) @@ -468,7 +449,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): logit_scale = getattr(config, "logit_scale", 1.0) self.logits_processor = LogitsProcessor( - self.unpadded_vocab_size, config.vocab_size, logit_scale + config.vocab_size, scale=logit_scale ) else: self.lm_head = PPMissingLayer() diff --git a/vllm/model_executor/models/olmo.py b/vllm/model_executor/models/olmo.py index 390a91d3425c..cb47f76a27ff 100644 --- a/vllm/model_executor/models/olmo.py +++ b/vllm/model_executor/models/olmo.py @@ -368,11 +368,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): if config.tie_word_embeddings: self.lm_head = self.model.embed_tokens else: - self.unpadded_vocab_size = config.vocab_size self.lm_head = ParallelLMHead( - self.unpadded_vocab_size, + config.vocab_size, config.hidden_size, - org_num_embeddings=config.vocab_size, quant_config=quant_config, prefix=maybe_prefix(prefix, "lm_head"), ) diff --git a/vllm/model_executor/models/olmo2.py b/vllm/model_executor/models/olmo2.py index 7e39f6dff25e..2aa01adebc9f 100644 --- a/vllm/model_executor/models/olmo2.py +++ b/vllm/model_executor/models/olmo2.py @@ -408,11 +408,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): if config.tie_word_embeddings: self.lm_head = self.model.embed_tokens else: - self.unpadded_vocab_size = config.vocab_size self.lm_head = ParallelLMHead( config.vocab_size, config.hidden_size, - org_num_embeddings=config.vocab_size, quant_config=vllm_config.quant_config, prefix=maybe_prefix(prefix, "lm_head"), ) diff --git a/vllm/model_executor/models/ouro.py b/vllm/model_executor/models/ouro.py index b8dad909c547..cc7947df50ae 100644 --- a/vllm/model_executor/models/ouro.py +++ b/vllm/model_executor/models/ouro.py @@ -462,10 +462,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config - lora_config = vllm_config.lora_config self.config = config - self.lora_config = lora_config self.quant_config = quant_config self.model = OuroModel( diff --git a/vllm/model_executor/models/phi.py b/vllm/model_executor/models/phi.py index 34db124b6447..e76fb1904727 100644 --- a/vllm/model_executor/models/phi.py +++ b/vllm/model_executor/models/phi.py @@ -323,11 +323,10 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config - lora_config = vllm_config.lora_config + self.config = config # lm_head use bias, cannot share word embeddings assert not config.tie_word_embeddings - self.lora_config = lora_config self.quant_config = quant_config diff --git a/vllm/model_executor/models/phi3v.py b/vllm/model_executor/models/phi3v.py index b86fe67fb476..a7b28bd18cc7 100644 --- a/vllm/model_executor/models/phi3v.py +++ b/vllm/model_executor/models/phi3v.py @@ -591,7 +591,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.embed_tokens = VocabParallelEmbedding( config.vocab_size, config.hidden_size, - org_num_embeddings=config.vocab_size, quant_config=self.quant_config, prefix=maybe_prefix(prefix, "model.embed_tokens"), ) diff --git a/vllm/model_executor/models/phi4mm.py b/vllm/model_executor/models/phi4mm.py index acad72b058fc..c2a3be16b610 100644 --- a/vllm/model_executor/models/phi4mm.py +++ b/vllm/model_executor/models/phi4mm.py @@ -21,7 +21,6 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, ) from vllm.model_executor.models.llama import LlamaModel @@ -1023,12 +1022,10 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): multimodal_config = vllm_config.model_config.multimodal_config assert multimodal_config, "multimodal_config is required" quant_config = vllm_config.quant_config - lora_config = vllm_config.lora_config self.config = config self.multimodal_config = multimodal_config self.quant_config = quant_config - self.lora_config = lora_config # Tensor/Pipeline parallel not supported for now. assert get_pp_group().world_size == 1, "pipeline parallel is not supported" @@ -1055,23 +1052,16 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") ) - self.unpadded_vocab_size = config.vocab_size - if lora_config: - self.unpadded_vocab_size += lora_config.lora_extra_vocab_size self.lm_head = ParallelLMHead( - self.unpadded_vocab_size, + config.vocab_size, config.hidden_size, - org_num_embeddings=config.vocab_size, - padding_size=DEFAULT_VOCAB_PADDING_SIZE, quant_config=quant_config, prefix=maybe_prefix(prefix, "lm_head"), ) if config.tie_word_embeddings: self.lm_head = self.lm_head.tie_weights(self.model.embed_tokens) logit_scale = getattr(config, "logit_scale", 1.0) - self.logits_processor = LogitsProcessor( - self.unpadded_vocab_size, config.vocab_size, logit_scale - ) + self.logits_processor = LogitsProcessor(config.vocab_size, scale=logit_scale) def _parse_and_validate_audio_input( self, **kwargs: object diff --git a/vllm/model_executor/models/phimoe.py b/vllm/model_executor/models/phimoe.py index c7436cedeb22..97e553787790 100644 --- a/vllm/model_executor/models/phimoe.py +++ b/vllm/model_executor/models/phimoe.py @@ -45,7 +45,6 @@ from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding, ) @@ -458,22 +457,15 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config = vllm_config.model_config.hf_config cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config - lora_config = vllm_config.lora_config - lora_vocab = ( - (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) - if lora_config - else 0 - ) - self.vocab_size = config.vocab_size + lora_vocab - self.org_vocab_size = config.vocab_size + self.vocab_size = config.vocab_size + self.config = config self.quant_config = quant_config self.embed_tokens = VocabParallelEmbedding( self.vocab_size, config.hidden_size, - org_num_embeddings=config.vocab_size, ) self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, @@ -634,35 +626,23 @@ class PhiMoEForCausalLM(nn.Module, SupportsLoRA, SupportsPP): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config - lora_config = vllm_config.lora_config + self.config = config - self.lora_config = lora_config + self.quant_config = vllm_config.quant_config self.model = PhiMoEModel( vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") ) - self.unpadded_vocab_size = config.vocab_size - if lora_config: - self.unpadded_vocab_size += lora_config.lora_extra_vocab_size + self.lm_head = ParallelLMHead( - self.unpadded_vocab_size, + config.vocab_size, config.hidden_size, - org_num_embeddings=config.vocab_size, - padding_size=( - DEFAULT_VOCAB_PADDING_SIZE - # We need bigger padding if using lora for kernel - # compatibility - if not lora_config - else lora_config.lora_vocab_padding_size - ), quant_config=None, bias=True, prefix=maybe_prefix(prefix, "lm_head"), ) - self.logits_processor = LogitsProcessor( - self.unpadded_vocab_size, config.vocab_size - ) + self.logits_processor = LogitsProcessor(config.vocab_size) self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors diff --git a/vllm/model_executor/models/plamo2.py b/vllm/model_executor/models/plamo2.py index 6427ccfccc13..ece1c5ec23cf 100644 --- a/vllm/model_executor/models/plamo2.py +++ b/vllm/model_executor/models/plamo2.py @@ -46,7 +46,6 @@ from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding, ) @@ -751,12 +750,10 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.config = config self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size - self.org_vocab_size = config.vocab_size self.embed_tokens = VocabParallelEmbedding( self.vocab_size, config.hidden_size, - org_num_embeddings=config.vocab_size, prefix=f"{prefix}.embed_tokens", ) self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( @@ -827,20 +824,16 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") ) self.vocab_size = self.config.vocab_size - self.unpadded_vocab_size = self.config.vocab_size - num_embeddings = ((self.vocab_size + 15) // 16) * 16 self.lm_head = ParallelLMHead( - num_embeddings, + self.vocab_size, self.config.hidden_size, - org_num_embeddings=self.config.vocab_size, - padding_size=DEFAULT_VOCAB_PADDING_SIZE, prefix=f"{prefix}.lm_head", ) if self.config.tie_word_embeddings: self.lm_head = self.lm_head.tie_weights(self.model.embed_tokens) self.logits_processor = LogitsProcessor( - self.unpadded_vocab_size, self.config.vocab_size + config.vocab_size, self.config.vocab_size ) self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py index b26546647ce7..cdf32c6c5137 100644 --- a/vllm/model_executor/models/qwen2.py +++ b/vllm/model_executor/models/qwen2.py @@ -477,10 +477,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config - lora_config = vllm_config.lora_config self.config = config - self.lora_config = lora_config self.quant_config = quant_config self.model = Qwen2Model( diff --git a/vllm/model_executor/models/qwen2_rm.py b/vllm/model_executor/models/qwen2_rm.py index e2ba0e262cf7..c5582218b852 100644 --- a/vllm/model_executor/models/qwen2_rm.py +++ b/vllm/model_executor/models/qwen2_rm.py @@ -43,10 +43,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config - lora_config = vllm_config.lora_config self.config = config - self.lora_config = lora_config self.quant_config = quant_config self.model = Qwen2Model( diff --git a/vllm/model_executor/models/qwen3.py b/vllm/model_executor/models/qwen3.py index 563d3cc23d72..f689ff79d761 100644 --- a/vllm/model_executor/models/qwen3.py +++ b/vllm/model_executor/models/qwen3.py @@ -272,10 +272,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config - lora_config = vllm_config.lora_config self.config = config - self.lora_config = lora_config self.quant_config = quant_config self.model = Qwen3Model( diff --git a/vllm/model_executor/models/qwen3_next.py b/vllm/model_executor/models/qwen3_next.py index ddb8693c16e2..9cd342caacb0 100644 --- a/vllm/model_executor/models/qwen3_next.py +++ b/vllm/model_executor/models/qwen3_next.py @@ -59,7 +59,6 @@ from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding, ) @@ -967,22 +966,17 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config: Qwen3NextConfig = vllm_config.model_config.hf_config parallel_config = vllm_config.parallel_config - lora_config = vllm_config.lora_config + eplb_config = parallel_config.eplb_config self.num_redundant_experts = eplb_config.num_redundant_experts self.config = config - lora_vocab = ( - (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) - if lora_config - else 0 - ) - self.vocab_size = config.vocab_size + lora_vocab + + self.vocab_size = config.vocab_size self.embed_tokens = VocabParallelEmbedding( self.vocab_size, config.hidden_size, - org_num_embeddings=config.vocab_size, ) def get_layer(prefix: str): @@ -1196,7 +1190,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.vllm_config = vllm_config self.model_config = vllm_config.model_config cache_config = vllm_config.cache_config - lora_config = vllm_config.lora_config + scheduler_config = vllm_config.scheduler_config assert not cache_config.enable_prefix_caching, ( "Qwen3Next currently does not support prefix caching" @@ -1209,23 +1203,13 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.model = Qwen3NextModel( vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") ) - self.unpadded_vocab_size = config.vocab_size - if lora_config: - self.unpadded_vocab_size += lora_config.lora_extra_vocab_size + self.lm_head = ParallelLMHead( - self.unpadded_vocab_size, + config.vocab_size, config.hidden_size, - org_num_embeddings=config.vocab_size, - padding_size=DEFAULT_VOCAB_PADDING_SIZE - # We need bigger padding if using lora for kernel - # compatibility - if not lora_config - else lora_config.lora_vocab_padding_size, prefix=maybe_prefix(prefix, "lm_head"), ) - self.logits_processor = LogitsProcessor( - self.unpadded_vocab_size, config.vocab_size - ) + self.logits_processor = LogitsProcessor(config.vocab_size) self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors ) diff --git a/vllm/model_executor/models/qwen3_next_mtp.py b/vllm/model_executor/models/qwen3_next_mtp.py index 271b76adcff7..9a552db029ee 100644 --- a/vllm/model_executor/models/qwen3_next_mtp.py +++ b/vllm/model_executor/models/qwen3_next_mtp.py @@ -15,7 +15,6 @@ from vllm.model_executor.layers.linear import ColumnParallelLinear from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding, ) @@ -48,17 +47,12 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): model_config = vllm_config.model_config quant_config = vllm_config.quant_config - lora_config = vllm_config.lora_config + config: Qwen3NextConfig = model_config.hf_config self.config = config - lora_vocab = ( - (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) - if lora_config - else 0 - ) - self.vocab_size = config.vocab_size + lora_vocab - self.org_vocab_size = config.vocab_size + + self.vocab_size = config.vocab_size self.mtp_start_layer_idx = config.num_hidden_layers self.num_mtp_layers = getattr(config, "num_nextn_predict_layers", 1) @@ -66,7 +60,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.embed_tokens = VocabParallelEmbedding( self.vocab_size, config.hidden_size, - org_num_embeddings=config.vocab_size, ) self.fc = ColumnParallelLinear( @@ -252,17 +245,13 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.model = Qwen3NextMultiTokenPredictor( vllm_config=vllm_config, prefix=maybe_prefix(prefix, "mtp") ) - self.unpadded_vocab_size = config.vocab_size + self.lm_head = ParallelLMHead( - self.unpadded_vocab_size, + config.vocab_size, config.hidden_size, - org_num_embeddings=config.vocab_size, - padding_size=DEFAULT_VOCAB_PADDING_SIZE, prefix=maybe_prefix(prefix, "lm_head"), ) - self.logits_processor = LogitsProcessor( - self.unpadded_vocab_size, config.vocab_size - ) + self.logits_processor = LogitsProcessor(config.vocab_size) self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors ) diff --git a/vllm/model_executor/models/qwen3_vl.py b/vllm/model_executor/models/qwen3_vl.py index 97d4667d82e9..d880e6015e5d 100644 --- a/vllm/model_executor/models/qwen3_vl.py +++ b/vllm/model_executor/models/qwen3_vl.py @@ -1136,10 +1136,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super(Qwen3ForCausalLM, self).__init__() config = vllm_config.model_config.hf_config.text_config quant_config = vllm_config.quant_config - lora_config = vllm_config.lora_config self.config = config - self.lora_config = lora_config self.quant_config = quant_config self.model = Qwen3LLMModel(vllm_config=vllm_config, prefix=prefix) diff --git a/vllm/model_executor/models/seed_oss.py b/vllm/model_executor/models/seed_oss.py index 641160295afb..04da19a440a1 100644 --- a/vllm/model_executor/models/seed_oss.py +++ b/vllm/model_executor/models/seed_oss.py @@ -440,10 +440,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config - lora_config = vllm_config.lora_config self.config = config - self.lora_config = lora_config self.quant_config = quant_config self.model = SeedOssModel( diff --git a/vllm/model_executor/models/solar.py b/vllm/model_executor/models/solar.py index f0dfce7bc7b6..5b8bf150edf6 100644 --- a/vllm/model_executor/models/solar.py +++ b/vllm/model_executor/models/solar.py @@ -46,7 +46,6 @@ from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding, ) @@ -277,24 +276,18 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config = vllm_config.model_config.hf_config cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config - lora_config = vllm_config.lora_config self.config = config self.quant_config = quant_config - lora_vocab = ( - (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) - if lora_config - else 0 - ) - self.vocab_size = config.vocab_size + lora_vocab - self.org_vocab_size = config.vocab_size + + self.vocab_size = config.vocab_size + if get_pp_group().is_first_rank or ( config.tie_word_embeddings and get_pp_group().is_last_rank ): self.embed_tokens = VocabParallelEmbedding( self.vocab_size, config.hidden_size, - org_num_embeddings=config.vocab_size, ) else: self.embed_tokens = PPMissingLayer() @@ -455,9 +448,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config - lora_config = vllm_config.lora_config + self.config = config - self.lora_config = lora_config + self.quant_config = quant_config self.model = SolarModel( @@ -465,18 +458,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): prefix=maybe_prefix(prefix, "model"), ) if get_pp_group().is_last_rank: - self.unpadded_vocab_size = config.vocab_size - if lora_config: - self.unpadded_vocab_size += lora_config.lora_extra_vocab_size self.lm_head = ParallelLMHead( - self.unpadded_vocab_size, + config.vocab_size, config.hidden_size, - org_num_embeddings=config.vocab_size, - padding_size=DEFAULT_VOCAB_PADDING_SIZE - # We need bigger padding if using lora for kernel - # compatibility - if not lora_config - else lora_config.lora_vocab_padding_size, quant_config=quant_config, prefix=maybe_prefix(prefix, "lm_head"), ) @@ -485,7 +469,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): logit_scale = getattr(config, "logit_scale", 1.0) self.logits_processor = LogitsProcessor( - self.unpadded_vocab_size, config.vocab_size, logit_scale + config.vocab_size, scale=logit_scale ) else: self.lm_head = PPMissingLayer() diff --git a/vllm/model_executor/models/starcoder2.py b/vllm/model_executor/models/starcoder2.py index d147237808c2..4cdc90b1f5cb 100644 --- a/vllm/model_executor/models/starcoder2.py +++ b/vllm/model_executor/models/starcoder2.py @@ -42,7 +42,6 @@ from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding, ) @@ -319,22 +318,17 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") ) self.vocab_size = config.vocab_size - self.unpadded_vocab_size = config.vocab_size + if config.tie_word_embeddings: self.lm_head = self.model.embed_tokens else: - self.unpadded_vocab_size = config.vocab_size self.lm_head = ParallelLMHead( - self.unpadded_vocab_size, + config.vocab_size, config.hidden_size, - org_num_embeddings=config.vocab_size, - padding_size=DEFAULT_VOCAB_PADDING_SIZE, quant_config=quant_config, prefix=f"{prefix}.lm_head", ) - self.logits_processor = LogitsProcessor( - self.unpadded_vocab_size, config.vocab_size - ) + self.logits_processor = LogitsProcessor(config.vocab_size) self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors ) diff --git a/vllm/model_executor/models/step3_text.py b/vllm/model_executor/models/step3_text.py index a2a1bfd30d8d..381b3f4932e5 100644 --- a/vllm/model_executor/models/step3_text.py +++ b/vllm/model_executor/models/step3_text.py @@ -31,7 +31,6 @@ from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding, ) @@ -400,28 +399,19 @@ def __init__( ): super().__init__() config = vllm_config.model_config.hf_config - lora_config = vllm_config.lora_config + self.config = config self.vllm_config = vllm_config self.model = Step3TextModel(vllm_config=vllm_config, prefix=prefix) if get_pp_group().is_last_rank: - self.unpadded_vocab_size = config.vocab_size - if lora_config: - self.unpadded_vocab_size += lora_config.lora_extra_vocab_size self.lm_head = ParallelLMHead( - self.unpadded_vocab_size, + config.vocab_size, config.hidden_size, - org_num_embeddings=config.vocab_size, - padding_size=DEFAULT_VOCAB_PADDING_SIZE - if not lora_config - else lora_config.lora_vocab_padding_size, prefix=maybe_prefix(prefix, "lm_head"), ) - self.logits_processor = LogitsProcessor( - self.unpadded_vocab_size, config.vocab_size - ) + self.logits_processor = LogitsProcessor(config.vocab_size) else: self.lm_head = PPMissingLayer() diff --git a/vllm/model_executor/models/transformers/causal.py b/vllm/model_executor/models/transformers/causal.py index 7f7b15a5675a..42fd11117c73 100644 --- a/vllm/model_executor/models/transformers/causal.py +++ b/vllm/model_executor/models/transformers/causal.py @@ -42,7 +42,6 @@ def __init__(self, *, vllm_config: "VllmConfig", prefix: str = ""): self.skip_prefixes.append("lm_head.") if self.pp_group.is_last_rank: - self.unpadded_vocab_size = self.text_config.vocab_size self.lm_head = ParallelLMHead( self.text_config.vocab_size, self.text_config.hidden_size, @@ -56,7 +55,7 @@ def __init__(self, *, vllm_config: "VllmConfig", prefix: str = ""): logit_scale = getattr(self.text_config, "logit_scale", 1.0) self.logits_processor = LogitsProcessor( - self.unpadded_vocab_size, self.text_config.vocab_size, logit_scale + self.text_config.vocab_size, scale=logit_scale ) else: self.lm_head = PPMissingLayer() diff --git a/vllm/model_executor/models/whisper.py b/vllm/model_executor/models/whisper.py index ccfe1871ef07..502783b1fd93 100644 --- a/vllm/model_executor/models/whisper.py +++ b/vllm/model_executor/models/whisper.py @@ -890,7 +890,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.dtype = vllm_config.model_config.dtype self.model = WhisperModel(vllm_config=vllm_config, prefix=prefix) - self.unpadded_vocab_size = config.vocab_size + self.proj_out = ParallelLMHead( config.vocab_size, config.d_model, @@ -899,9 +899,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): ) self.proj_out = self.proj_out.tie_weights(self.model.decoder.embed_tokens) logit_scale = getattr(config, "logit_scale", 1.0) - self.logits_processor = LogitsProcessor( - self.unpadded_vocab_size, config.vocab_size, logit_scale - ) + self.logits_processor = LogitsProcessor(config.vocab_size, scale=logit_scale) def forward( self, diff --git a/vllm/model_executor/models/zamba2.py b/vllm/model_executor/models/zamba2.py index bc1351600a2f..bf3107525bc5 100644 --- a/vllm/model_executor/models/zamba2.py +++ b/vllm/model_executor/models/zamba2.py @@ -38,7 +38,6 @@ from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding, ) @@ -692,19 +691,13 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: assert not is_lora_enabled self.config = config - lora_vocab = ( - (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) - if lora_config - else 0 - ) - self.vocab_size = config.vocab_size + lora_vocab - self.org_vocab_size = config.vocab_size + + self.vocab_size = config.vocab_size # Initialize token embeddings self.embed_tokens = VocabParallelEmbedding( self.vocab_size, config.hidden_size, - org_num_embeddings=config.vocab_size, ) # Map hybrid layer indices to block indices @@ -911,7 +904,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: (not supported by Mamba) """ config = vllm_config.model_config.hf_config - lora_config = vllm_config.lora_config + scheduler_config = vllm_config.scheduler_config super().__init__() @@ -919,9 +912,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: self.vllm_config = vllm_config self.scheduler_config = scheduler_config self.model_config = vllm_config.model_config - self.unpadded_vocab_size = config.vocab_size - if lora_config: - self.unpadded_vocab_size += lora_config.lora_extra_vocab_size # Initialize core model self.model = Zamba2Model( @@ -930,23 +920,15 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: # Initialize language modeling head self.lm_head = ParallelLMHead( - self.unpadded_vocab_size, + config.vocab_size, config.hidden_size, - org_num_embeddings=config.vocab_size, - padding_size=DEFAULT_VOCAB_PADDING_SIZE - # We need bigger padding if using lora for kernel - # compatibility - if not lora_config - else lora_config.lora_vocab_padding_size, prefix=maybe_prefix(prefix, "lm_head"), ) # Tie weights with input embeddings if using same dimensions self.lm_head = self.lm_head.tie_weights(self.model.embed_tokens) # Initialize logits processing and sampling - self.logits_processor = LogitsProcessor( - self.unpadded_vocab_size, config.vocab_size - ) + self.logits_processor = LogitsProcessor(config.vocab_size) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: """Convert input token IDs to embeddings. From df4d3a44a83681feea723cc4c4ebe9085d29d58d Mon Sep 17 00:00:00 2001 From: Kyuyeun Kim <62023335+kyuyeunk@users.noreply.github.com> Date: Tue, 11 Nov 2025 11:16:47 -0800 Subject: [PATCH 077/183] [TPU] Rename path to tpu platform (#28452) Signed-off-by: Kyuyeun Kim --- vllm/platforms/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/platforms/__init__.py b/vllm/platforms/__init__.py index badf72de4a90..a45ca988200d 100644 --- a/vllm/platforms/__init__.py +++ b/vllm/platforms/__init__.py @@ -38,7 +38,7 @@ def tpu_platform_plugin() -> str | None: # Check for Pathways TPU proxy if envs.VLLM_TPU_USING_PATHWAYS: logger.debug("Confirmed TPU platform is available via Pathways proxy.") - return "tpu_inference.platforms.tpu_jax.TpuPlatform" + return "tpu_inference.platforms.tpu_platform.TpuPlatform" # Check for libtpu installation try: From d4902ba56d9b265698fb53f2d956117454945371 Mon Sep 17 00:00:00 2001 From: wangxiyuan Date: Wed, 12 Nov 2025 06:28:07 +0800 Subject: [PATCH 078/183] [Misc] Cleanup Executor interface (#28441) Signed-off-by: wangxiyuan --- vllm/v1/executor/abstract.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/vllm/v1/executor/abstract.py b/vllm/v1/executor/abstract.py index 1e913876b763..db8303fcec50 100644 --- a/vllm/v1/executor/abstract.py +++ b/vllm/v1/executor/abstract.py @@ -294,12 +294,6 @@ def reset_mm_cache(self) -> None: """Reset the multi-modal cache in each worker.""" self.collective_rpc("reset_mm_cache") - def start_profile(self) -> None: - self.collective_rpc("start_profile") - - def stop_profile(self) -> None: - self.collective_rpc("stop_profile") - def sleep(self, level: int = 1): if self.is_sleeping: logger.warning("Executor is already sleeping.") From 28534b92b9f002e56d4e31d02ca59a070cdad468 Mon Sep 17 00:00:00 2001 From: Michael Goin Date: Tue, 11 Nov 2025 17:53:59 -0500 Subject: [PATCH 079/183] Add Zurich vLLM Meetup (#28488) Signed-off-by: mgoin --- README.md | 1 + docs/community/meetups.md | 1 + 2 files changed, 2 insertions(+) diff --git a/README.md b/README.md index b5e230e4b9b0..033e1035d891 100644 --- a/README.md +++ b/README.md @@ -21,6 +21,7 @@ Join us at the [PyTorch Conference, October 22-23](https://events.linuxfoundatio *Latest News* 🔥 +- [2025/11] We hosted [the first vLLM Europe Meetup in Zurich](https://luma.com/0gls27kb) focused on quantization, distributed inference, and reinforcement learning at scale with speakers from Mistral, IBM, and Red Hat. Please find the meetup slides [here](https://docs.google.com/presentation/d/1UC9PTLCHYXQpOmJDSFg6Sljra3iVXzc09DeEI7dnxMc/edit?usp=sharing) and recording [here](https://www.youtube.com/watch?v=6m6ZE6yVEDI) - [2025/11] We hosted [vLLM Beijing Meetup](https://mp.weixin.qq.com/s/xSrYXjNgr1HbCP4ExYNG1w) focusing on distributed inference and diverse accelerator support with vLLM! Please find the meetup slides [here](https://drive.google.com/drive/folders/1nQJ8ZkLSjKxvu36sSHaceVXtttbLvvu-?usp=drive_link). - [2025/10] We hosted [vLLM Shanghai Meetup](https://mp.weixin.qq.com/s/__xb4OyOsImz-9eAVrdlcg) focused on hands-on vLLM inference optimization! Please find the meetup slides [here](https://drive.google.com/drive/folders/1KqwjsFJLfEsC8wlDugnrR61zsWHt94Q6). - [2025/09] We hosted [vLLM Toronto Meetup](https://luma.com/e80e0ymm) focused on tackling inference at scale and speculative decoding with speakers from NVIDIA and Red Hat! Please find the meetup slides [here](https://docs.google.com/presentation/d/1IYJYmJcu9fLpID5N5RbW_vO0XLo0CGOR14IXOjB61V8/edit?usp=sharing). diff --git a/docs/community/meetups.md b/docs/community/meetups.md index 7ddd45799789..3fca4659e284 100644 --- a/docs/community/meetups.md +++ b/docs/community/meetups.md @@ -2,6 +2,7 @@ We host regular meetups in San Francisco Bay Area every 2 months. We will share the project updates from the vLLM team and have guest speakers from the industry to share their experience and insights. Please find the materials of our previous meetups below: +- [vLLM Zurich Meetup](https://luma.com/0gls27kb), November 6th 2025. [[Slides]](https://docs.google.com/presentation/d/1UC9PTLCHYXQpOmJDSFg6Sljra3iVXzc09DeEI7dnxMc/edit?usp=sharing) [[Recording]](https://www.youtube.com/watch?v=6m6ZE6yVEDI) - [vLLM Beijing Meetup](https://mp.weixin.qq.com/s/xSrYXjNgr1HbCP4ExYNG1w), November 1st 2025. [[Slides]](https://drive.google.com/drive/folders/1nQJ8ZkLSjKxvu36sSHaceVXtttbLvvu-?usp=drive_link) - [vLLM Shanghai Meetup](https://mp.weixin.qq.com/s/__xb4OyOsImz-9eAVrdlcg), October 25th 2025. [[Slides]](https://drive.google.com/drive/folders/1KqwjsFJLfEsC8wlDugnrR61zsWHt94Q6) - [vLLM Toronto Meetup](https://luma.com/e80e0ymm), September 25th 2025. [[Slides]](https://docs.google.com/presentation/d/1IYJYmJcu9fLpID5N5RbW_vO0XLo0CGOR14IXOjB61V8/edit?usp=sharing) From e5f599d4d1cfd34a5216cf0733d152ea42073f28 Mon Sep 17 00:00:00 2001 From: Michael Goin Date: Tue, 11 Nov 2025 18:16:12 -0500 Subject: [PATCH 080/183] [Bugfix] Disable shared expert overlap if Marlin MoE is used (#28410) Signed-off-by: mgoin --- vllm/model_executor/layers/fused_moe/layer.py | 4 ++++ .../layers/fused_moe/shared_fused_moe.py | 10 +++++----- vllm/model_executor/layers/quantization/awq_marlin.py | 1 + .../compressed_tensors/compressed_tensors_moe.py | 1 + vllm/model_executor/layers/quantization/gptq_marlin.py | 1 + vllm/model_executor/layers/quantization/mxfp4.py | 1 + 6 files changed, 13 insertions(+), 5 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index e198322ba7a8..615da58eeda2 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -678,6 +678,10 @@ def use_flashinfer_cutlass_kernels(self): and self.moe_config.use_flashinfer_cutlass_kernels ) + @property + def use_marlin_kernels(self): + return getattr(self.quant_method, "use_marlin", False) + @property def use_dp_chunking(self) -> bool: return ( diff --git a/vllm/model_executor/layers/fused_moe/shared_fused_moe.py b/vllm/model_executor/layers/fused_moe/shared_fused_moe.py index 06112ca51b6d..6ec8b33ed930 100644 --- a/vllm/model_executor/layers/fused_moe/shared_fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/shared_fused_moe.py @@ -28,17 +28,17 @@ def __init__( super().__init__(**kwargs) self._shared_experts = shared_experts - # Disable shared expert overlap if we are using eplb, because of - # correctness issues, or if using flashinfer with DP, since there - # is nothing to be gained in this case. Disabling the overlap - # optimization also prevents the shared experts from being hidden - # from torch.compile. + # Disable shared expert overlap if: + # - we are using eplb, because of correctness issues + # - we are using flashinfer with DP, since there nothint to gain + # - we are using marlin kjernels self.use_overlapped = ( use_overlapped and not ( # TODO(wentao): find the root cause and remove this condition self.enable_eplb or (self.moe_config.use_flashinfer_cutlass_kernels and self.dp_size > 1) + or self.use_marlin_kernels ) and self._shared_experts is not None ) diff --git a/vllm/model_executor/layers/quantization/awq_marlin.py b/vllm/model_executor/layers/quantization/awq_marlin.py index 3e1f87b59a34..3f6ea68072b4 100644 --- a/vllm/model_executor/layers/quantization/awq_marlin.py +++ b/vllm/model_executor/layers/quantization/awq_marlin.py @@ -424,6 +424,7 @@ def __init__( if self.quant_config.weight_bits != 4: raise ValueError("AWQMoEMethod only supports 4bit now.") self.quant_type = scalar_types.uint4 + self.use_marlin = True def create_weights( self, diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index 6257a410e943..f1050c15f79e 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -1342,6 +1342,7 @@ def __init__( f"{WNA16_SUPPORTED_BITS}", ) self.quant_type = WNA16_SUPPORTED_TYPES_MAP[self.num_bits] + self.use_marlin = True def create_weights( self, diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index 42a569e7770c..68a122fd46c6 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -482,6 +482,7 @@ def __init__( self.quant_type = scalar_types.uint8b128 else: raise ValueError("GPTQMarlinMoEMethod only supports int4 and int8 now.") + self.use_marlin = True def create_weights( self, diff --git a/vllm/model_executor/layers/quantization/mxfp4.py b/vllm/model_executor/layers/quantization/mxfp4.py index 8d7297a0a1b3..7940b359a150 100644 --- a/vllm/model_executor/layers/quantization/mxfp4.py +++ b/vllm/model_executor/layers/quantization/mxfp4.py @@ -216,6 +216,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): def __init__(self, moe: FusedMoEConfig): super().__init__(moe) self.mxfp4_backend = get_mxfp4_backend(moe.is_lora_enabled) + self.use_marlin = self.mxfp4_backend == Mxfp4Backend.MARLIN self.max_capture_size = ( get_current_vllm_config().compilation_config.max_cudagraph_capture_size ) From 412e153df557bbae541363ac4abde879a6d84488 Mon Sep 17 00:00:00 2001 From: Max Hu Date: Tue, 11 Nov 2025 18:32:20 -0500 Subject: [PATCH 081/183] [Feature] Allow configuring FlashInfer workspace size (#28269) Signed-off-by: Max Hu Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- vllm/envs.py | 6 ++++++ vllm/v1/attention/backends/flashinfer.py | 6 +++--- vllm/v1/attention/backends/mla/common.py | 16 +++++++--------- 3 files changed, 16 insertions(+), 12 deletions(-) diff --git a/vllm/envs.py b/vllm/envs.py index 52a9671bc46e..5274c8ba1b24 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -159,6 +159,7 @@ VLLM_USE_FLASHINFER_MOE_FP8: bool = False VLLM_USE_FLASHINFER_MOE_FP4: bool = False VLLM_FLASHINFER_MOE_BACKEND: Literal["throughput", "latency"] = "latency" + VLLM_FLASHINFER_WORKSPACE_BUFFER_SIZE: int = 394 * 1024 * 1024 VLLM_XGRAMMAR_CACHE_MB: int = 0 VLLM_MSGPACK_ZERO_COPY_THRESHOLD: int = 256 VLLM_ALLOW_INSECURE_SERIALIZATION: bool = False @@ -1237,6 +1238,10 @@ def get_vllm_port() -> int | None: "VLLM_FLASHINFER_MOE_BACKEND": env_with_choices( "VLLM_FLASHINFER_MOE_BACKEND", "latency", ["throughput", "latency"] ), + # Control the workspace buffer size for the FlashInfer backend. + "VLLM_FLASHINFER_WORKSPACE_BUFFER_SIZE": lambda: int( + os.getenv("VLLM_FLASHINFER_WORKSPACE_BUFFER_SIZE", str(394 * 1024 * 1024)) + ), # Control the maximum number of tokens per expert supported by the # NVFP4 MoE CUTLASS Kernel. This value is used to create a buffer for # the blockscale tensor of activations NVFP4 Quantization. @@ -1583,6 +1588,7 @@ def compute_hash() -> str: "VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8", "VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS", "VLLM_USE_FLASHINFER_MOE_MXFP4_BF16", + "VLLM_FLASHINFER_WORKSPACE_BUFFER_SIZE", "VLLM_USE_CUDNN_PREFILL", "VLLM_USE_TRTLLM_RAGGED_DEEPSEEK_PREFILL", "VLLM_USE_TRTLLM_ATTENTION", diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index 07a0ab41a9e0..18bbc3cc3c12 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -16,6 +16,7 @@ from flashinfer.prefill import trtllm_batch_context_with_kv_cache from flashinfer.utils import FP4Tensor +from vllm import envs from vllm.attention.backends.abstract import ( AttentionBackend, AttentionImpl, @@ -55,7 +56,6 @@ ) from vllm.v1.kv_cache_interface import AttentionSpec -FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024 FLASHINFER_WORKSPACE_BUFFER_SIZE_BATCH_INVARIANT = 2048 * 1024 * 1024 FP8_DTYPE = current_platform.fp8_dtype() @@ -70,7 +70,7 @@ def _get_trtllm_gen_workspace_buffer(): global trtllm_gen_workspace_buffer if trtllm_gen_workspace_buffer is None: trtllm_gen_workspace_buffer = torch.zeros( - FLASHINFER_WORKSPACE_BUFFER_SIZE, dtype=torch.uint8, device="cuda" + envs.VLLM_FLASHINFER_WORKSPACE_BUFFER_SIZE, dtype=torch.uint8, device="cuda" ) return trtllm_gen_workspace_buffer @@ -414,7 +414,7 @@ def __init__( def _get_workspace_buffer(self): if self._workspace_buffer is None: - buffer_size = FLASHINFER_WORKSPACE_BUFFER_SIZE + buffer_size = envs.VLLM_FLASHINFER_WORKSPACE_BUFFER_SIZE if vllm_is_batch_invariant(): buffer_size = FLASHINFER_WORKSPACE_BUFFER_SIZE_BATCH_INVARIANT self._workspace_buffer = torch.zeros( diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index 19bd102cb1e3..467c01cd9d06 100755 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -196,8 +196,8 @@ import torch from tqdm import tqdm -import vllm.envs as envs from vllm import _custom_ops as ops +from vllm import envs from vllm._aiter_ops import rocm_aiter_ops from vllm.attention.backends.abstract import ( AttentionBackend, @@ -453,12 +453,6 @@ def use_trtllm_ragged_deepseek_prefill() -> bool: ) -# Currently 394MB, this can be tuned based on GEMM sizes used. -# Chosen to be the same as sglang: -# https://github.com/sgl-project/sglang/blob/766392c6bda2558b61ce6d1c1bfd8081a549e1f1/python/sglang/global_config.py#L37 -FLASHINFER_WORKSPACE_BUFFER_SIZE = 394 * 1024 * 1024 - - class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): """ NOTE: Please read the comment at the top of the file before trying to @@ -590,7 +584,9 @@ def __init__( if self._use_fi_prefill: self._workspace_buffer = torch.empty( - FLASHINFER_WORKSPACE_BUFFER_SIZE, dtype=torch.uint8, device=device + envs.VLLM_FLASHINFER_WORKSPACE_BUFFER_SIZE, + dtype=torch.uint8, + device=device, ) self._fi_prefill_main: BatchPrefillWithRaggedKVCacheWrapper | None = None @@ -602,7 +598,9 @@ def __init__( if self._use_trtllm_ragged_prefill: self._workspace_buffer = torch.empty( - FLASHINFER_WORKSPACE_BUFFER_SIZE, dtype=torch.uint8, device=device + envs.VLLM_FLASHINFER_WORKSPACE_BUFFER_SIZE, + dtype=torch.uint8, + device=device, ) if self._use_cudnn_prefill: From d23539549a6db54ab152ce4e566c31f6891ddab5 Mon Sep 17 00:00:00 2001 From: Adrian Abeyta Date: Tue, 11 Nov 2025 18:34:58 -0600 Subject: [PATCH 082/183] Use FLASHINFER MLA backend when testing fp8_kv_scale_compile (#28491) Signed-off-by: adabeyta --- tests/compile/test_full_graph.py | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/tests/compile/test_full_graph.py b/tests/compile/test_full_graph.py index 71f90f6d8d3e..b4e5e56ac9fe 100644 --- a/tests/compile/test_full_graph.py +++ b/tests/compile/test_full_graph.py @@ -10,6 +10,7 @@ from tests.quantization.utils import is_quant_method_supported from vllm import LLM, SamplingParams +from vllm.attention.backends.registry import AttentionBackendEnum from vllm.config import CompilationConfig, CompilationMode, CUDAGraphMode, PassConfig from vllm.platforms import current_platform from vllm.utils.torch_utils import is_torch_equal_or_newer @@ -184,13 +185,24 @@ def test_custom_compile_config( [CompilationMode.NONE, CompilationMode.VLLM_COMPILE], ) @pytest.mark.parametrize( - "model", + "model, backend", [ - "Qwen/Qwen2-0.5B", # Standard attention model - "deepseek-ai/DeepSeek-V2-Lite", # MLA (Multi-head Latent Attention) model + ("Qwen/Qwen2-0.5B", None), # Standard attention model + ( + "deepseek-ai/DeepSeek-V2-Lite", + AttentionBackendEnum.FLASHINFER_MLA, + ), # MLA (Multi-head Latent Attention) model ], ) -def test_fp8_kv_scale_compile(compilation_mode: int, model: str): +def test_fp8_kv_scale_compile( + monkeypatch: pytest.MonkeyPatch, + compilation_mode: int, + model: str, + backend: AttentionBackendEnum | None, +): + if backend: + monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend.name) + model_kwargs = { "quantization": "fp8", "kv_cache_dtype": "fp8_e4m3", From 1788aa1efb1f3cd8bf521885244aed3b89bed8a1 Mon Sep 17 00:00:00 2001 From: Ilya Markov Date: Wed, 12 Nov 2025 01:41:54 +0100 Subject: [PATCH 083/183] [BugFix] Graceful handling of torch symm mem errors. (#27671) Signed-off-by: ilmarkov Co-authored-by: Michael Goin --- .../device_communicators/symm_mem.py | 22 +++++++++++++------ vllm/envs.py | 4 ++-- 2 files changed, 17 insertions(+), 9 deletions(-) diff --git a/vllm/distributed/device_communicators/symm_mem.py b/vllm/distributed/device_communicators/symm_mem.py index 74d6fb40c83b..eb1f173b1192 100644 --- a/vllm/distributed/device_communicators/symm_mem.py +++ b/vllm/distributed/device_communicators/symm_mem.py @@ -88,13 +88,21 @@ def __init__( self.max_size = SYMM_MEM_ALL_REDUCE_MAX_SIZES[self.device_capability][ self.world_size ] - - self.buffer = torch_symm_mem.empty( - self.max_size // self.dtype.itemsize, - device=self.device, - dtype=self.dtype, - ) - handle = torch_symm_mem.rendezvous(self.buffer, self.group.group_name) + try: + self.buffer = torch_symm_mem.empty( + self.max_size // self.dtype.itemsize, + device=self.device, + dtype=self.dtype, + ) + handle = torch_symm_mem.rendezvous(self.buffer, self.group.group_name) + except RuntimeError as e: + logger.warning_once( + "SymmMemCommunicator: symmetric memory initialization failed: %s " + "Communicator is not available. To suppress this warning set " + "VLLM_ALLREDUCE_USE_SYMM_MEM=0", + str(e), + ) + return if handle.multicast_ptr == 0: logger.warning( "SymmMemCommunicator: symmetric memory " diff --git a/vllm/envs.py b/vllm/envs.py index 5274c8ba1b24..46725efac70e 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -201,7 +201,7 @@ VLLM_USE_FLASHINFER_MOE_MXFP4_BF16: bool = False VLLM_ROCM_FP8_MFMA_PAGE_ATTN: bool = False VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS: bool = False - VLLM_ALLREDUCE_USE_SYMM_MEM: bool = False + VLLM_ALLREDUCE_USE_SYMM_MEM: bool = True VLLM_TUNED_CONFIG_FOLDER: str | None = None VLLM_GPT_OSS_SYSTEM_TOOL_MCP_LABELS: set[str] = set() VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS: bool = False @@ -1389,7 +1389,7 @@ def get_vllm_port() -> int | None: ), # Whether to use pytorch symmetric memory for allreduce "VLLM_ALLREDUCE_USE_SYMM_MEM": lambda: bool( - int(os.getenv("VLLM_ALLREDUCE_USE_SYMM_MEM", "0")) + int(os.getenv("VLLM_ALLREDUCE_USE_SYMM_MEM", "1")) ), # Allows vllm to find tuned config under customized folder "VLLM_TUNED_CONFIG_FOLDER": lambda: os.getenv("VLLM_TUNED_CONFIG_FOLDER", None), From 48c879369f83ab1ab281a4bfe97f9a54790715d1 Mon Sep 17 00:00:00 2001 From: Yanan Cao Date: Tue, 11 Nov 2025 16:46:18 -0800 Subject: [PATCH 084/183] [Frontend] Change CompilationMode to a proper Enum (#28165) Signed-off-by: Yanan Cao --- tests/compile/test_basic_correctness.py | 6 ++- tests/utils_/test_argparse_utils.py | 60 +++++++++++++++++++++++++ vllm/compilation/wrapper.py | 4 +- vllm/config/compilation.py | 51 ++++++++++++++------- vllm/config/vllm.py | 5 +-- vllm/entrypoints/llm.py | 5 ++- 6 files changed, 108 insertions(+), 23 deletions(-) diff --git a/tests/compile/test_basic_correctness.py b/tests/compile/test_basic_correctness.py index 132a838b8d44..3f6898607f6b 100644 --- a/tests/compile/test_basic_correctness.py +++ b/tests/compile/test_basic_correctness.py @@ -127,7 +127,9 @@ def test_compile_correctness( CompilationMode.VLLM_COMPILE, ]: for mode in [CompilationMode.NONE, comp_mode]: - all_args.append(final_args + [f"-O.mode={mode}", "-O.backend=inductor"]) + all_args.append( + final_args + [f"-O.mode={mode.name}", "-O.backend=inductor"] + ) # inductor will change the output, so we only compare if the output # is close, not exactly the same. @@ -146,7 +148,7 @@ def test_compile_correctness( CompilationMode.DYNAMO_TRACE_ONCE, CompilationMode.VLLM_COMPILE, ]: - all_args.append(final_args + [f"-O.mode={mode}", "-O.backend=eager"]) + all_args.append(final_args + [f"-O.mode={mode.name}", "-O.backend=eager"]) all_envs.append({}) all_envs.append({}) diff --git a/tests/utils_/test_argparse_utils.py b/tests/utils_/test_argparse_utils.py index 51684edcc8a3..3310753d2b6d 100644 --- a/tests/utils_/test_argparse_utils.py +++ b/tests/utils_/test_argparse_utils.py @@ -8,6 +8,7 @@ import pytest import yaml from transformers import AutoTokenizer +from pydantic import ValidationError from vllm.transformers_utils.detokenizer_utils import convert_ids_list_to_tokens @@ -376,6 +377,65 @@ def test_load_config_file(tmp_path): os.remove(str(config_file_path)) +def test_compilation_mode_string_values(parser): + """Test that -O.mode accepts both integer and string mode values.""" + args = parser.parse_args(["-O.mode", "0"]) + assert args.compilation_config == {"mode": 0} + + args = parser.parse_args(["-O3"]) + assert args.compilation_config == {"mode": 3} + + args = parser.parse_args(["-O.mode=NONE"]) + assert args.compilation_config == {"mode": "NONE"} + + args = parser.parse_args(["-O.mode", "STOCK_TORCH_COMPILE"]) + assert args.compilation_config == {"mode": "STOCK_TORCH_COMPILE"} + + args = parser.parse_args(["-O.mode=DYNAMO_TRACE_ONCE"]) + assert args.compilation_config == {"mode": "DYNAMO_TRACE_ONCE"} + + args = parser.parse_args(["-O.mode", "VLLM_COMPILE"]) + assert args.compilation_config == {"mode": "VLLM_COMPILE"} + + args = parser.parse_args(["-O.mode=none"]) + assert args.compilation_config == {"mode": "none"} + + args = parser.parse_args(["-O.mode=vllm_compile"]) + assert args.compilation_config == {"mode": "vllm_compile"} + + +def test_compilation_config_mode_validator(): + """Test that CompilationConfig.mode field validator converts strings to integers.""" + from vllm.config.compilation import CompilationConfig, CompilationMode + + config = CompilationConfig(mode=0) + assert config.mode == CompilationMode.NONE + + config = CompilationConfig(mode=3) + assert config.mode == CompilationMode.VLLM_COMPILE + + config = CompilationConfig(mode="NONE") + assert config.mode == CompilationMode.NONE + + config = CompilationConfig(mode="STOCK_TORCH_COMPILE") + assert config.mode == CompilationMode.STOCK_TORCH_COMPILE + + config = CompilationConfig(mode="DYNAMO_TRACE_ONCE") + assert config.mode == CompilationMode.DYNAMO_TRACE_ONCE + + config = CompilationConfig(mode="VLLM_COMPILE") + assert config.mode == CompilationMode.VLLM_COMPILE + + config = CompilationConfig(mode="none") + assert config.mode == CompilationMode.NONE + + config = CompilationConfig(mode="vllm_compile") + assert config.mode == CompilationMode.VLLM_COMPILE + + with pytest.raises(ValidationError, match="Invalid compilation mode"): + CompilationConfig(mode="INVALID_MODE") + + def test_flat_product(): # Check regular itertools.product behavior result1 = list(flat_product([1, 2, 3], ["a", "b"])) diff --git a/vllm/compilation/wrapper.py b/vllm/compilation/wrapper.py index 4b10c85209f6..4d26619bd128 100644 --- a/vllm/compilation/wrapper.py +++ b/vllm/compilation/wrapper.py @@ -31,7 +31,9 @@ class TorchCompileWrapperWithCustomDispatcher: """ def __init__( - self, compiled_callable: Callable | None = None, compilation_mode: int = 0 + self, + compiled_callable: Callable | None = None, + compilation_mode: CompilationMode = CompilationMode.NONE, ): vllm_config = get_current_vllm_config() self.vllm_config = vllm_config diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index 9c9557df4e73..e1d60ee84d89 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -28,7 +28,7 @@ logger = init_logger(__name__) -class CompilationMode: +class CompilationMode(enum.IntEnum): """The compilation approach used for torch.compile-based compilation of the model.""" @@ -115,7 +115,7 @@ class PassConfig: """The threshold of the communicated tensor sizes under which vllm should use flashinfer fused allreduce. Specified as a float in MB. - Unspecified will fallback to default values + Unspecified will fallback to default values which are compute capability and world size dependent. FI_ALLREDUCE_FUSION_MAX_SIZE_MB = { 90: { @@ -244,7 +244,7 @@ class CompilationConfig: Please use mode. Currently all levels are mapped to mode. """ # Top-level Compilation control - mode: int | None = None + mode: CompilationMode | None = None """The compilation approach used for torch.compile-based compilation of the model. @@ -377,23 +377,23 @@ class CompilationConfig: FULL mode: Capture full cudagraph for all batches. Can be good for small models or workloads with small prompts; not supported by many backends. Generally for performance FULL_AND_PIECEWISE is better. - + FULL_DECODE_ONLY mode: Capture full cudagraph for decode batches only. Mixed prefill-decode batches are run without cudagraphs. Can be good for decode instances in a P/D setup where prefill is not as important so we can save some memory. - + FULL_AND_PIECEWISE mode: Capture full cudagraph for decode batches and piecewise cudagraph for prefill and mixed prefill-decode batches. This is the most performant mode for most models and is the default. Currently, the cudagraph mode is only used for the v1 engine. - Note that the cudagraph logic is generally orthogonal to the - compilation logic. While piecewise cudagraphs require piecewise + Note that the cudagraph logic is generally orthogonal to the + compilation logic. While piecewise cudagraphs require piecewise compilation (mode=VLLM_COMPILE and non-empty splitting_ops), full cudagraphs are supported with and without compilation. - - Warning: This flag is new and subject to change in addition + + Warning: This flag is new and subject to change in addition more modes may be added. """ use_cudagraph: bool = True @@ -422,7 +422,7 @@ class CompilationConfig: cudagraph. If the caller can guarantee that the same input buffers are always used, it can set this to False. Otherwise, it should set this to True, and the compiler will copy the input to an - internally managed buffer. Default is False. + internally managed buffer. Default is False. Note that this flag is only effective when cudagraph_mode is PIECEWISE. """ full_cuda_graph: bool | None = False @@ -451,7 +451,7 @@ class CompilationConfig: outside the partition functions. For a graph with N cudagraph-unsafe ops (e.g., Attention), there would be N+1 partitions. To mark an op as cudagraph unsafe, we can add `tags=(torch._C.Tag.cudagraph_unsafe)` when - register the custom op. + register the custom op. This config supports both full cudagraph and piecewise cudagraph without compiling twice. For piecewise cudagraph, it applies vLLM CUDAGraph wrapper @@ -468,8 +468,8 @@ class CompilationConfig: max_cudagraph_capture_size: int | None = field(default=None) """The maximum cudagraph capture size. - - If cudagraph_capture_sizes is specified, this will be set to the largest + + If cudagraph_capture_sizes is specified, this will be set to the largest size in that list (or checked for consistency if specified). If cudagraph_capture_sizes is not specified, the list of sizes is generated automatically following the pattern: @@ -478,7 +478,7 @@ class CompilationConfig: range(256, max_cudagraph_capture_size + 1, 16)) If not specified, max_cudagraph_capture_size is set to min(max_num_seqs*2, - 512) by default. This voids OOM in tight memory scenarios with small + 512) by default. This voids OOM in tight memory scenarios with small max_num_seqs, and prevents capture of many large graphs (>512) that would greatly increase startup time with limited performance benefit. """ @@ -579,6 +579,27 @@ def __repr__(self) -> str: __str__ = __repr__ + @field_validator("mode", mode="before") + @classmethod + def validate_mode_before(cls, value: Any) -> Any: + """ + Enable parsing the `mode` field from string mode names. + Accepts both integers (0-3) and string names, like NONE, STOCK_TORCH_COMPILE, + DYNAMO_TRACE_ONCE, VLLM_COMPILE. + """ + if isinstance(value, str): + # Convert string mode name to integer value + mode_name = value.upper() + + if mode_name not in CompilationMode.__members__: + raise ValueError( + f"Invalid compilation mode: {value}. " + f"Valid modes are: {', '.join(CompilationMode.__members__.keys())}" + ) + + return CompilationMode[mode_name] + return value + @field_validator("cudagraph_mode", mode="before") @classmethod def validate_cudagraph_mode_before(cls, value: Any) -> Any: @@ -904,7 +925,7 @@ def is_attention_compiled_piecewise(self) -> bool: return self.mode == CompilationMode.VLLM_COMPILE # Inductor partition case - return self.backend == "inductor" and self.mode > CompilationMode.NONE + return self.backend == "inductor" and self.mode != CompilationMode.NONE def custom_op_log_check(self): """ diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index 0fca967d9083..df9a1fd08af6 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -422,16 +422,13 @@ def __post_init__(self): self.compilation_config.mode = CompilationMode.VLLM_COMPILE else: self.compilation_config.mode = CompilationMode.NONE - else: - assert self.compilation_config.mode >= CompilationMode.NONE - assert self.compilation_config.mode <= CompilationMode.VLLM_COMPILE # If user does not set custom ops via none or all set it here based on # compilation mode and backend. if all(s not in self.compilation_config.custom_ops for s in ("all", "none")): if ( self.compilation_config.backend == "inductor" - and self.compilation_config.mode > CompilationMode.NONE + and self.compilation_config.mode != CompilationMode.NONE ): self.compilation_config.custom_ops.append("none") else: diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 22fe2ae9280a..62717a7eacdf 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -23,6 +23,7 @@ StructuredOutputsConfig, is_init_field, ) +from vllm.config.compilation import CompilationMode from vllm.config.model import ( ConvertOption, HfOverrides, @@ -259,7 +260,9 @@ def __init__( if compilation_config is not None: if isinstance(compilation_config, int): - compilation_config_instance = CompilationConfig(mode=compilation_config) + compilation_config_instance = CompilationConfig( + mode=CompilationMode(compilation_config) + ) elif isinstance(compilation_config, dict): compilation_config_instance = CompilationConfig( **{ From 3f770f4427cb926c24af540cc72d1b5901f7f702 Mon Sep 17 00:00:00 2001 From: Isotr0py Date: Wed, 12 Nov 2025 08:49:29 +0800 Subject: [PATCH 085/183] [Performance] Cache loaded custom logitsprocs to avoid overheads (#28462) Signed-off-by: Isotr0py --- vllm/v1/sample/logits_processor/__init__.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/vllm/v1/sample/logits_processor/__init__.py b/vllm/v1/sample/logits_processor/__init__.py index eb537eae6c90..5992c4066c9c 100644 --- a/vllm/v1/sample/logits_processor/__init__.py +++ b/vllm/v1/sample/logits_processor/__init__.py @@ -5,7 +5,7 @@ import itertools from abc import abstractmethod from collections.abc import Sequence -from functools import partial +from functools import lru_cache, partial from typing import TYPE_CHECKING import torch @@ -216,11 +216,17 @@ def build_logitsprocs( ) +cached_load_custom_logitsprocs = lru_cache(_load_custom_logitsprocs) + + def validate_logits_processors_parameters( logits_processors: Sequence[str | type[LogitsProcessor]] | None, sampling_params: SamplingParams, ): - for logits_procs in _load_custom_logitsprocs(logits_processors): + logits_processors = ( + tuple(logits_processors) if logits_processors is not None else None + ) + for logits_procs in cached_load_custom_logitsprocs(logits_processors): logits_procs.validate_params(sampling_params) From e1710393c44cff20e481b632b86d157a9d694625 Mon Sep 17 00:00:00 2001 From: wangxiyuan Date: Wed, 12 Nov 2025 09:22:16 +0800 Subject: [PATCH 086/183] [[V0 deprecation]]Remove VLLM_USE_V1 env (#28204) Signed-off-by: wangxiyuan --- .../scripts/hardware_ci/run-cpu-test.sh | 2 +- examples/offline_inference/mlpspeculator.py | 3 +- .../offline_inference/qwen2_5_omni/README.md | 2 - .../qwen2_5_omni/only_thinker.py | 7 +-- .../others/lmcache/cpu_offload_lmcache.py | 43 ++++++------------- tests/entrypoints/openai/test_orca_metrics.py | 3 -- vllm/envs.py | 13 ------ vllm/usage/usage_lib.py | 1 - 8 files changed, 15 insertions(+), 59 deletions(-) diff --git a/.buildkite/scripts/hardware_ci/run-cpu-test.sh b/.buildkite/scripts/hardware_ci/run-cpu-test.sh index 7927aef19e4e..7e0f720feaa7 100644 --- a/.buildkite/scripts/hardware_ci/run-cpu-test.sh +++ b/.buildkite/scripts/hardware_ci/run-cpu-test.sh @@ -76,7 +76,7 @@ function cpu_tests() { # Run AWQ test # docker exec cpu-test-"$NUMA_NODE" bash -c " # set -e - # VLLM_USE_V1=0 pytest -x -s -v \ + # pytest -x -s -v \ # tests/quantization/test_ipex_quant.py" # Run multi-lora tests diff --git a/examples/offline_inference/mlpspeculator.py b/examples/offline_inference/mlpspeculator.py index d5b1b4ad29a9..6a533eb5c937 100644 --- a/examples/offline_inference/mlpspeculator.py +++ b/examples/offline_inference/mlpspeculator.py @@ -4,8 +4,7 @@ This file demonstrates the usage of text generation with an LLM model, comparing the performance with and without speculative decoding. -Note that still not support `v1`: -VLLM_USE_V1=0 python examples/offline_inference/mlpspeculator.py +Note that this example is out of date and not supported in vLLM v1. """ import gc diff --git a/examples/offline_inference/qwen2_5_omni/README.md b/examples/offline_inference/qwen2_5_omni/README.md index 16d44cbadbc9..d8fb50d7fe55 100644 --- a/examples/offline_inference/qwen2_5_omni/README.md +++ b/examples/offline_inference/qwen2_5_omni/README.md @@ -11,12 +11,10 @@ python examples/offline_inference/qwen2_5_omni/only_thinker.py \ # Read vision and audio inputs from a single video file # NOTE: V1 engine does not support interleaved modalities yet. -VLLM_USE_V1=0 \ python examples/offline_inference/qwen2_5_omni/only_thinker.py \ -q use_audio_in_video # Multiple audios -VLLM_USE_V1=0 \ python examples/offline_inference/qwen2_5_omni/only_thinker.py \ -q multi_audios ``` diff --git a/examples/offline_inference/qwen2_5_omni/only_thinker.py b/examples/offline_inference/qwen2_5_omni/only_thinker.py index 6fbe1303f431..ed005e6a69b8 100644 --- a/examples/offline_inference/qwen2_5_omni/only_thinker.py +++ b/examples/offline_inference/qwen2_5_omni/only_thinker.py @@ -7,7 +7,6 @@ from typing import NamedTuple -import vllm.envs as envs from vllm import LLM, SamplingParams from vllm.assets.audio import AudioAsset from vllm.assets.image import ImageAsset @@ -72,11 +71,7 @@ def get_use_audio_in_video_query() -> QueryResult: ) asset = VideoAsset(name="baby_reading", num_frames=16) audio = asset.get_audio(sampling_rate=16000) - assert not envs.VLLM_USE_V1, ( - "V1 does not support use_audio_in_video. " - "Please launch this example with " - "`VLLM_USE_V1=0`." - ) + return QueryResult( inputs={ "prompt": prompt, diff --git a/examples/others/lmcache/cpu_offload_lmcache.py b/examples/others/lmcache/cpu_offload_lmcache.py index e10ee4e2a9a9..53036b3eb0ff 100644 --- a/examples/others/lmcache/cpu_offload_lmcache.py +++ b/examples/others/lmcache/cpu_offload_lmcache.py @@ -37,7 +37,7 @@ from vllm.engine.arg_utils import EngineArgs -def setup_environment_variables(vllm_version: str): +def setup_environment_variables(): # LMCache-related environment variables # Use experimental features in LMCache os.environ["LMCACHE_USE_EXPERIMENTAL"] = "True" @@ -47,12 +47,10 @@ def setup_environment_variables(vllm_version: str): os.environ["LMCACHE_LOCAL_CPU"] = "True" # Set local CPU memory limit to 5.0 GB os.environ["LMCACHE_MAX_LOCAL_CPU_SIZE"] = "5.0" - if vllm_version == "v0": - os.environ["VLLM_USE_V1"] = "0" @contextlib.contextmanager -def build_llm_with_lmcache(lmcache_connector: str, model: str, vllm_version: str): +def build_llm_with_lmcache(lmcache_connector: str, model: str): ktc = KVTransferConfig( kv_connector=lmcache_connector, kv_role="kv_both", @@ -60,21 +58,12 @@ def build_llm_with_lmcache(lmcache_connector: str, model: str, vllm_version: str # Set GPU memory utilization to 0.8 for an A40 GPU with 40GB # memory. Reduce the value if your GPU has less memory. # Note: LMCache supports chunked prefill (see vLLM#14505, LMCache#392). - if vllm_version == "v0": - llm_args = EngineArgs( - model=model, - kv_transfer_config=ktc, - max_model_len=8000, - gpu_memory_utilization=0.8, - enable_chunked_prefill=True, # Only in v0 - ) - else: - llm_args = EngineArgs( - model=model, - kv_transfer_config=ktc, - max_model_len=8000, - gpu_memory_utilization=0.8, - ) + llm_args = EngineArgs( + model=model, + kv_transfer_config=ktc, + max_model_len=8000, + gpu_memory_utilization=0.8, + ) llm = LLM(**asdict(llm_args)) try: @@ -116,18 +105,10 @@ def parse_args(): def main(): - args = parse_args() - - if args.version == "v0": - lmcache_connector = "LMCacheConnector" - model = "mistralai/Mistral-7B-Instruct-v0.2" - else: - lmcache_connector = "LMCacheConnectorV1" - model = "meta-llama/Meta-Llama-3.1-8B-Instruct" - - setup_environment_variables(args.version) - - with build_llm_with_lmcache(lmcache_connector, model, args.version) as llm: + lmcache_connector = "LMCacheConnectorV1" + model = "meta-llama/Meta-Llama-3.1-8B-Instruct" + setup_environment_variables() + with build_llm_with_lmcache(lmcache_connector, model) as llm: # This example script runs two requests with a shared prefix. # Define the shared prompt and specific prompts shared_prompt = "Hello, how are you?" * 1000 diff --git a/tests/entrypoints/openai/test_orca_metrics.py b/tests/entrypoints/openai/test_orca_metrics.py index d32cfde07c21..1ed44a33bf81 100644 --- a/tests/entrypoints/openai/test_orca_metrics.py +++ b/tests/entrypoints/openai/test_orca_metrics.py @@ -22,9 +22,6 @@ def monkeypatch_module(): @pytest.fixture(scope="module", params=[True]) def server(request, monkeypatch_module): - use_v1 = request.param - monkeypatch_module.setenv("VLLM_USE_V1", "1" if use_v1 else "0") - args = [ "--dtype", "bfloat16", diff --git a/vllm/envs.py b/vllm/envs.py index 46725efac70e..2aa6afcabf28 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -100,7 +100,6 @@ VLLM_SKIP_P2P_CHECK: bool = False VLLM_DISABLED_KERNELS: list[str] = [] VLLM_DISABLE_PYNCCL: bool = False - VLLM_USE_V1: bool = True VLLM_ROCM_USE_AITER: bool = False VLLM_ROCM_USE_AITER_PAGED_ATTN: bool = False VLLM_ROCM_USE_AITER_LINEAR: bool = True @@ -884,8 +883,6 @@ def get_vllm_port() -> int | None: "VLLM_DISABLE_PYNCCL": lambda: ( os.getenv("VLLM_DISABLE_PYNCCL", "False").lower() in ("true", "1") ), - # If set, use the V1 code path. - "VLLM_USE_V1": lambda: bool(int(os.getenv("VLLM_USE_V1", "1"))), # Disable aiter ops unless specifically enabled. # Acts as a parent switch to enable the rest of the other operations. "VLLM_ROCM_USE_AITER": lambda: ( @@ -1538,16 +1535,6 @@ def is_set(name: str): raise AttributeError(f"module {__name__!r} has no attribute {name!r}") -def set_vllm_use_v1(use_v1: bool): - if is_set("VLLM_USE_V1"): - raise ValueError( - "Should not call set_vllm_use_v1() if VLLM_USE_V1 is set " - "explicitly by the user. Please raise this as a Github " - "Issue and explicitly set VLLM_USE_V1=0 or 1." - ) - os.environ["VLLM_USE_V1"] = "1" if use_v1 else "0" - - def compute_hash() -> str: """ WARNING: Whenever a new key is added to this environment diff --git a/vllm/usage/usage_lib.py b/vllm/usage/usage_lib.py index c8bff8b7c80b..4eddaf56d81a 100644 --- a/vllm/usage/usage_lib.py +++ b/vllm/usage/usage_lib.py @@ -42,7 +42,6 @@ "VLLM_USE_FLASHINFER_SAMPLER", "VLLM_PP_LAYER_PARTITION", "VLLM_USE_TRITON_AWQ", - "VLLM_USE_V1", "VLLM_ENABLE_V1_MULTIPROCESSING", ] From 7f829be7d3d734020606fcca520f3c500581beb8 Mon Sep 17 00:00:00 2001 From: "Li, Jiang" Date: Wed, 12 Nov 2025 09:43:06 +0800 Subject: [PATCH 087/183] [CPU] Refactor CPU attention backend (#27954) Signed-off-by: jiang1.li --- .buildkite/release-pipeline.yaml | 2 +- .../scripts/hardware_ci/run-cpu-test.sh | 3 +- cmake/cpu_extension.cmake | 28 +- csrc/cpu/attention.cpp | 798 ------- csrc/cpu/cache.cpp | 214 -- csrc/cpu/cpu_attn.cpp | 249 +++ csrc/cpu/cpu_attn_amx.hpp | 511 +++++ csrc/cpu/cpu_attn_impl.hpp | 1977 +++++++++++++++++ csrc/cpu/cpu_attn_macros.h | 63 + csrc/cpu/cpu_attn_vec.hpp | 248 +++ csrc/cpu/cpu_attn_vec16.hpp | 171 ++ csrc/cpu/cpu_types_x86.hpp | 50 +- csrc/cpu/dnnl_helper.cpp | 18 +- csrc/cpu/dnnl_helper.h | 24 - csrc/cpu/scratchpad_manager.cpp | 23 + csrc/cpu/scratchpad_manager.h | 31 + csrc/cpu/shm.cpp | 2 +- csrc/cpu/torch_bindings.cpp | 105 +- docker/Dockerfile.cpu | 4 + docs/getting_started/installation/cpu.md | 2 + .../attention/test_attention_selector.py | 6 +- tests/kernels/attention/test_cpu_attn.py | 575 +++++ tests/kernels/test_onednn.py | 1 - .../models/language/generation/test_common.py | 17 +- .../models/language/pooling/test_embedding.py | 3 +- tests/models/registry.py | 4 +- vllm/_custom_ops.py | 82 + vllm/attention/backends/registry.py | 3 +- vllm/engine/arg_utils.py | 3 - vllm/platforms/cpu.py | 37 +- vllm/utils/__init__.py | 1 - vllm/v1/attention/backends/cpu_attn.py | 981 +++----- vllm/v1/attention/backends/utils.py | 2 +- vllm/v1/worker/cpu_model_runner.py | 14 +- 34 files changed, 4352 insertions(+), 1900 deletions(-) delete mode 100644 csrc/cpu/attention.cpp delete mode 100644 csrc/cpu/cache.cpp create mode 100644 csrc/cpu/cpu_attn.cpp create mode 100644 csrc/cpu/cpu_attn_amx.hpp create mode 100644 csrc/cpu/cpu_attn_impl.hpp create mode 100644 csrc/cpu/cpu_attn_macros.h create mode 100644 csrc/cpu/cpu_attn_vec.hpp create mode 100644 csrc/cpu/cpu_attn_vec16.hpp create mode 100644 csrc/cpu/scratchpad_manager.cpp create mode 100644 csrc/cpu/scratchpad_manager.h create mode 100644 tests/kernels/attention/test_cpu_attn.py diff --git a/.buildkite/release-pipeline.yaml b/.buildkite/release-pipeline.yaml index 12f730738b8a..38c400ba1faf 100644 --- a/.buildkite/release-pipeline.yaml +++ b/.buildkite/release-pipeline.yaml @@ -132,7 +132,7 @@ steps: queue: cpu_queue_postmerge commands: - "aws ecr-public get-login-password --region us-east-1 | docker login --username AWS --password-stdin public.ecr.aws/q9t5s3a7" - - "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg GIT_REPO_CHECK=1 --build-arg VLLM_CPU_AVX512BF16=true --build-arg VLLM_CPU_AVX512VNNI=true --tag public.ecr.aws/q9t5s3a7/vllm-cpu-release-repo:$(buildkite-agent meta-data get release-version) --tag public.ecr.aws/q9t5s3a7/vllm-cpu-release-repo:latest --progress plain --target vllm-openai -f docker/Dockerfile.cpu ." + - "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg GIT_REPO_CHECK=1 --build-arg VLLM_CPU_AVX512BF16=true --build-arg VLLM_CPU_AVX512VNNI=true --build-arg VLLM_CPU_AMXBF16=true --tag public.ecr.aws/q9t5s3a7/vllm-cpu-release-repo:$(buildkite-agent meta-data get release-version) --tag public.ecr.aws/q9t5s3a7/vllm-cpu-release-repo:latest --progress plain --target vllm-openai -f docker/Dockerfile.cpu ." - "docker push public.ecr.aws/q9t5s3a7/vllm-cpu-release-repo:latest" - "docker push public.ecr.aws/q9t5s3a7/vllm-cpu-release-repo:$(buildkite-agent meta-data get release-version)" env: diff --git a/.buildkite/scripts/hardware_ci/run-cpu-test.sh b/.buildkite/scripts/hardware_ci/run-cpu-test.sh index 7e0f720feaa7..7479c43977d7 100644 --- a/.buildkite/scripts/hardware_ci/run-cpu-test.sh +++ b/.buildkite/scripts/hardware_ci/run-cpu-test.sh @@ -49,6 +49,7 @@ function cpu_tests() { # Run kernel tests docker exec cpu-test-"$NUMA_NODE" bash -c " set -e + pytest -x -v -s tests/kernels/attention/test_cpu_attn.py pytest -x -v -s tests/kernels/test_onednn.py" # Run basic model test @@ -116,4 +117,4 @@ function cpu_tests() { # All of CPU tests are expected to be finished less than 40 mins. export -f cpu_tests -timeout 2h bash -c "cpu_tests $CORE_RANGE $NUMA_NODE" +timeout 2.5h bash -c "cpu_tests $CORE_RANGE $NUMA_NODE" diff --git a/cmake/cpu_extension.cmake b/cmake/cpu_extension.cmake index dbda19fbcbf2..51447cde0b29 100644 --- a/cmake/cpu_extension.cmake +++ b/cmake/cpu_extension.cmake @@ -15,6 +15,7 @@ endif() # set(ENABLE_AVX512BF16 $ENV{VLLM_CPU_AVX512BF16}) set(ENABLE_AVX512VNNI $ENV{VLLM_CPU_AVX512VNNI}) +set(ENABLE_AMXBF16 $ENV{VLLM_CPU_AMXBF16}) include_directories("${CMAKE_SOURCE_DIR}/csrc") @@ -140,6 +141,22 @@ if (AVX512_FOUND AND NOT AVX512_DISABLED) set(ENABLE_AVX512VNNI OFF) message(WARNING "Disable AVX512-VNNI ISA support, no avx512_vnni found in local CPU flags." " If cross-compilation is required, please set env VLLM_CPU_AVX512VNNI=1.") endif() + + find_isa(${CPUINFO} "amx_bf16" AMXBF16_FOUND) + if (AMXBF16_FOUND OR ENABLE_AMXBF16) + if (CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND + CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 12.3) + list(APPEND CXX_COMPILE_FLAGS "-mamx-bf16" "-mamx-tile") + set(ENABLE_AMXBF16 ON) + add_compile_definitions(-DCPU_CAPABILITY_AMXBF16) + else() + set(ENABLE_AMXBF16 OFF) + message(WARNING "Disable AMX_BF16 ISA support, requires gcc/g++ >= 12.3") + endif() + else() + set(ENABLE_AMXBF16 OFF) + message(WARNING "Disable AMX_BF16 ISA support, no amx_bf16 found in local CPU flags." " If cross-compilation is required, please set env VLLM_CPU_AMXBF16=1.") + endif() elseif (AVX2_FOUND) list(APPEND CXX_COMPILE_FLAGS "-mavx2") @@ -275,7 +292,10 @@ if ((AVX512_FOUND AND NOT AVX512_DISABLED) OR (ASIMD_FOUND AND NOT APPLE_SILICON set(ONEDNN_VERBOSE "OFF") set(CMAKE_POLICY_DEFAULT_CMP0077 NEW) + set(VLLM_BUILD_TYPE ${CMAKE_BUILD_TYPE}) + set(CMAKE_BUILD_TYPE "Release") # remove oneDNN debug symbols to reduce size FetchContent_MakeAvailable(oneDNN) + set(CMAKE_BUILD_TYPE ${VLLM_BUILD_TYPE}) add_library(dnnl_ext OBJECT "csrc/cpu/dnnl_helper.cpp") target_include_directories( dnnl_ext @@ -305,14 +325,14 @@ endif() # set(VLLM_EXT_SRC "csrc/cpu/activation.cpp" - "csrc/cpu/attention.cpp" - "csrc/cpu/cache.cpp" "csrc/cpu/utils.cpp" "csrc/cpu/layernorm.cpp" "csrc/cpu/mla_decode.cpp" "csrc/cpu/pos_encoding.cpp" - "csrc/cpu/torch_bindings.cpp" - "csrc/moe/dynamic_4bit_int_moe_cpu.cpp") + "csrc/moe/dynamic_4bit_int_moe_cpu.cpp" + "csrc/cpu/cpu_attn.cpp" + "csrc/cpu/scratchpad_manager.cpp" + "csrc/cpu/torch_bindings.cpp") if (AVX512_FOUND AND NOT AVX512_DISABLED) set(VLLM_EXT_SRC diff --git a/csrc/cpu/attention.cpp b/csrc/cpu/attention.cpp deleted file mode 100644 index 82862fea7f2b..000000000000 --- a/csrc/cpu/attention.cpp +++ /dev/null @@ -1,798 +0,0 @@ -#include "cpu_types.hpp" - -namespace { - -template -struct KernelVecType { - using q_load_vec_type = void; - using q_vec_type = void; - using k_load_vec_type = void; - using k_vec_type = void; - using qk_acc_vec_type = void; - using v_load_vec_type = void; -}; - -template <> -struct KernelVecType { - using q_load_vec_type = vec_op::FP32Vec4; - using q_vec_type = vec_op::FP32Vec16; - using k_load_vec_type = vec_op::FP32Vec16; - using k_vec_type = vec_op::FP32Vec16; - using qk_acc_vec_type = vec_op::FP32Vec16; - using v_load_vec_type = vec_op::FP32Vec16; -}; - -template <> -struct KernelVecType { -#if defined(__powerpc64__) || defined(__s390x__) - // Power and s390x architecture-specific vector types - using q_load_vec_type = vec_op::FP32Vec8; - using k_load_vec_type = vec_op::FP32Vec16; - using v_load_vec_type = vec_op::FP32Vec16; -#else - // Fallback for other architectures, including x86 - using q_load_vec_type = vec_op::FP16Vec8; - using k_load_vec_type = vec_op::FP16Vec16; - using v_load_vec_type = vec_op::FP16Vec16; -#endif - using q_vec_type = vec_op::FP32Vec16; - using k_vec_type = vec_op::FP32Vec16; - using qk_acc_vec_type = vec_op::FP32Vec16; -}; - -#ifdef __AVX512BF16__ -template <> -struct KernelVecType { - using q_load_vec_type = vec_op::BF16Vec8; - using q_vec_type = vec_op::BF16Vec32; - using k_load_vec_type = vec_op::BF16Vec32; - using k_vec_type = vec_op::BF16Vec32; - using qk_acc_vec_type = vec_op::FP32Vec16; - using v_load_vec_type = vec_op::BF16Vec16; -}; -#else - #ifdef __aarch64__ - #ifndef ARM_BF16_SUPPORT - // pass - #else -template <> -struct KernelVecType { - using q_load_vec_type = vec_op::BF16Vec8; - using q_vec_type = vec_op::FP32Vec16; - using k_load_vec_type = vec_op::BF16Vec16; - using k_vec_type = vec_op::FP32Vec16; - using qk_acc_vec_type = vec_op::FP32Vec16; - using v_load_vec_type = vec_op::BF16Vec16; -}; - #endif - #else -template <> -struct KernelVecType { - using q_load_vec_type = vec_op::BF16Vec8; - using q_vec_type = vec_op::FP32Vec16; - using k_load_vec_type = vec_op::BF16Vec16; - using k_vec_type = vec_op::FP32Vec16; - using qk_acc_vec_type = vec_op::FP32Vec16; - using v_load_vec_type = vec_op::BF16Vec16; -}; - #endif -#endif - -template -FORCE_INLINE std::pair reduceSoftmax(T* data, const int size, - const int capacity) { - T max = data[0]; - for (int i = 1; i < size; ++i) { - max = max >= data[i] ? max : data[i]; - } - - T sum = 0; - for (int i = 0; i < size; ++i) { - data[i] = std::exp(data[i] - max); - sum += data[i]; - } - - int i = 0; - for (; i < size; ++i) { - data[i] /= sum; - } - - for (; i < capacity; ++i) { - data[i] = 0; - } - - return {max, sum}; -} - -template -FORCE_INLINE std::pair reduceSoftmaxAlibi(T* data, const int size, - const int capacity, - const float alibi_slope, - const int start_index, - const int seq_len) { - data[0] += alibi_slope * (start_index - seq_len + 1); - T max = data[0]; - for (int i = 1; i < size; ++i) { - T qk = data[i] + alibi_slope * (start_index + i - seq_len + 1); - data[i] = qk; - max = max >= qk ? max : qk; - } - - T sum = 0; - for (int i = 0; i < size; ++i) { - data[i] = std::exp(data[i] - max); - sum += data[i]; - } - - int i = 0; - for (; i < size; ++i) { - data[i] /= sum; - } - - for (; i < capacity; ++i) { - data[i] = 0; - } - - return {max, sum}; -} - -template -FORCE_INLINE void reducePartitionSoftmax(const T* max_data, T* sum_data, - const int size) { - T max = max_data[0]; - for (int i = 1; i < size; ++i) { - max = max >= max_data[i] ? max : max_data[i]; - } - - T rescaled_sum = 0; - for (int i = 0; i < size; ++i) { - T rescale_factor = std::exp(max_data[i] - max); - rescaled_sum += rescale_factor * sum_data[i]; - sum_data[i] *= rescale_factor; - } - for (int i = 0; i < size; ++i) { - sum_data[i] /= rescaled_sum + 1e-8; - } -} - -template -struct reduceQKBlockKernel { - using q_load_vec_type = typename KernelVecType::q_load_vec_type; - using q_vec_type = typename KernelVecType::q_vec_type; - using k_load_vec_type = typename KernelVecType::k_load_vec_type; - using k_vec_type = typename KernelVecType::k_vec_type; - using qk_acc_vec_type = typename KernelVecType::qk_acc_vec_type; - - constexpr static int TOKEN_PER_GROUP = k_load_vec_type::get_elem_num() / x; - constexpr static int MAX_GROUP_NUM = 16 / TOKEN_PER_GROUP; - constexpr static int UNROLL_GROUP_NUM = MAX_GROUP_NUM / 4; - - static_assert(MAX_GROUP_NUM == 8 || MAX_GROUP_NUM == 4); - static_assert(k_load_vec_type::get_elem_num() % x == 0); - static_assert(q_load_vec_type::get_elem_num() * sizeof(scalar_t) == 16); - - FORCE_INLINE static void call(const scalar_t* __restrict__ q, - const scalar_t* __restrict__ k_block, - float* __restrict__ logits, float scale, - const int token_num) { - const int group_num = (token_num + TOKEN_PER_GROUP - 1) / TOKEN_PER_GROUP; - - qk_acc_vec_type group_accums[MAX_GROUP_NUM]; - if (token_num == BLOCK_SIZE) { - for (int q_offset = 0; q_offset < HEAD_SIZE; - q_offset += x, k_block += x * BLOCK_SIZE) { - q_load_vec_type q_load_group_vec(q + q_offset); - q_vec_type q_group_vec(q_load_group_vec); - - vec_op::unroll_loop( - [k_block, &q_group_vec, &group_accums](int token_group_idx) { - k_load_vec_type k_load_group_vec(k_block + token_group_idx * x * - TOKEN_PER_GROUP); - k_vec_type k_group_vec(k_load_group_vec); - vec_op::fma(group_accums[token_group_idx], q_group_vec, - k_group_vec); - vec_op::prefetch(k_block + x * BLOCK_SIZE + - token_group_idx * x * TOKEN_PER_GROUP); - }); - } - } else { - for (int q_offset = 0; q_offset < HEAD_SIZE; - q_offset += x, k_block += x * BLOCK_SIZE) { - q_load_vec_type q_load_group_vec(q + q_offset); - q_vec_type q_group_vec(q_load_group_vec); - for (int token_group_start = 0; token_group_start < group_num; - token_group_start += UNROLL_GROUP_NUM) { - vec_op::unroll_loop( - [token_group_start, k_block, &q_group_vec, - &group_accums](int token_group_idx) { - token_group_idx += token_group_start; - k_load_vec_type k_load_group_vec(k_block + token_group_idx * x * - TOKEN_PER_GROUP); - k_vec_type k_group_vec(k_load_group_vec); - vec_op::fma(group_accums[token_group_idx], q_group_vec, - k_group_vec); - vec_op::prefetch(k_block + x * BLOCK_SIZE + - token_group_idx * x * TOKEN_PER_GROUP); - }); - } - } - } - - for (int token_group_idx = 0; token_group_idx < group_num; - ++token_group_idx) { - vec_op::unroll_loop( - [&group_accums, logits, scale, token_group_idx](int token_idx) { - float dot_v = - group_accums[token_group_idx] - .template reduce_sub_sum(token_idx); - logits[token_group_idx * TOKEN_PER_GROUP + token_idx] = - dot_v * scale; - }); - } - } -}; - -template -FORCE_INLINE void reduceValueBlock(const float* prob, const scalar_t* v_block, - acc_t&& acc) { - using v_load_vec_type = typename KernelVecType::v_load_vec_type; - constexpr int ELEM_NUM = v_load_vec_type::get_elem_num(); - static_assert(BLOCK_SIZE == ELEM_NUM); - vec_op::FP32Vec16 prob_vec(prob); - - vec_op::unroll_loop([&](int head_elem_idx) { - v_load_vec_type v_vec(v_block + BLOCK_SIZE * head_elem_idx); - vec_op::FP32Vec16 fp32_v_vec(v_vec); - acc[head_elem_idx] = acc[head_elem_idx] + prob_vec * fp32_v_vec; - }); -} -}; // namespace - -// Paged attention v1 -namespace { -template -struct paged_attention_v1_impl { - static void call( - scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] - const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] - const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, - // head_size/x, block_size, x] - const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, - // head_size, block_size] - const int num_kv_heads, const float scale, - const int* __restrict__ block_tables, // [num_seqs, - // max_num_blocks_per_seq] - const int* __restrict__ seq_lens, // [num_seqs] - const int max_num_blocks_per_seq, - const float* __restrict__ alibi_slopes, // [num_heads] - const int q_stride, const int kv_block_stride, const int kv_head_stride, - const int num_seqs, const int num_heads) { - constexpr int x = 16 / sizeof(scalar_t); - const int num_queries_per_kv = num_heads / num_kv_heads; - - static_assert(BLOCK_SIZE == 16); - - int max_seq_len = max_num_blocks_per_seq * BLOCK_SIZE; - int max_seq_len_padded = (max_seq_len + 15) & 0xFFFFFFF0; - TORCH_CHECK((max_seq_len_padded * sizeof(float)) % 64 == 0); - - const int parallel_work_item_num = omp_get_max_threads(); - - size_t logits_bytes = - parallel_work_item_num * max_seq_len_padded * sizeof(float); - float* logits = (float*)std::aligned_alloc( - 64, logits_bytes); // Cacheline alignment for each context token. - // [parallel_work_item_num, max_seq_len_padded] - -#pragma omp parallel for collapse(2) schedule(dynamic, 1) - for (int seq_idx = 0; seq_idx < num_seqs; ++seq_idx) { - for (int head_idx = 0; head_idx < num_heads; ++head_idx) { - int seq_len = seq_lens[seq_idx]; - const int* seq_block_table = - block_tables + max_num_blocks_per_seq * seq_idx; - const int block_num = (seq_len + BLOCK_SIZE - 1) / BLOCK_SIZE; - const int64_t kv_head_idx = head_idx / num_queries_per_kv; - const scalar_t* __restrict__ q_vec_ptr = - q + seq_idx * q_stride + head_idx * HEAD_SIZE; - const int last_block_token_num = seq_len - (block_num - 1) * BLOCK_SIZE; - float* __restrict__ thread_block_logits = - logits + omp_get_thread_num() * max_seq_len_padded; - - // Compute logits - for (int block_idx = 0; block_idx < block_num; ++block_idx) { - const int64_t physical_block_idx = seq_block_table[block_idx]; - const scalar_t* __restrict__ k_block_cache_ptr = - k_cache + physical_block_idx * kv_block_stride + - kv_head_idx * kv_head_stride; - float* __restrict__ head_block_logits = - thread_block_logits + block_idx * BLOCK_SIZE; - - reduceQKBlockKernel::call( - q_vec_ptr, k_block_cache_ptr, head_block_logits, scale, - block_idx == block_num - 1 ? last_block_token_num : BLOCK_SIZE); - } - - // Compute softmax - if (alibi_slopes) { - reduceSoftmaxAlibi(thread_block_logits, seq_len, - block_num * BLOCK_SIZE, alibi_slopes[head_idx], 0, - seq_len); - } else { - reduceSoftmax(thread_block_logits, seq_len, block_num * BLOCK_SIZE); - } - - // Compute value - constexpr int head_elem_num_per_partition = 16; - constexpr int head_partition_num = - HEAD_SIZE / head_elem_num_per_partition; - for (int head_part_idx = 0; head_part_idx < head_partition_num; - ++head_part_idx) { - vec_op::FP32Vec16 accums[head_elem_num_per_partition]; - scalar_t* __restrict__ out_ptr = - out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE + - head_part_idx * head_elem_num_per_partition; - for (int block_idx = 0; block_idx < block_num; ++block_idx) { - const int64_t physical_block_idx = seq_block_table[block_idx]; - const float* __restrict__ prob_vec_ptr = - thread_block_logits + block_idx * BLOCK_SIZE; - const scalar_t* __restrict__ v_block_cache_ptr = - v_cache + physical_block_idx * kv_block_stride + - kv_head_idx * kv_head_stride + - BLOCK_SIZE * head_part_idx * head_elem_num_per_partition; - reduceValueBlock( - prob_vec_ptr, v_block_cache_ptr, accums); - - if (block_idx != block_num - 1) { - const int64_t next_physical_block_idx = - seq_block_table[block_idx + 1]; - const scalar_t* __restrict__ next_v_block_cache_ptr = - v_cache + next_physical_block_idx * kv_block_stride + - kv_head_idx * kv_head_stride + - BLOCK_SIZE * head_part_idx * head_elem_num_per_partition; - vec_op::unroll_loop( - [&](int head_elem_idx) { - if (head_elem_idx % 2 == 0) { - vec_op::prefetch(next_v_block_cache_ptr + - BLOCK_SIZE * head_elem_idx); - } - }); - } - } - - vec_op::unroll_loop( - [&](int head_elem_idx) { - float value = accums[head_elem_idx].reduce_sum(); - vec_op::storeFP32(value, out_ptr + head_elem_idx); - }); - } - } - } - std::free(logits); - } -}; - -#define LAUNCH_V1_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE) \ - paged_attention_v1_impl::call( \ - out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, scale, \ - block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \ - alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, num_seqs, \ - num_heads); - -template -void paged_attention_v1_impl_launcher( - torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache, - torch::Tensor& value_cache, int num_kv_heads, float scale, - torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len, - const std::optional& alibi_slopes) { - int num_seqs = query.size(0); - int num_heads = query.size(1); - int head_size = query.size(2); - int max_num_blocks_per_seq = block_tables.size(1); - int q_stride = query.stride(0); - int kv_block_stride = key_cache.stride(0); - int kv_head_stride = key_cache.stride(1); - - // NOTE: alibi_slopes is optional. - const float* alibi_slopes_ptr = - alibi_slopes - ? reinterpret_cast(alibi_slopes.value().data_ptr()) - : nullptr; - - T* out_ptr = reinterpret_cast(out.data_ptr()); - T* query_ptr = reinterpret_cast(query.data_ptr()); - T* key_cache_ptr = reinterpret_cast(key_cache.data_ptr()); - T* value_cache_ptr = reinterpret_cast(value_cache.data_ptr()); - int* block_tables_ptr = block_tables.data_ptr(); - int* seq_lens_ptr = seq_lens.data_ptr(); - - switch (head_size) { - case 32: - LAUNCH_V1_ATTENTION_KERNEL(T, 32, BLOCK_SIZE); - break; - case 64: - LAUNCH_V1_ATTENTION_KERNEL(T, 64, BLOCK_SIZE); - break; - case 80: - LAUNCH_V1_ATTENTION_KERNEL(T, 80, BLOCK_SIZE); - break; - case 96: - LAUNCH_V1_ATTENTION_KERNEL(T, 96, BLOCK_SIZE); - break; - case 112: - LAUNCH_V1_ATTENTION_KERNEL(T, 112, BLOCK_SIZE); - break; - case 128: - LAUNCH_V1_ATTENTION_KERNEL(T, 128, BLOCK_SIZE); - break; - case 192: - LAUNCH_V1_ATTENTION_KERNEL(T, 192, BLOCK_SIZE); - break; - case 256: - LAUNCH_V1_ATTENTION_KERNEL(T, 256, BLOCK_SIZE); - break; - default: - TORCH_CHECK(false, "Unsupported head size: ", head_size); - break; - } -} - -#define CALL_V1_KERNEL_LAUNCHER(T, BLOCK_SIZE) \ - paged_attention_v1_impl_launcher( \ - out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, \ - seq_lens, max_seq_len, alibi_slopes); - -#define CALL_V1_KERNEL_LAUNCHER_BLOCK_SIZE(T) \ - switch (block_size) { \ - case 16: \ - CALL_V1_KERNEL_LAUNCHER(T, 16); \ - break; \ - default: \ - TORCH_CHECK(false, "Unsupported block size: ", block_size); \ - break; \ - } -} // namespace - -void paged_attention_v1( - torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache, - torch::Tensor& value_cache, int64_t num_kv_heads, double scale, - torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size, - int64_t max_seq_len, const std::optional& alibi_slopes, - const std::string& kv_cache_dtype, torch::Tensor& k_scale, - torch::Tensor& v_scale, const int64_t tp_rank, - const int64_t blocksparse_local_blocks, - const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, - const int64_t blocksparse_head_sliding_step) { - TORCH_CHECK(blocksparse_vert_stride <= 1, - "CPU backend does not support blocksparse attention yet."); - VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "paged_attention_v1_impl", - [&] { - CPU_KERNEL_GUARD_IN(paged_attention_v1_impl) - CALL_V1_KERNEL_LAUNCHER_BLOCK_SIZE(scalar_t); - CPU_KERNEL_GUARD_OUT(paged_attention_v1_impl) - }); -} - -// Paged attention v2 -namespace { -template -struct paged_attention_v2_impl { - static void call( - scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] - float* __restrict__ exp_sums, // [num_seqs, num_heads, - // max_num_partitions] - float* __restrict__ max_logits, // [num_seqs, num_heads, - // max_num_partitions] - scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, - // max_num_partitions, head_size] - const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] - const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, - // head_size/x, block_size, x] - const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, - // head_size, block_size] - const int num_kv_heads, const float scale, - const int* __restrict__ block_tables, // [num_seqs, - // max_num_blocks_per_seq] - const int* __restrict__ seq_lens, // [num_seqs] - const int max_num_blocks_per_seq, - const float* __restrict__ alibi_slopes, // [num_heads] - const int q_stride, const int kv_block_stride, const int kv_head_stride, - const int num_seqs, const int num_heads, const int max_num_partitions) { - constexpr int x = 16 / sizeof(scalar_t); - const int num_queries_per_kv = num_heads / num_kv_heads; - - static_assert(BLOCK_SIZE == 16); - static_assert(PARTITION_SIZE * sizeof(float) % 64 == 0); - static_assert(PARTITION_SIZE % BLOCK_SIZE == 0); - -#pragma omp parallel for collapse(3) schedule(static, 1) - for (int seq_idx = 0; seq_idx < num_seqs; ++seq_idx) { - for (int partition_idx = 0; partition_idx < max_num_partitions; - ++partition_idx) { - for (int head_idx = 0; head_idx < num_heads; ++head_idx) { - const int seq_len = seq_lens[seq_idx]; - const int start_token_idx = partition_idx * PARTITION_SIZE; - - if (start_token_idx >= seq_len) continue; - - const int partition_num = - (seq_len + PARTITION_SIZE - 1) / PARTITION_SIZE; - const bool no_reduce = (partition_num == 1); - const int token_num = - (std::min(seq_len, start_token_idx + PARTITION_SIZE) - - start_token_idx); - const int block_num = (token_num + BLOCK_SIZE - 1) / BLOCK_SIZE; - const int last_block_token_num = - token_num - (block_num - 1) * BLOCK_SIZE; - const int* seq_block_table = block_tables + - max_num_blocks_per_seq * seq_idx + - start_token_idx / BLOCK_SIZE; - const int64_t kv_head_idx = head_idx / num_queries_per_kv; - const scalar_t* __restrict__ q_vec_ptr = - q + seq_idx * q_stride + head_idx * HEAD_SIZE; - - float logits[PARTITION_SIZE] __attribute__((aligned(64))) = {0}; - - // Compute logits - for (int block_idx = 0; block_idx < block_num; ++block_idx) { - const int64_t physical_block_idx = seq_block_table[block_idx]; - const scalar_t* __restrict__ k_block_cache_ptr = - k_cache + physical_block_idx * kv_block_stride + - kv_head_idx * kv_head_stride; - float* __restrict__ head_block_logits = - logits + block_idx * BLOCK_SIZE; - - reduceQKBlockKernel::call( - q_vec_ptr, k_block_cache_ptr, head_block_logits, scale, - block_idx == block_num - 1 ? last_block_token_num : BLOCK_SIZE); - } - - std::pair max_and_sum; - if (alibi_slopes) { - max_and_sum = reduceSoftmaxAlibi( - logits, token_num, block_num * BLOCK_SIZE, - alibi_slopes[head_idx], start_token_idx, seq_len); - } else { - max_and_sum = - reduceSoftmax(logits, token_num, block_num * BLOCK_SIZE); - } - - auto&& [max_logit, exp_sum] = max_and_sum; - - scalar_t* __restrict__ output_buffer = nullptr; - if (!no_reduce) { - auto idx = seq_idx * num_heads * max_num_partitions + - head_idx * max_num_partitions + partition_idx; - max_logits[idx] = max_logit; - exp_sums[idx] = exp_sum; - output_buffer = - tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE + - head_idx * max_num_partitions * HEAD_SIZE + - partition_idx * HEAD_SIZE; - } else { - output_buffer = - out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE; - } - - // Compute value - constexpr int head_elem_num_per_partition = 16; - constexpr int head_partition_num = - HEAD_SIZE / head_elem_num_per_partition; - for (int head_part_idx = 0; head_part_idx < head_partition_num; - ++head_part_idx) { - vec_op::FP32Vec16 accums[head_elem_num_per_partition]; - scalar_t* __restrict__ out_ptr = - output_buffer + head_part_idx * head_elem_num_per_partition; - for (int block_idx = 0; block_idx < block_num; ++block_idx) { - const int64_t physical_block_idx = seq_block_table[block_idx]; - const float* __restrict__ prob_vec_ptr = - logits + block_idx * BLOCK_SIZE; - const scalar_t* __restrict__ v_block_cache_ptr = - v_cache + physical_block_idx * kv_block_stride + - kv_head_idx * kv_head_stride + - BLOCK_SIZE * head_part_idx * head_elem_num_per_partition; - reduceValueBlock( - prob_vec_ptr, v_block_cache_ptr, accums); - - if (block_idx != block_num - 1) { - const int64_t next_physical_block_idx = - seq_block_table[block_idx + 1]; - const scalar_t* __restrict__ next_v_block_cache_ptr = - v_cache + next_physical_block_idx * kv_block_stride + - kv_head_idx * kv_head_stride + - BLOCK_SIZE * head_part_idx * head_elem_num_per_partition; - vec_op::unroll_loop( - [&](int head_elem_idx) { - if (head_elem_idx % 2 == 0) { - vec_op::prefetch(next_v_block_cache_ptr + - BLOCK_SIZE * head_elem_idx); - } - }); - } - } - - vec_op::unroll_loop( - [&](int head_elem_idx) { - float value = accums[head_elem_idx].reduce_sum(); - vec_op::storeFP32(value, out_ptr + head_elem_idx); - }); - } - } - } - } - - // Rescale partition softmax and store the factors to exp_sums -#pragma omp parallel for collapse(2) schedule(static, 1) - for (int seq_idx = 0; seq_idx < num_seqs; ++seq_idx) { - for (int head_idx = 0; head_idx < num_heads; ++head_idx) { - const int seq_len = seq_lens[seq_idx]; - const int partition_num = - (seq_len + PARTITION_SIZE - 1) / PARTITION_SIZE; - - if (partition_num == 1) continue; - - reducePartitionSoftmax( - max_logits + seq_idx * num_heads * max_num_partitions + - head_idx * max_num_partitions, - exp_sums + seq_idx * num_heads * max_num_partitions + - head_idx * max_num_partitions, - partition_num); - } - } - - // Reduce values - using v_load_vec_type = typename KernelVecType::v_load_vec_type; - static_assert(v_load_vec_type::get_elem_num() == BLOCK_SIZE); - constexpr int head_elem_num_per_group = - 16; // Note: didn't align with the cacheline size, due to some - // HEAD_SIZE didn't align with 64 bytes - static_assert(HEAD_SIZE % head_elem_num_per_group == 0); - constexpr int head_group_num = HEAD_SIZE / head_elem_num_per_group; - const float* __restrict__ rescale_factors = exp_sums; -#pragma omp parallel for collapse(3) schedule(static, 1) - for (int seq_idx = 0; seq_idx < num_seqs; ++seq_idx) { - for (int head_idx = 0; head_idx < num_heads; ++head_idx) { - for (int group_idx = 0; group_idx < head_group_num; ++group_idx) { - const int seq_len = seq_lens[seq_idx]; - const int partition_num = - (seq_len + PARTITION_SIZE - 1) / PARTITION_SIZE; - - if (partition_num == 1) continue; - - const float* __restrict__ seq_head_rescale_factors = - rescale_factors + seq_idx * num_heads * max_num_partitions + - head_idx * max_num_partitions; - const scalar_t* __restrict__ seq_head_tmp_out = - tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE + - head_idx * max_num_partitions * HEAD_SIZE + - group_idx * head_elem_num_per_group; - scalar_t* __restrict__ seq_head_output = - out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE + - group_idx * head_elem_num_per_group; - - vec_op::FP32Vec16 acc; - for (int i = 0; i < partition_num; ++i) { - vec_op::FP32Vec16 rescale_factor(seq_head_rescale_factors[i]); - v_load_vec_type value(seq_head_tmp_out + i * HEAD_SIZE); - vec_op::FP32Vec16 fp32_value(value); - acc = acc + fp32_value * rescale_factor; - } - v_load_vec_type cast_acc(acc); - cast_acc.save(seq_head_output); - } - } - } - } -}; - -#define LAUNCH_V2_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE) \ - paged_attention_v2_impl::call( \ - out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr, \ - key_cache_ptr, value_cache_ptr, num_kv_heads, scale, block_tables_ptr, \ - seq_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \ - kv_block_stride, kv_head_stride, num_seqs, num_heads, \ - max_num_partitions); - -template -void paged_attention_v2_impl_launcher( - torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits, - torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache, - torch::Tensor& value_cache, int num_kv_heads, float scale, - torch::Tensor& block_tables, torch::Tensor& seq_lens, int block_size, - int max_seq_len, const std::optional& alibi_slopes) { - int num_seqs = query.size(0); - int num_heads = query.size(1); - int head_size = query.size(2); - int max_num_blocks_per_seq = block_tables.size(1); - int q_stride = query.stride(0); - int kv_block_stride = key_cache.stride(0); - int kv_head_stride = key_cache.stride(1); - int max_num_partitions = exp_sums.size(-1); - - // NOTE: alibi_slopes is optional. - const float* alibi_slopes_ptr = - alibi_slopes - ? reinterpret_cast(alibi_slopes.value().data_ptr()) - : nullptr; - - T* out_ptr = reinterpret_cast(out.data_ptr()); - float* exp_sums_ptr = reinterpret_cast(exp_sums.data_ptr()); - float* max_logits_ptr = reinterpret_cast(max_logits.data_ptr()); - T* tmp_out_ptr = reinterpret_cast(tmp_out.data_ptr()); - T* query_ptr = reinterpret_cast(query.data_ptr()); - T* key_cache_ptr = reinterpret_cast(key_cache.data_ptr()); - T* value_cache_ptr = reinterpret_cast(value_cache.data_ptr()); - int* block_tables_ptr = block_tables.data_ptr(); - int* seq_lens_ptr = seq_lens.data_ptr(); - - switch (head_size) { - case 32: - LAUNCH_V2_ATTENTION_KERNEL(T, 32, BLOCK_SIZE); - break; - case 64: - LAUNCH_V2_ATTENTION_KERNEL(T, 64, BLOCK_SIZE); - break; - case 80: - LAUNCH_V2_ATTENTION_KERNEL(T, 80, BLOCK_SIZE); - break; - case 96: - LAUNCH_V2_ATTENTION_KERNEL(T, 96, BLOCK_SIZE); - break; - case 112: - LAUNCH_V2_ATTENTION_KERNEL(T, 112, BLOCK_SIZE); - break; - case 128: - LAUNCH_V2_ATTENTION_KERNEL(T, 128, BLOCK_SIZE); - break; - case 192: - LAUNCH_V2_ATTENTION_KERNEL(T, 192, BLOCK_SIZE); - break; - case 256: - LAUNCH_V2_ATTENTION_KERNEL(T, 256, BLOCK_SIZE); - break; - default: - TORCH_CHECK(false, "Unsupported head size: ", head_size); - break; - } -} - -#define CALL_V2_KERNEL_LAUNCHER(T, BLOCK_SIZE) \ - paged_attention_v2_impl_launcher( \ - out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \ - num_kv_heads, scale, block_tables, seq_lens, block_size, max_seq_len, \ - alibi_slopes); - -#define CALL_V2_KERNEL_LAUNCHER_BLOCK_SIZE(T) \ - switch (block_size) { \ - case 16: \ - CALL_V2_KERNEL_LAUNCHER(T, 16); \ - break; \ - default: \ - TORCH_CHECK(false, "Unsupported block size: ", block_size); \ - break; \ - } -} // namespace - -void paged_attention_v2( - torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits, - torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache, - torch::Tensor& value_cache, int64_t num_kv_heads, double scale, - torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size, - int64_t max_seq_len, const std::optional& alibi_slopes, - const std::string& kv_cache_dtype, torch::Tensor& k_scale, - torch::Tensor& v_scale, const int64_t tp_rank, - const int64_t blocksparse_local_blocks, - const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, - const int64_t blocksparse_head_sliding_step) { - TORCH_CHECK(blocksparse_vert_stride <= 1, - "CPU backend does not support blocksparse attention yet."); - VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "paged_attention_v2_impl", - [&] { - CPU_KERNEL_GUARD_IN(paged_attention_v2_impl) - CALL_V2_KERNEL_LAUNCHER_BLOCK_SIZE(scalar_t); - CPU_KERNEL_GUARD_OUT(paged_attention_v2_impl) - }); -} \ No newline at end of file diff --git a/csrc/cpu/cache.cpp b/csrc/cpu/cache.cpp deleted file mode 100644 index 69f6d06e3c96..000000000000 --- a/csrc/cpu/cache.cpp +++ /dev/null @@ -1,214 +0,0 @@ -#include -#include - -#include "cpu_types.hpp" - -#if defined(__x86_64__) - #define DISPATCH_MACRO VLLM_DISPATCH_FLOATING_TYPES_WITH_E5M2 -#else - #define DISPATCH_MACRO VLLM_DISPATCH_FLOATING_TYPES -#endif - -namespace { -template -void copy_blocks_cpu_impl(std::vector const& key_caches, - std::vector const& value_caches, - const torch::Tensor& mapping_pairs, - const int element_num_per_block, - const int layer_num) { - const size_t pair_num = mapping_pairs.size(0); - const size_t block_bytes = sizeof(scalar_t) * element_num_per_block; -#pragma omp parallel for collapse(2) - for (int layer = 0; layer < layer_num; ++layer) { - for (size_t pair = 0; pair < pair_num; ++pair) { - int64_t source_offset = - element_num_per_block * mapping_pairs[pair][0].item(); - int64_t target_offset = - element_num_per_block * mapping_pairs[pair][1].item(); - scalar_t* key_cache_ptr = key_caches[layer].data_ptr(); - scalar_t* source_ptr = key_cache_ptr + source_offset; - scalar_t* target_ptr = key_cache_ptr + target_offset; - std::memcpy(target_ptr, source_ptr, block_bytes); - - scalar_t* value_cache_ptr = value_caches[layer].data_ptr(); - source_ptr = value_cache_ptr + source_offset; - target_ptr = value_cache_ptr + target_offset; - std::memcpy(target_ptr, source_ptr, block_bytes); - } - } -} - -template -void reshape_and_cache_cpu_impl( - const scalar_t* __restrict__ key, const scalar_t* __restrict__ value, - scalar_t* __restrict__ key_cache, scalar_t* __restrict__ value_cache, - const int64_t* __restrict__ slot_mapping, const int num_tokens, - const int key_stride, const int value_stride, const int num_heads, - const int head_size, const int block_size, const int x) { - const int block_elem_num = num_heads * head_size * block_size; - -#pragma omp parallel for collapse(2) - for (int token_idx = 0; token_idx < num_tokens; ++token_idx) { - for (int head_idx = 0; head_idx < num_heads; ++head_idx) { - const int64_t slot_idx = slot_mapping[token_idx]; - if (slot_idx >= 0) { - int src_key_head_idx = token_idx * key_stride + head_idx * head_size; - int src_value_head_idx = - token_idx * value_stride + head_idx * head_size; - const scalar_t* src_key_head_ptr = key + src_key_head_idx; - const scalar_t* src_value_head_ptr = value + src_value_head_idx; - const int64_t block_index = slot_idx / block_size; - const int64_t block_offset = slot_idx % block_size; - scalar_t* target_key_head_ptr = key_cache + - block_elem_num * block_index + - head_idx * block_size * head_size; - scalar_t* target_value_head_ptr = value_cache + - block_elem_num * block_index + - head_idx * block_size * head_size; - - for (int src_key_idx = 0; src_key_idx < head_size; src_key_idx += x) { - const int64_t target_offset = - src_key_idx * block_size + block_offset * x; - for (int i = 0; i < x; ++i) { - target_key_head_ptr[target_offset + i] = - src_key_head_ptr[src_key_idx + i]; - } - } - - for (int src_value_idx = 0; src_value_idx < head_size; - ++src_value_idx) { - const int64_t target_offset = - src_value_idx * block_size + block_offset; - target_value_head_ptr[target_offset] = - src_value_head_ptr[src_value_idx]; - } - } - } - } -} -}; // namespace - -template -void concat_and_cache_mla_cpu_impl( - const scalar_t* __restrict__ kv_c, // [num_tokens, kv_lora_rank] - const scalar_t* __restrict__ k_pe, // [num_tokens, pe_dim] - scalar_t* __restrict__ kv_cache, // [num_blocks, block_size, (kv_lora_rank - // + pe_dim)] - const int64_t* __restrict__ slot_mapping, // [num_tokens] - const int num_tokens, // - const int block_stride, // - const int entry_stride, // - const int kv_c_stride, // - const int k_pe_stride, // - const int kv_lora_rank, // - const int pe_dim, // - const int block_size // -) { -#pragma omp parallel for - for (int token_idx = 0; token_idx < num_tokens; ++token_idx) { - const int64_t slot_idx = slot_mapping[token_idx]; - // NOTE: slot_idx can be -1 if the token is padded - if (slot_idx < 0) { - continue; - } - const int64_t block_idx = slot_idx / block_size; - const int64_t block_offset = slot_idx % block_size; - - auto copy = [&](const scalar_t* __restrict__ src, - scalar_t* __restrict__ dst, int src_stride, int dst_stride, - int size, int offset) { - for (int i = 0; i < size; i++) { - const int64_t src_idx = token_idx * src_stride + i; - const int64_t dst_idx = - block_idx * block_stride + block_offset * entry_stride + i + offset; - dst[dst_idx] = src[src_idx]; - } - }; - - copy(kv_c, kv_cache, kv_c_stride, block_stride, kv_lora_rank, 0); - copy(k_pe, kv_cache, k_pe_stride, block_stride, pe_dim, kv_lora_rank); - } -} - -// Note: the key_caches and value_caches vectors are constant but -// not the Tensors they contain. The vectors need to be const refs -// in order to satisfy pytorch's C++ operator registration code. -void copy_blocks(std::vector const& key_caches, - std::vector const& value_caches, - const torch::Tensor& block_mapping) { - unsigned num_layers = key_caches.size(); - TORCH_CHECK(num_layers == value_caches.size()); - if (num_layers == 0) { - return; - } - - const int element_num_per_block = key_caches[0][0].numel(); - DISPATCH_MACRO(key_caches[0].scalar_type(), "copy_blocks_cpu_impl", [&] { - CPU_KERNEL_GUARD_IN(copy_blocks_cpu_impl) - copy_blocks_cpu_impl(key_caches, value_caches, block_mapping, - element_num_per_block, num_layers); - CPU_KERNEL_GUARD_OUT(copy_blocks_cpu_impl) - }); -} - -void reshape_and_cache(torch::Tensor& key, torch::Tensor& value, - torch::Tensor& key_cache, torch::Tensor& value_cache, - torch::Tensor& slot_mapping, - const std::string& kv_cache_dtype, - torch::Tensor& k_scale, torch::Tensor& v_scale) { - int num_tokens = key.size(0); - int num_heads = key.size(1); - int head_size = key.size(2); - int block_size = key_cache.size(3); - int x = key_cache.size(4); - - int key_stride = key.stride(0); - int value_stride = value.stride(0); - - DISPATCH_MACRO(key.scalar_type(), "reshape_and_cache_cpu_impl", [&] { - CPU_KERNEL_GUARD_IN(reshape_and_cache_cpu_impl) - reshape_and_cache_cpu_impl( - key.data_ptr(), value.data_ptr(), - key_cache.data_ptr(), value_cache.data_ptr(), - slot_mapping.data_ptr(), num_tokens, key_stride, value_stride, - num_heads, head_size, block_size, x); - CPU_KERNEL_GUARD_OUT(reshape_and_cache_cpu_impl) - }); -} - -void concat_and_cache_mla( - torch::Tensor& kv_c, // [num_tokens, kv_lora_rank] - torch::Tensor& k_pe, // [num_tokens, pe_dim] - torch::Tensor& kv_cache, // [num_blocks, block_size, (kv_lora_rank + - // pe_dim)] - torch::Tensor& slot_mapping, // [num_tokens] or [num_actual_tokens] - const std::string& kv_cache_dtype, torch::Tensor& scale) { - int num_tokens = slot_mapping.size(0); - int kv_lora_rank = kv_c.size(1); - int pe_dim = k_pe.size(1); - int block_size = kv_cache.size(1); - - TORCH_CHECK(kv_cache.size(2) == kv_lora_rank + pe_dim); - TORCH_CHECK(kv_cache_dtype != "fp8"); - - int kv_c_stride = kv_c.stride(0); - int k_pe_stride = k_pe.stride(0); - int block_stride = kv_cache.stride(0); - int entry_stride = kv_cache.stride(1); - - VLLM_DISPATCH_FLOATING_TYPES( - kv_c.scalar_type(), "concat_and_cache_mla_cpu_impl", [&] { - CPU_KERNEL_GUARD_IN(concat_and_cache_mla_cpu_impl) - concat_and_cache_mla_cpu_impl( - kv_c.data_ptr(), k_pe.data_ptr(), - kv_cache.data_ptr(), slot_mapping.data_ptr(), - num_tokens, block_stride, entry_stride, kv_c_stride, k_pe_stride, - kv_lora_rank, pe_dim, block_size); - CPU_KERNEL_GUARD_OUT(concat_and_cache_mla_cpu_impl) - }); -} - -void swap_blocks(torch::Tensor& src, torch::Tensor& dst, - const torch::Tensor& block_mapping) { - TORCH_CHECK(false, "swap_blocks is unsupported on CPU.") -} diff --git a/csrc/cpu/cpu_attn.cpp b/csrc/cpu/cpu_attn.cpp new file mode 100644 index 000000000000..50f17c758c14 --- /dev/null +++ b/csrc/cpu/cpu_attn.cpp @@ -0,0 +1,249 @@ +#include "cpu_attn_vec.hpp" +#include "cpu_attn_vec16.hpp" + +#ifdef CPU_CAPABILITY_AMXBF16 + #include "cpu_attn_amx.hpp" + #define AMX_DISPATCH(...) \ + case cpu_attention::ISA::AMX: { \ + using attn_impl = cpu_attention::AttentionImpl; \ + return __VA_ARGS__(); \ + } +#else + #define AMX_DISPATCH(...) case cpu_attention::ISA::AMX: +#endif + +#define CPU_ATTN_DISPATCH_CASE(HEAD_DIM, ...) \ + case HEAD_DIM: { \ + constexpr size_t head_dim = HEAD_DIM; \ + return __VA_ARGS__(); \ + } + +#define CPU_ATTN_DISPATCH_CASE_HEADDIM(HEAD_DIM, ...) \ + [&] { \ + switch (HEAD_DIM) { \ + CPU_ATTN_DISPATCH_CASE(32, __VA_ARGS__) \ + CPU_ATTN_DISPATCH_CASE(64, __VA_ARGS__) \ + CPU_ATTN_DISPATCH_CASE(96, __VA_ARGS__) \ + CPU_ATTN_DISPATCH_CASE(128, __VA_ARGS__) \ + CPU_ATTN_DISPATCH_CASE(160, __VA_ARGS__) \ + CPU_ATTN_DISPATCH_CASE(192, __VA_ARGS__) \ + CPU_ATTN_DISPATCH_CASE(224, __VA_ARGS__) \ + CPU_ATTN_DISPATCH_CASE(256, __VA_ARGS__) \ + default: { \ + TORCH_CHECK(false, "Invalid CPU attention head_dim: " + \ + std::to_string(HEAD_DIM)); \ + } \ + } \ + }() + +#define CPU_ATTN_DISPATCH_IMPL(ISA_TYPE, ...) \ + [&] { \ + switch (ISA_TYPE) { \ + AMX_DISPATCH(__VA_ARGS__) \ + case cpu_attention::ISA::VEC: { \ + using attn_impl = \ + cpu_attention::AttentionImpl; \ + return __VA_ARGS__(); \ + } \ + case cpu_attention::ISA::VEC16: { \ + using attn_impl = \ + cpu_attention::AttentionImpl; \ + return __VA_ARGS__(); \ + } \ + default: { \ + TORCH_CHECK(false, "Invalid CPU attention ISA type."); \ + } \ + } \ + }() + +torch::Tensor get_scheduler_metadata( + const int64_t num_req, const int64_t num_heads_q, + const int64_t num_heads_kv, const int64_t head_dim, + const torch::Tensor& seq_lens, at::ScalarType dtype, + const torch::Tensor& query_start_loc, const bool casual, + const int64_t window_size, const std::string& isa_hint, + const bool enable_kv_split) { + cpu_attention::ISA isa; + if (isa_hint == "amx") { + isa = cpu_attention::ISA::AMX; + } else if (isa_hint == "vec") { + isa = cpu_attention::ISA::VEC; + } else if (isa_hint == "vec16") { + isa = cpu_attention::ISA::VEC16; + } else { + TORCH_CHECK(false, "Unsupported CPU attention ISA hint: " + isa_hint); + } + + cpu_attention::AttentionScheduler::ScheduleInput input; + input.num_reqs = num_req; + input.num_heads_q = num_heads_q; + input.num_heads_kv = num_heads_kv; + input.head_dim = head_dim; + input.query_start_loc = query_start_loc.data_ptr(); + input.seq_lens = seq_lens.data_ptr(); + if (window_size != -1) { + input.left_sliding_window_size = window_size - 1; + if (casual) { + input.right_sliding_window_size = 0; + } else { + input.right_sliding_window_size = window_size - 1; + } + } else { + input.left_sliding_window_size = -1; + if (casual) { + input.right_sliding_window_size = 0; + } else { + input.right_sliding_window_size = -1; + } + } + input.casual = casual; + input.isa = isa; + input.enable_kv_split = enable_kv_split; + TORCH_CHECK(casual, "Only supports casual mask for now."); + + VLLM_DISPATCH_FLOATING_TYPES(dtype, "get_scheduler_metadata", [&]() { + CPU_ATTN_DISPATCH_CASE_HEADDIM(head_dim, [&] { + CPU_ATTN_DISPATCH_IMPL(isa, [&]() { + input.elem_size = sizeof(scalar_t); + input.q_buffer_elem_size = sizeof(attn_impl::q_buffer_t); + input.logits_buffer_elem_size = sizeof(attn_impl::logits_buffer_t); + input.output_buffer_elem_size = + sizeof(attn_impl::partial_output_buffer_t); + input.max_num_q_per_iter = attn_impl::MaxQHeadNumPerIteration; + input.kv_block_alignment = attn_impl::BlockSizeAlignment; + }); + }); + }); + + cpu_attention::AttentionScheduler scheduler; + torch::Tensor metadata = scheduler.schedule(input); + return metadata; +} + +void cpu_attn_reshape_and_cache( + const torch::Tensor& key, // [token_num, head_num, head_size] + const torch::Tensor& value, // [token_num, head_num, head_size] + torch::Tensor& + key_cache, // [num_blocks, num_kv_heads, block_size, head_size] + torch::Tensor& + value_cache, // [num_blocks, num_kv_heads, block_size, head_size] + const torch::Tensor& slot_mapping, const std::string& isa) { + TORCH_CHECK_EQ(key.dim(), 3); + TORCH_CHECK_EQ(value.dim(), 3); + TORCH_CHECK_EQ(key_cache.dim(), 4); + TORCH_CHECK_EQ(value_cache.dim(), 4); + TORCH_CHECK_EQ(key.stride(2), 1); + TORCH_CHECK_EQ(value.stride(2), 1); + + const int64_t token_num = key.size(0); + const int64_t key_token_num_stride = key.stride(0); + const int64_t value_token_num_stride = value.stride(0); + const int64_t head_num = value.size(1); + const int64_t key_head_num_stride = key.stride(1); + const int64_t value_head_num_stride = value.stride(1); + const int64_t num_blocks = key_cache.size(0); + const int64_t num_blocks_stride = key_cache.stride(0); + const int64_t cache_head_num_stride = key_cache.stride(1); + const int64_t block_size = key_cache.size(2); + const int64_t block_size_stride = key_cache.stride(2); + const int64_t head_dim = key.size(-1); + + cpu_attention::ISA isa_tag = [&]() { + if (isa == "amx") { + return cpu_attention::ISA::AMX; + } else if (isa == "vec") { + return cpu_attention::ISA::VEC; + } else if (isa == "vec16") { + return cpu_attention::ISA::VEC16; + } else { + TORCH_CHECK(false, "Invalid ISA type: " + isa); + } + }(); + + VLLM_DISPATCH_FLOATING_TYPES( + key.scalar_type(), "cpu_attn_reshape_and_cache", [&]() { + CPU_ATTN_DISPATCH_CASE_HEADDIM(head_dim, [&] { + CPU_ATTN_DISPATCH_IMPL(isa_tag, [&]() { + attn_impl::reshape_and_cache( + key.data_ptr(), value.data_ptr(), + key_cache.data_ptr(), + value_cache.data_ptr(), + slot_mapping.data_ptr(), token_num, + key_token_num_stride, value_token_num_stride, head_num, + key_head_num_stride, value_head_num_stride, num_blocks, + num_blocks_stride, cache_head_num_stride, block_size, + block_size_stride); + }); + }); + }); +} + +void cpu_attention_with_kv_cache( + const torch::Tensor& query, // [num_tokens, num_heads, head_size] + const torch::Tensor& + key_cache, // [num_blocks, num_kv_heads, block_size, head_size] + const torch::Tensor& + value_cache, // [num_blocks, num_kv_heads, block_size, head_size] + torch::Tensor& output, // [num_tokens, num_heads, head_size] + const torch::Tensor& query_start_loc, // [num_tokens + 1] + const torch::Tensor& seq_lens, // [num_tokens] + const double scale, const bool causal, + const std::optional& alibi_slopes, // [num_heads] + const int64_t sliding_window_left, const int64_t sliding_window_right, + const torch::Tensor& block_table, // [num_tokens, max_block_num] + const double softcap, const torch::Tensor& scheduler_metadata, + const std::optional& s_aux // [num_heads] +) { + TORCH_CHECK_EQ(query.dim(), 3); + TORCH_CHECK_EQ(query.stride(2), 1); + TORCH_CHECK_EQ(key_cache.dim(), 4); + TORCH_CHECK_EQ(value_cache.dim(), 4); + + cpu_attention::AttentionInput input; + input.metadata = reinterpret_cast( + scheduler_metadata.data_ptr()); + input.num_tokens = query.size(0); + input.num_heads = query.size(1); + input.num_kv_heads = key_cache.size(1); + input.block_size = key_cache.size(2); + input.query = query.data_ptr(); + input.query_num_tokens_stride = query.stride(0); + input.query_num_heads_stride = query.stride(1); + input.cache_num_blocks_stride = key_cache.stride(0); + input.cache_num_kv_heads_stride = key_cache.stride(1); + input.blt_num_tokens_stride = block_table.stride(0); + input.key_cache = key_cache.data_ptr(); + input.value_cache = value_cache.data_ptr(); + input.output = output.data_ptr(); + input.query_start_loc = query_start_loc.data_ptr(); + input.seq_lens = seq_lens.data_ptr(); + input.block_table = block_table.data_ptr(); + input.alibi_slopes = + alibi_slopes.has_value() ? alibi_slopes->data_ptr() : nullptr; + // For now sink must be bf16 + input.s_aux = s_aux.has_value() ? s_aux->data_ptr() : nullptr; + input.scale = scale; + input.causal = causal; + input.sliding_window_left = sliding_window_left; + input.sliding_window_right = sliding_window_right; + if (input.causal) { + // to make boundary calculation easier + input.sliding_window_right = 0; + } + float softcap_fp32 = softcap; + input.softcap = softcap_fp32; + + VLLM_DISPATCH_FLOATING_TYPES( + query.scalar_type(), "cpu_attention_with_kv_cache", [&]() { + CPU_ATTN_DISPATCH_CASE_HEADDIM(query.size(2), [&] { + CPU_ATTN_DISPATCH_IMPL(input.metadata->isa, [&]() { + TORCH_CHECK_EQ(input.block_size % attn_impl::BlockSizeAlignment, 0); + cpu_attention::AttentionMainLoop mainloop; + mainloop(&input); + }); + }); + }); +} diff --git a/csrc/cpu/cpu_attn_amx.hpp b/csrc/cpu/cpu_attn_amx.hpp new file mode 100644 index 000000000000..8da458b99119 --- /dev/null +++ b/csrc/cpu/cpu_attn_amx.hpp @@ -0,0 +1,511 @@ +#ifndef CPU_ATTN_AMX_HPP +#define CPU_ATTN_AMX_HPP + +#include "cpu_attn_impl.hpp" + +namespace cpu_attention { +namespace { +// AMX specific +constexpr static int64_t AMX_TILE_ROW_BYTES = 64; +constexpr static int64_t AMX_TILE_ROW_NUM = 16; +constexpr static int64_t AMX_TILE_BYTES = AMX_TILE_ROW_BYTES * AMX_TILE_ROW_NUM; + +typedef struct __tile_config { + uint8_t palette_id = 1; + uint8_t start_row = 0; + uint8_t reserved_0[14] = {0}; + uint16_t colsb[16] = {0}; + uint8_t rows[16] = {0}; +} __tilecfg; + +// 2-2-4 pattern, for 16 < m <= 32 +// TILE 0, 1: load A matrix, row num should be 16, m - 16 +// TILE 2, 3: load B matrix, row num should be 16 +// TILE 4, 5, 6, 7: store results C matrix, row num should be 16, 16, m - 16, m +// - 16 +template +class TileGemm224 { + public: + template + FORCE_INLINE static void gemm(const int32_t m_size, void* __restrict__ a_tile, + void* __restrict__ b_tile, + float* __restrict__ c_tile, const int64_t lda, + const int64_t ldb, const int64_t ldc, + const int32_t block_size, + const int32_t dynamic_k_size, + const bool accum_c) { + TORCH_CHECK(false, "Unsupported kv cache type for TileGemm224"); + } + + FORCE_INLINE static void init_tile_config(int32_t m, __tilecfg& config) { + TORCH_CHECK(false, "Unsupported kv cache type for TileGemm224"); + } +}; + +template <> +class TileGemm224 { + public: + template + FORCE_INLINE static void gemm(const int32_t m_size, + c10::BFloat16* __restrict__ a_tile, + c10::BFloat16* __restrict__ b_tile, + float* __restrict__ c_tile, const int64_t lda, + const int64_t ldb, const int64_t ldc, + const int32_t block_size, + const int32_t dynamic_k_size, + const bool accum_c) { + const int32_t k_times = + dynamic_k_size / (AMX_TILE_ROW_NUM * 4 / sizeof(c10::BFloat16)); + c10::BFloat16* __restrict__ a_tile_0 = a_tile; + c10::BFloat16* __restrict__ a_tile_1 = a_tile + lda * AMX_TILE_ROW_NUM; + const int64_t a_tile_stride = [&]() { + if constexpr (phase == AttentionGemmPhase::QK) { + // q_buffer is prepacked + return AMX_TILE_ROW_BYTES; + } else if constexpr (phase == AttentionGemmPhase::PV) { + // logits_buffer is row-major + return lda * sizeof(c10::BFloat16); + } else { + TORCH_CHECK(false, "Unreachable"); + } + }(); + + c10::BFloat16* __restrict__ b_tile_2 = b_tile; + c10::BFloat16* __restrict__ b_tile_3 = [&]() { + if constexpr (phase == AttentionGemmPhase::QK) { + // k_cache is prepacked + return b_tile + (k_size * AMX_TILE_ROW_BYTES / 4); + } else if constexpr (phase == AttentionGemmPhase::PV) { + // v_cache is prepacked + return b_tile + (block_size * AMX_TILE_ROW_BYTES / 4); + } else { + TORCH_CHECK(false, "Unreachable"); + } + }(); + // k_cache, v_cache are prepacked + const int32_t b_tile_stride = AMX_TILE_ROW_BYTES; + + // logits_buffer, output_buffer are not prepacked + float* __restrict__ c_tile_4 = c_tile; + float* __restrict__ c_tile_5 = + c_tile_4 + AMX_TILE_ROW_BYTES / sizeof(float); + float* __restrict__ c_tile_6 = c_tile + AMX_TILE_ROW_NUM * ldc; + float* __restrict__ c_tile_7 = + c_tile_6 + AMX_TILE_ROW_BYTES / sizeof(float); + const int32_t c_tile_stride = ldc * sizeof(float); + + if (accum_c) { + _tile_loadd(4, c_tile_4, c_tile_stride); + _tile_loadd(5, c_tile_5, c_tile_stride); + _tile_loadd(6, c_tile_6, c_tile_stride); + _tile_loadd(7, c_tile_7, c_tile_stride); + } else { + _tile_zero(4); + _tile_zero(5); + _tile_zero(6); + _tile_zero(7); + } + + for (int32_t k = 0; k < k_times; ++k) { + _tile_loadd(0, a_tile_0, a_tile_stride); + _tile_stream_loadd(2, b_tile_2, b_tile_stride); + _tile_dpbf16ps(4, 0, 2); + _tile_stream_loadd(3, b_tile_3, b_tile_stride); + _tile_dpbf16ps(5, 0, 3); + _tile_loadd(1, a_tile_1, a_tile_stride); + _tile_dpbf16ps(6, 1, 2); + _tile_dpbf16ps(7, 1, 3); + + // update ptrs + if constexpr (phase == AttentionGemmPhase::QK) { + // Q buffer is prepacked + a_tile_0 += AMX_TILE_BYTES / sizeof(c10::BFloat16); + a_tile_1 += AMX_TILE_BYTES / sizeof(c10::BFloat16); + } else if constexpr (phase == AttentionGemmPhase::PV) { + // P buffer is not prepacked + a_tile_0 += AMX_TILE_ROW_BYTES / sizeof(c10::BFloat16); + a_tile_1 += AMX_TILE_ROW_BYTES / sizeof(c10::BFloat16); + } else { + TORCH_CHECK(false, "Unreachable"); + } + b_tile_2 += AMX_TILE_BYTES / sizeof(c10::BFloat16); + b_tile_3 += AMX_TILE_BYTES / sizeof(c10::BFloat16); + } + + _tile_stored(4, c_tile_4, c_tile_stride); + _tile_stored(5, c_tile_5, c_tile_stride); + _tile_stored(6, c_tile_6, c_tile_stride); + _tile_stored(7, c_tile_7, c_tile_stride); + } + + FORCE_INLINE static void init_tile_config(int32_t m, __tilecfg& config) { + const int32_t m_0 = AMX_TILE_ROW_NUM; + const int32_t m_1 = m - AMX_TILE_ROW_NUM; + config.rows[0] = m_0; + config.rows[1] = m_1; + config.rows[2] = AMX_TILE_ROW_NUM; + config.rows[3] = AMX_TILE_ROW_NUM; + config.rows[4] = m_0; + config.rows[5] = m_0; + config.rows[6] = m_1; + config.rows[7] = m_1; + _tile_loadconfig(&config); + } +}; + +// 1-2-2 pattern, for 0 < m <= 16 +// TILE 0, (1): load A matrix, use extra 1 tile for prefetch, row num should be +// m, m +// TILE 2, 3, (4, 5): load B matrix, use extra 2 tiles for prefetch, row +// num should be 16 +// TILE 6, 7, (6, 7): store results C matrix, row num should be +// m +template +class TileGemm122 { + public: + template + FORCE_INLINE static void gemm(const int32_t m_size, void* __restrict__ a_tile, + void* __restrict__ b_tile, + float* __restrict__ c_tile, const int64_t lda, + const int64_t ldb, const int64_t ldc, + const int32_t block_size, + const int32_t dynamic_k_size, + const bool accum_c) { + TORCH_CHECK(false, "Unsupported kv cache type for TileGemm122"); + } + + FORCE_INLINE static void init_tile_config(int32_t m, __tilecfg& config) { + TORCH_CHECK(false, "Unsupported kv cache type for TileGemm122"); + } +}; + +template <> +class TileGemm122 { + public: + template + FORCE_INLINE static void gemm(const int32_t m_size, + c10::BFloat16* __restrict__ a_tile, + c10::BFloat16* __restrict__ b_tile, + float* __restrict__ c_tile, const int64_t lda, + const int64_t ldb, const int64_t ldc, + const int32_t block_size, + const int32_t dynamic_k_size, + const bool accum_c) { + c10::BFloat16* __restrict__ a_tile_0 = a_tile; + c10::BFloat16* __restrict__ a_tile_1 = [&]() { + if constexpr (phase == AttentionGemmPhase::QK) { + // q_buffer is prepacked + return a_tile + AMX_TILE_BYTES / sizeof(c10::BFloat16); + } else if constexpr (phase == AttentionGemmPhase::PV) { + // logits_buffer is row-major + return a_tile + AMX_TILE_ROW_BYTES / sizeof(c10::BFloat16); + } else { + TORCH_CHECK(false, "Unreachable"); + } + }(); + const int64_t a_tile_stride = [&]() { + if constexpr (phase == AttentionGemmPhase::QK) { + // q_buffer is prepacked + return AMX_TILE_ROW_BYTES; + } else if constexpr (phase == AttentionGemmPhase::PV) { + // logits_buffer is row-major + return lda * sizeof(c10::BFloat16); + } else { + TORCH_CHECK(false, "Unreachable"); + } + }(); + + c10::BFloat16* __restrict__ b_tile_2 = b_tile; + c10::BFloat16* __restrict__ b_tile_3 = [&]() { + if constexpr (phase == AttentionGemmPhase::QK) { + // k_cache is prepacked + return b_tile + (k_size * AMX_TILE_ROW_BYTES / 4); + } else if constexpr (phase == AttentionGemmPhase::PV) { + // v_cache is prepacked + return b_tile + (block_size * AMX_TILE_ROW_BYTES / 4); + } else { + TORCH_CHECK(false, "Unreachable"); + } + }(); + c10::BFloat16* __restrict__ b_tile_4 = + b_tile_2 + AMX_TILE_BYTES / sizeof(c10::BFloat16); + c10::BFloat16* __restrict__ b_tile_5 = + b_tile_3 + AMX_TILE_BYTES / sizeof(c10::BFloat16); + int64_t b_stride = AMX_TILE_ROW_BYTES; + + float* __restrict__ c_tile_6 = c_tile; + float* __restrict__ c_tile_7 = c_tile + AMX_TILE_ROW_BYTES / sizeof(float); + int64_t c_stride = ldc * sizeof(float); + + const int32_t k_times = + dynamic_k_size / (AMX_TILE_ROW_NUM * 4 / sizeof(c10::BFloat16)); + const int32_t k_group_times = k_times / 2; + const bool has_tail = (k_times % 2 == 1); + + if (accum_c) { + _tile_loadd(6, c_tile_6, c_stride); + _tile_loadd(7, c_tile_7, c_stride); + } else { + _tile_zero(6); + _tile_zero(7); + } + + for (int32_t k = 0; k < k_group_times; ++k) { + _tile_loadd(0, a_tile_0, a_tile_stride); + _tile_stream_loadd(2, b_tile_2, b_stride); + _tile_dpbf16ps(6, 0, 2); + _tile_stream_loadd(3, b_tile_3, b_stride); + _tile_dpbf16ps(7, 0, 3); + _tile_loadd(1, a_tile_1, a_tile_stride); + _tile_stream_loadd(4, b_tile_4, b_stride); + _tile_dpbf16ps(6, 1, 4); + _tile_stream_loadd(5, b_tile_5, b_stride); + _tile_dpbf16ps(7, 1, 5); + + // update ptrs + if constexpr (phase == AttentionGemmPhase::QK) { + // Q buffer is prepacked + a_tile_0 += 2 * AMX_TILE_BYTES / sizeof(c10::BFloat16); + a_tile_1 += 2 * AMX_TILE_BYTES / sizeof(c10::BFloat16); + } else if constexpr (phase == AttentionGemmPhase::PV) { + // P buffer is not prepacked + a_tile_0 += 2 * AMX_TILE_ROW_BYTES / sizeof(c10::BFloat16); + a_tile_1 += 2 * AMX_TILE_ROW_BYTES / sizeof(c10::BFloat16); + } + b_tile_2 += 2 * AMX_TILE_BYTES / sizeof(c10::BFloat16); + b_tile_3 += 2 * AMX_TILE_BYTES / sizeof(c10::BFloat16); + b_tile_4 += 2 * AMX_TILE_BYTES / sizeof(c10::BFloat16); + b_tile_5 += 2 * AMX_TILE_BYTES / sizeof(c10::BFloat16); + } + + if (has_tail) { + _tile_loadd(0, a_tile_0, a_tile_stride); + _tile_stream_loadd(2, b_tile_2, b_stride); + _tile_dpbf16ps(6, 0, 2); + _tile_stream_loadd(3, b_tile_3, b_stride); + _tile_dpbf16ps(7, 0, 3); + } + + _tile_stored(6, c_tile_6, c_stride); + _tile_stored(7, c_tile_7, c_stride); + } + + FORCE_INLINE static void init_tile_config(int32_t m, __tilecfg& config) { + config.rows[0] = m; + config.rows[1] = m; + config.rows[2] = AMX_TILE_ROW_NUM; + config.rows[3] = AMX_TILE_ROW_NUM; + config.rows[4] = AMX_TILE_ROW_NUM; + config.rows[5] = AMX_TILE_ROW_NUM; + config.rows[6] = m; + config.rows[7] = m; + _tile_loadconfig(&config); + } +}; +} // namespace + +template +class AttentionImpl { + public: + using query_t = scalar_t; + using q_buffer_t = scalar_t; + using kv_cache_t = scalar_t; + using logits_buffer_t = float; + using partial_output_buffer_t = float; + using prob_buffer_t = scalar_t; + + constexpr static int64_t BlockSizeAlignment = + AMX_TILE_ROW_BYTES / + sizeof(kv_cache_t); // KV token num unit of QK and PV phases + constexpr static int64_t HeadDimAlignment = + 2 * (AMX_TILE_ROW_BYTES / 4); // headdim num unit of PV phase + constexpr static int64_t MaxQHeadNumPerIteration = 32; + constexpr static int64_t HeadDim = head_dim; + constexpr static ISA ISAType = ISA::AMX; + constexpr static bool scale_on_logits = true; + + public: + AttentionImpl() : current_q_head_num_(0) { + // Use all columns in AMX tiles + vec_op::unroll_loop([&](int i) { amx_tile_config_.colsb[i] = 64; }); + } + + ~AttentionImpl() { _tile_release(); } + + template