diff --git a/src/twinkle/kernel/monkey_patch_npu.py b/src/twinkle/kernel/monkey_patch_npu.py index acb2a6aa..07213968 100644 --- a/src/twinkle/kernel/monkey_patch_npu.py +++ b/src/twinkle/kernel/monkey_patch_npu.py @@ -44,6 +44,19 @@ def _resolve_unsqueeze_dim(position_ids=None, unsqueeze_dim=1): return unsqueeze_dim +def _is_ep_enabled(model=None) -> bool: + r"""Check whether Expert Parallelism (EP) is enabled. + + EP is detected via ``device_mesh.ep_size > 1``. + When EP is active, each rank holds only a subset of expert weights, + making ``npu_grouped_matmul`` efficient (small contiguous weights). + """ + device_mesh = getattr(model, 'device_mesh', None) + if device_mesh is None: + return False + return (getattr(device_mesh, 'ep_size', None) or 0) > 1 + + # ============================================================================= # Section 1: MoE Grouped MatMul (GMM) # ============================================================================= @@ -99,8 +112,45 @@ def _grouped_mm_npu(input: torch.tensor, weight_ekn: torch.tensor, offs: torch.t return GmmFunction.apply(input, counts, weight_ekn) -def apply_hf_moe_grouped_mm_patch() -> None: - r"""Patch HuggingFace MoE integration to use NPU grouped matmul.""" +def _apply_hf_moe_grouped_mm_patch(model=None) -> None: + r"""Patch HuggingFace MoE integration to use NPU grouped matmul. + + When Expert Parallelism (EP) is **not** enabled, each rank holds **all** + expert weights. ``weight.transpose(-2, -1)`` then produces a large + non-contiguous view that ``npu_grouped_matmul`` forces to ``.contiguous()`` + (~12.88 GB per MoE layer), creating a bandwidth bottleneck that makes the + NPU patch **slower** than the native per-expert fallback (~8x overhead). + + Detection logic: + - ``TWINKLE_NPU_GMM_PATCH`` not set → **skip** the patch by default. + - ``TWINKLE_NPU_GMM_PATCH=1`` → EP-aware: apply only if EP is enabled + (each rank has few experts, weights are small and contiguous); + skip if EP is **not** enabled (avoid ~8x overhead). + - ``TWINKLE_NPU_GMM_PATCH=0`` → **disable** the patch regardless. + """ + moe_enabled = _is_env_enabled('TWINKLE_NPU_GMM_PATCH', default=False) + + if not moe_enabled: + has_native_gmm = hasattr(torch.nn.functional, 'grouped_mm') + logger.info( + '[PATCH] TWINKLE_NPU_GMM_PATCH not set: MoE GMM patch skipped by default. ' + 'Set TWINKLE_NPU_GMM_PATCH=1 to enable (EP-aware). ' + 'Native grouped_mm available: %s.', + has_native_gmm, + ) + return + + if not _is_ep_enabled(model): + has_native_gmm = hasattr(torch.nn.functional, 'grouped_mm') + logger.info( + '[PATCH] TWINKLE_NPU_GMM_PATCH=1 but EP not enabled (all experts on each rank) — ' + 'skipping _grouped_mm_npu patch to avoid ~8x overhead from ' + 'contiguous copies on transposed weights. ' + 'Native grouped_mm available: %s.', + has_native_gmm, + ) + return + import transformers.integrations.moe as hf_moe hf_moe._grouped_mm = _grouped_mm_npu logger.info('[PATCH] transformers.integrations.moe._grouped_mm -> _grouped_mm_npu') @@ -865,24 +915,34 @@ def apply_npu_patch(model=None) -> None: - SDPA Attention compatibility fixes - Flash Linear Attention (FLA) for Qwen3.5 + When ``model`` is **not** provided, the GMM patch is **skipped** by default + (EP cannot be detected without a model instance). + + When ``model`` is provided, the GMM patch is evaluated with EP detection: + - EP enabled → apply GMM patch (efficient on small sharded weights). + - EP not enabled → skip GMM patch (avoid ~8x contiguous-copy overhead). + Environment variables: - ``TWINKLE_NPU_PATCH``: overall switch (``1``/``0``) - ``TWINKLE_NPU_FUSED_OPS``: fused ops switch (``1``/``0``) - - ``TWINKLE_NPU_MOE_PATCH``: MoE GMM switch (``1``/``0``) + - ``TWINKLE_NPU_GMM_PATCH``: MoE GMM switch (``1``/``0``/unset). + When unset: skip the patch by default. + When ``1``: EP-aware — patch is applied **only if EP is enabled**; + without EP the native grouped_mm or per-expert fallback is used + (avoiding ~8x overhead from contiguous copies). + When ``0``: disable the patch regardless. - ``TWINKLE_NPU_FLA``: FLA switch (``1``/``0``) - ``TWINKLE_NPU_GATED_RMSNorm_FP32``: force FP32 in Gated RMSNorm (``1``/``0``) Args: - model: Optional model instance. Required for FLA to traverse and - replace per-instance ``chunk_gated_delta_rule`` bindings. + model: Optional model instance. If not provided, GMM patch is skipped. + If provided, GMM patch is evaluated with EP detection on the model. """ global _NPU_PATCH_APPLIED if not _is_env_enabled('TWINKLE_NPU_PATCH', default=True): return - moe_enabled = _is_env_enabled('TWINKLE_NPU_MOE_PATCH', default=True) - if _NPU_PATCH_APPLIED: logger.debug('[NPU] Patches already applied, skipping.') return @@ -893,8 +953,7 @@ def apply_npu_patch(model=None) -> None: logger.warning('torch_npu not available. Skipping NPU patches.') return - if moe_enabled: - apply_hf_moe_grouped_mm_patch() + _apply_hf_moe_grouped_mm_patch(model) _apply_all_fused_ops(model)