From 98e69cd2a4dfa33eac4d35c3308d55e34e0205f0 Mon Sep 17 00:00:00 2001 From: 0hujun <96733800+0hujun@users.noreply.github.com> Date: Thu, 28 May 2026 15:17:28 +0800 Subject: [PATCH 1/3] fix: Npu Group MatMul op patchs only in EP --- src/twinkle/kernel/monkey_patch_npu.py | 76 +++++++++++++++++++++++--- 1 file changed, 67 insertions(+), 9 deletions(-) diff --git a/src/twinkle/kernel/monkey_patch_npu.py b/src/twinkle/kernel/monkey_patch_npu.py index acb2a6aa..48f5b6af 100644 --- a/src/twinkle/kernel/monkey_patch_npu.py +++ b/src/twinkle/kernel/monkey_patch_npu.py @@ -44,6 +44,18 @@ def _resolve_unsqueeze_dim(position_ids=None, unsqueeze_dim=1): return unsqueeze_dim +def _is_ep_enabled(model=None) -> bool: + r"""Check whether Expert Parallelism (EP) is enabled. + + EP is detected via ``device_mesh.ep_size > 1``. + When EP is active, each rank holds only a subset of expert weights, + making ``npu_grouped_matmul`` efficient (small contiguous weights). + """ + if model is None or model.device_mesh is None: + return False + return (getattr(model.device_mesh, 'ep_size', None) or 0) > 1 + + # ============================================================================= # Section 1: MoE Grouped MatMul (GMM) # ============================================================================= @@ -99,8 +111,45 @@ def _grouped_mm_npu(input: torch.tensor, weight_ekn: torch.tensor, offs: torch.t return GmmFunction.apply(input, counts, weight_ekn) -def apply_hf_moe_grouped_mm_patch() -> None: - r"""Patch HuggingFace MoE integration to use NPU grouped matmul.""" +def _apply_hf_moe_grouped_mm_patch(model=None) -> None: + r"""Patch HuggingFace MoE integration to use NPU grouped matmul. + + When Expert Parallelism (EP) is **not** enabled, each rank holds **all** + expert weights. ``weight.transpose(-2, -1)`` then produces a large + non-contiguous view that ``npu_grouped_matmul`` forces to ``.contiguous()`` + (~12.88 GB per MoE layer), creating a bandwidth bottleneck that makes the + NPU patch **slower** than the native per-expert fallback (~8x overhead). + + Detection logic: + - ``TWINKLE_NPU_GMM_PATCH`` not set → **skip** the patch by default. + - ``TWINKLE_NPU_GMM_PATCH=1`` → EP-aware: apply only if EP is enabled + (each rank has few experts, weights are small and contiguous); + skip if EP is **not** enabled (avoid ~8x overhead). + - ``TWINKLE_NPU_GMM_PATCH=0`` → **disable** the patch regardless. + """ + moe_enabled = _is_env_enabled('TWINKLE_NPU_GMM_PATCH', default=True) + + if not moe_enabled: + has_native_gmm = hasattr(torch.nn.functional, 'grouped_mm') + logger.info( + '[PATCH] TWINKLE_NPU_GMM_PATCH not set: MoE GMM patch skipped by default. ' + 'Set TWINKLE_NPU_GMM_PATCH=1 to enable (EP-aware). ' + 'Native grouped_mm available: %s.', + has_native_gmm, + ) + return + + if not _is_ep_enabled(model): + has_native_gmm = hasattr(torch.nn.functional, 'grouped_mm') + logger.info( + '[PATCH] TWINKLE_NPU_GMM_PATCH=1 but EP not enabled (all experts on each rank) — ' + 'skipping _grouped_mm_npu patch to avoid ~8x overhead from ' + 'contiguous copies on transposed weights. ' + 'Native grouped_mm available: %s.', + has_native_gmm, + ) + return + import transformers.integrations.moe as hf_moe hf_moe._grouped_mm = _grouped_mm_npu logger.info('[PATCH] transformers.integrations.moe._grouped_mm -> _grouped_mm_npu') @@ -865,24 +914,34 @@ def apply_npu_patch(model=None) -> None: - SDPA Attention compatibility fixes - Flash Linear Attention (FLA) for Qwen3.5 + When ``model`` is **not** provided, the GMM patch is **skipped** by default + (EP cannot be detected without a model instance). + + When ``model`` is provided, the GMM patch is evaluated with EP detection: + - EP enabled → apply GMM patch (efficient on small sharded weights). + - EP not enabled → skip GMM patch (avoid ~8x contiguous-copy overhead). + Environment variables: - ``TWINKLE_NPU_PATCH``: overall switch (``1``/``0``) - ``TWINKLE_NPU_FUSED_OPS``: fused ops switch (``1``/``0``) - - ``TWINKLE_NPU_MOE_PATCH``: MoE GMM switch (``1``/``0``) + - ``TWINKLE_NPU_GMM_PATCH``: MoE GMM switch (``1``/``0``/unset). + When unset: skip the patch by default. + When ``1``: EP-aware — patch is applied **only if EP is enabled**; + without EP the native grouped_mm or per-expert fallback is used + (avoiding ~8x overhead from contiguous copies). + When ``0``: disable the patch regardless. - ``TWINKLE_NPU_FLA``: FLA switch (``1``/``0``) - ``TWINKLE_NPU_GATED_RMSNorm_FP32``: force FP32 in Gated RMSNorm (``1``/``0``) Args: - model: Optional model instance. Required for FLA to traverse and - replace per-instance ``chunk_gated_delta_rule`` bindings. + model: Optional model instance. If not provided, GMM patch is skipped. + If provided, GMM patch is evaluated with EP detection on the model. """ global _NPU_PATCH_APPLIED if not _is_env_enabled('TWINKLE_NPU_PATCH', default=True): return - moe_enabled = _is_env_enabled('TWINKLE_NPU_MOE_PATCH', default=True) - if _NPU_PATCH_APPLIED: logger.debug('[NPU] Patches already applied, skipping.') return @@ -893,8 +952,7 @@ def apply_npu_patch(model=None) -> None: logger.warning('torch_npu not available. Skipping NPU patches.') return - if moe_enabled: - apply_hf_moe_grouped_mm_patch() + _apply_hf_moe_grouped_mm_patch(model) _apply_all_fused_ops(model) From 1992ca0ad914e181c2e3da92221835f6b7183577 Mon Sep 17 00:00:00 2001 From: 0hujun <96733800+0hujun@users.noreply.github.com> Date: Thu, 28 May 2026 15:24:01 +0800 Subject: [PATCH 2/3] Update src/twinkle/kernel/monkey_patch_npu.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- src/twinkle/kernel/monkey_patch_npu.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/twinkle/kernel/monkey_patch_npu.py b/src/twinkle/kernel/monkey_patch_npu.py index 48f5b6af..2b9a31d4 100644 --- a/src/twinkle/kernel/monkey_patch_npu.py +++ b/src/twinkle/kernel/monkey_patch_npu.py @@ -51,9 +51,10 @@ def _is_ep_enabled(model=None) -> bool: When EP is active, each rank holds only a subset of expert weights, making ``npu_grouped_matmul`` efficient (small contiguous weights). """ - if model is None or model.device_mesh is None: + device_mesh = getattr(model, 'device_mesh', None) + if device_mesh is None: return False - return (getattr(model.device_mesh, 'ep_size', None) or 0) > 1 + return (getattr(device_mesh, 'ep_size', None) or 0) > 1 # ============================================================================= From 598c5abdd4714db92a5ae11e13e3c3c0d918d669 Mon Sep 17 00:00:00 2001 From: 0hujun <96733800+0hujun@users.noreply.github.com> Date: Thu, 28 May 2026 15:24:45 +0800 Subject: [PATCH 3/3] Update src/twinkle/kernel/monkey_patch_npu.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- src/twinkle/kernel/monkey_patch_npu.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/twinkle/kernel/monkey_patch_npu.py b/src/twinkle/kernel/monkey_patch_npu.py index 2b9a31d4..07213968 100644 --- a/src/twinkle/kernel/monkey_patch_npu.py +++ b/src/twinkle/kernel/monkey_patch_npu.py @@ -128,7 +128,7 @@ def _apply_hf_moe_grouped_mm_patch(model=None) -> None: skip if EP is **not** enabled (avoid ~8x overhead). - ``TWINKLE_NPU_GMM_PATCH=0`` → **disable** the patch regardless. """ - moe_enabled = _is_env_enabled('TWINKLE_NPU_GMM_PATCH', default=True) + moe_enabled = _is_env_enabled('TWINKLE_NPU_GMM_PATCH', default=False) if not moe_enabled: has_native_gmm = hasattr(torch.nn.functional, 'grouped_mm')