diff --git a/lmdeploy/pytorch/backends/dlinfer/ascend/op_backend.py b/lmdeploy/pytorch/backends/dlinfer/ascend/op_backend.py index 8cc41f5df8..9fca4754da 100644 --- a/lmdeploy/pytorch/backends/dlinfer/ascend/op_backend.py +++ b/lmdeploy/pytorch/backends/dlinfer/ascend/op_backend.py @@ -13,6 +13,7 @@ from lmdeploy.pytorch import envs as _envs from lmdeploy.pytorch.config import BackendConfig, CacheConfig, ModelConfig from lmdeploy.pytorch.distributed import get_dist_manager +from lmdeploy.pytorch.model_inputs import get_step_ctx_manager from lmdeploy.utils import get_logger from ..moe import DlinferMoECommType, DlinferMoeMetadata @@ -156,8 +157,23 @@ def update_step_context(cls, step_context): block_num, block_size, *_ = step_context.kv_caches[0][0].shape is_prefill_no_cache = False + num_spec_tokens = get_step_ctx_manager().build_ctx.num_spec_tokens + if not step_context.is_decoding: is_prefill_no_cache = all((step_context.q_seqlens == step_context.kv_seqlens).tolist()) + is_multi_token_decoding = False + is_decoding = False + else: + is_multi_token_decoding = step_context.q_seqlens.max().item() > 1 + # is_decoding: True only for regular single-token decode (original semantics) + is_decoding = not is_multi_token_decoding + + # MoE EP dispatch/combine and graph capture are collective ops shared by all + # DP ranks, so they must agree on decode-vs-prefill. Use the DP-global state + # (if any rank is prefill, all ranks are prefill) for those paths; the local + # is_decoding / is_multi_token_decoding above stay rank-local for attention. + global_is_decoding = step_context.global_is_decoding() + if step_context.block_offsets.dtype != torch.int32: step_context.block_offsets = step_context.block_offsets.to(torch.int32) if step_context.kv_seqlens.dtype != torch.int32: @@ -180,8 +196,6 @@ def get_cpu_seqlens(is_decoding, is_prefill_no_cache): q_seqlens_cpu: query sequence lengths (per sequence). kv_seqlens_cpu: kv sequence lengths (per sequence), used for list/max seqlens calculation. - kv_seqlens_expanded: kv sequence lengths expanded per token via - repeat_interleave, used for attention metadata. """ if is_decoding: q_seqlens_cpu = None @@ -219,7 +233,8 @@ def update_q_seqlens(is_decoding, is_prefill_no_cache, q_seqlens_cpu=None): return torch.arange(1, batch_size + 1, dtype=torch.int32) elif is_prefill_no_cache: return q_seqlens_cpu - return q_seqlens_cpu.cumsum(dim=0) + # for paged_prefill, eg. MTP, prefix caching + return q_seqlens_cpu.cumsum(dim=0).to(torch.int32) def get_kv_start_indices_and_attention_mask(is_decoding, is_prefill_no_cache, q_seqlens_list, kv_seqlens_list, max_q_seq_len, max_kv_seq_len): @@ -277,12 +292,29 @@ def get_tokens_info(dp_size, tp_size, ep_size, ep_group): if ep_size <= 1: return 0, 0, 0 # get padded_tokens_current_rank - is_graph = cls.enable_graph and step_context.is_decoding + is_graph = cls.enable_graph and global_is_decoding and (is_decoding or is_multi_token_decoding) if is_graph: from dlinfer.framework.lmdeploy_ext.cudagraph.ascend_cudagraph import get_ascend_compatible_size - actual_tokens_current_rank = step_context.q_seqlens.shape[0] - padded_tokens_current_rank = min(get_ascend_compatible_size(actual_tokens_current_rank), - cls.max_batches) + # The cudagraph is keyed/captured on the GLOBAL padded batch + # (max over all DP ranks), so every DP rank executes the MoE with + # the same global token count. padded_tokens_current_rank must + # therefore mirror that global captured size; deriving it from this + # rank's local batch makes DP ranks disagree on the MC2 + # dispatch/combine token count and corrupts the collective + # (MoeDistributeCombineV2 AICORE out-of-bounds). dp_meta.dp_batches + # holds the per-rank sequence counts; its max is the global batch + # the graph capture uses. + dp_meta = step_context.dp_meta + if dp_meta is not None and dp_meta.dp_batches: + global_batch = max(dp_meta.dp_batches) + else: + global_batch = step_context.q_seqlens.shape[0] + query_len = (num_spec_tokens + 1) if is_multi_token_decoding else 1 + # actual tokens: this rank's real (non-padded) token count, used to + # build x_active_mask so MC2 ignores the graph padding region. + actual_tokens_current_rank = step_context.q_seqlens.sum().item() + padded_tokens_current_rank = min(get_ascend_compatible_size(global_batch), + cls.max_batches) * query_len else: actual_tokens_current_rank = step_context.q_seqlens.sum().item() padded_tokens_current_rank = actual_tokens_current_rank @@ -303,7 +335,7 @@ def get_tokens_info(dp_size, tp_size, ep_size, ep_group): @lru_cache def init_mc2_token_capacity(tp_size): - max_num_tokens = min(cls.max_batches, 512) + max_num_tokens = min(cls.max_batches * (num_spec_tokens + 1), 512) num_tokens_per_tp_rank = (max_num_tokens + tp_size - 1) // tp_size return num_tokens_per_tp_rank * tp_size @@ -311,7 +343,7 @@ def select_moe_comm_type(max_tokens_across_dp, dp_size, tp_size, ep_size): if ep_size <= 1: return DlinferMoECommType.ALLGATHER mc2_token_capacity = init_mc2_token_capacity(tp_size) - is_graph = cls.enable_graph and step_context.is_decoding + is_graph = cls.enable_graph and global_is_decoding and (is_decoding or is_multi_token_decoding) if is_graph: max_tokens_across_dp = math.ceil(max_tokens_across_dp / tp_size) * tp_size if SocVersion.is_A2(): @@ -320,7 +352,7 @@ def select_moe_comm_type(max_tokens_across_dp, dp_size, tp_size, ep_size): else: return DlinferMoECommType.ALLGATHER elif SocVersion.is_A3(): - if max_tokens_across_dp <= mc2_token_capacity: + if max_tokens_across_dp <= mc2_token_capacity and global_is_decoding: return DlinferMoECommType.MC2 else: return DlinferMoECommType.ALLTOALL @@ -337,7 +369,7 @@ def get_pad_info(actual_tokens_current_rank, padded_tokens_current_rank, max_tok dtype=torch.bool, device=torch.npu.current_device()) elif moe_comm_type == DlinferMoECommType.ALLTOALL: - pad_size = tp_size - padded_tokens_current_rank + pad_size = (-padded_tokens_current_rank) % tp_size elif moe_comm_type == DlinferMoECommType.ALLGATHER: pad_size = max_tokens_across_dp - padded_tokens_current_rank else: @@ -353,16 +385,16 @@ def get_moe_group_name(group): group_name = backend.get_hccl_comm_name(local_rank) return group_name - q_seqlens_cpu, kv_seqlens_cpu = get_cpu_seqlens(step_context.is_decoding, is_prefill_no_cache) - q_seqlens_list, kv_seqlens_list = get_list_seqlens(step_context.is_decoding, is_prefill_no_cache, q_seqlens_cpu, + q_seqlens_cpu, kv_seqlens_cpu = get_cpu_seqlens(is_decoding, is_prefill_no_cache) + q_seqlens_list, kv_seqlens_list = get_list_seqlens(is_decoding, is_prefill_no_cache, q_seqlens_cpu, kv_seqlens_cpu) - max_q_seq_len, max_kv_seq_len = get_max_seqlens(step_context.is_decoding, is_prefill_no_cache, q_seqlens_list, + max_q_seq_len, max_kv_seq_len = get_max_seqlens(is_decoding, is_prefill_no_cache, q_seqlens_list, kv_seqlens_list) - kv_start_indices, attention_mask = get_kv_start_indices_and_attention_mask(step_context.is_decoding, + kv_start_indices, attention_mask = get_kv_start_indices_and_attention_mask(is_decoding, is_prefill_no_cache, q_seqlens_list, kv_seqlens_list, max_q_seq_len, max_kv_seq_len) - q_seqlens_cpu = update_q_seqlens(step_context.is_decoding, is_prefill_no_cache, q_seqlens_cpu) + q_seqlens_cpu = update_q_seqlens(is_decoding, is_prefill_no_cache, q_seqlens_cpu) if not cls.enable_graph and step_context.kv_quant_policy == 8: record_file = os.getenv('ASCEND_QUANT_RECORD_FILE') @@ -379,18 +411,57 @@ def get_moe_group_name(group): cu_seqlens = None has_initial_state = None - + spec_conv_offsets = None + spec_state_offsets = None + cache_seqlens = None is_gated_delta = step_context.model_config.is_gated_delta if is_gated_delta: - q_start_loc = step_context.q_start_loc.to(dtype=step_context.q_seqlens.dtype, - device=step_context.q_seqlens.device) - cu_seqlens = torch.cat((q_start_loc, step_context.q_seqlens.sum().unsqueeze(0))).int() - if not step_context.is_decoding: - has_initial_state = ~(step_context.q_seqlens == step_context.kv_seqlens) + q_seqlens = step_context.q_seqlens + kv_seqlens = step_context.kv_seqlens + + q_start_loc = step_context.q_start_loc.to(dtype=q_seqlens.dtype, + device=q_seqlens.device) + cu_seqlens = torch.cat((q_start_loc, q_seqlens.sum().unsqueeze(0))).int() + cache_seqlens = (kv_seqlens - q_seqlens).contiguous() + + + states_shapes = step_context.model_config.states_shapes + if not is_decoding and not is_multi_token_decoding and len(states_shapes) > 0: + has_initial_state = ~(q_seqlens == kv_seqlens) + # # Conv ring buffer: conv_state_len = conv_kernel_size + num_spec_tokens. + conv_state_len = states_shapes[0][0][0] + conv_kernel_size = conv_state_len - num_spec_tokens + + if num_spec_tokens > 0: + state_slots = 1 + num_spec_tokens + spec_state_offsets = ( + torch.remainder(cache_seqlens, state_slots), + torch.remainder(kv_seqlens, state_slots), + ) + + range_idx = torch.arange( + -conv_kernel_size, + 0, + device=cache_seqlens.device, + dtype=torch.int32, + ) + # Read the (conv_kernel_size - 1) tokens preceding the current write + # window from the circular buffer. + read_conv_offsets = torch.remainder( + cache_seqlens[:, None] + range_idx[1:][None], + conv_state_len, + ).to(torch.int64) + # Write the last conv_kernel_size tokens of this prefill batch into + # circular-buffer slots so the next decode read aligns naturally. + write_conv_offsets = torch.remainder( + kv_seqlens[:, None] + range_idx[None], + conv_state_len, + ).to(torch.int64) + spec_conv_offsets = (read_conv_offsets, write_conv_offsets) attn_meta_cls = cls.get_attention_metadata_cls() attn_metadata = attn_meta_cls( - step_context.is_decoding, + is_decoding, step_context.block_offsets, # cu_seqlens is only used in GDN and is passed down via q_start_loc. # Otherwise, q_start_loc is None. @@ -406,6 +477,10 @@ def get_moe_group_name(group): quant_policy=step_context.kv_quant_policy, quant_meta=AscendKVQuantMeta.quant_meta, has_initial_state=has_initial_state, + is_multi_token_decoding=is_multi_token_decoding, + spec_conv_offsets=spec_conv_offsets, + spec_state_offsets=spec_state_offsets, + cache_seqlens=cache_seqlens, ) step_context.attn_metadata = attn_metadata @@ -462,6 +537,14 @@ def init(): logger.warning(f'Error during Ascend initialization: {str(e)}. ' 'Please check your Ascend environment configuration.') + try: + import dlinfer.framework.lmdeploy_ext.device # noqa: F401 — triggers vendor_device_init() + except ImportError: + logger.warning('dlinfer framework extensions not found. ' + 'Ascend-specific model patches will not be applied.') + except Exception as e: + logger.warning(f'Error during dlinfer extension initialization: {str(e)}. ' + 'Ascend-specific model patches may not be applied.') try: from dlinfer.vendor.ascend.triton_ops.triton_utils import init_device_properties_triton init_device_properties_triton() diff --git a/lmdeploy/pytorch/backends/dlinfer/attention.py b/lmdeploy/pytorch/backends/dlinfer/attention.py index a8eea27545..45756fb7c4 100644 --- a/lmdeploy/pytorch/backends/dlinfer/attention.py +++ b/lmdeploy/pytorch/backends/dlinfer/attention.py @@ -19,6 +19,10 @@ class DlinferAttentionMetadata(AttentionMetadata): quant_meta: dict = None cu_seq_lens_kv: Tensor | None = None has_initial_state: Tensor | None = None + is_multi_token_decoding: bool = False + spec_conv_offsets: Sequence[Tensor] | None = None + spec_state_offsets: Sequence[Tensor] | None = None + cache_seqlens: Tensor | None = None class DlinferAttentionImpl(AttentionImpl[DlinferAttentionMetadata]): diff --git a/lmdeploy/pytorch/config.py b/lmdeploy/pytorch/config.py index 344e8b53b7..3d1eb0a593 100644 --- a/lmdeploy/pytorch/config.py +++ b/lmdeploy/pytorch/config.py @@ -627,6 +627,7 @@ def from_config( block_size=target_cache_cfg.block_size, model_format=model_format, hf_overrides=hf_overrides, + device_type=target_cache_cfg.device_type, ) cache_config = None # include medusa