Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 68 additions & 9 deletions src/twinkle/kernel/monkey_patch_npu.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,19 @@ def _resolve_unsqueeze_dim(position_ids=None, unsqueeze_dim=1):
return unsqueeze_dim


def _is_ep_enabled(model=None) -> bool:
r"""Check whether Expert Parallelism (EP) is enabled.

EP is detected via ``device_mesh.ep_size > 1``.
When EP is active, each rank holds only a subset of expert weights,
making ``npu_grouped_matmul`` efficient (small contiguous weights).
"""
device_mesh = getattr(model, 'device_mesh', None)
if device_mesh is None:
return False
return (getattr(device_mesh, 'ep_size', None) or 0) > 1


# =============================================================================
# Section 1: MoE Grouped MatMul (GMM)
# =============================================================================
Expand Down Expand Up @@ -99,8 +112,45 @@ def _grouped_mm_npu(input: torch.tensor, weight_ekn: torch.tensor, offs: torch.t
return GmmFunction.apply(input, counts, weight_ekn)


def apply_hf_moe_grouped_mm_patch() -> None:
r"""Patch HuggingFace MoE integration to use NPU grouped matmul."""
def _apply_hf_moe_grouped_mm_patch(model=None) -> None:
r"""Patch HuggingFace MoE integration to use NPU grouped matmul.

When Expert Parallelism (EP) is **not** enabled, each rank holds **all**
expert weights. ``weight.transpose(-2, -1)`` then produces a large
non-contiguous view that ``npu_grouped_matmul`` forces to ``.contiguous()``
(~12.88 GB per MoE layer), creating a bandwidth bottleneck that makes the
NPU patch **slower** than the native per-expert fallback (~8x overhead).

Detection logic:
- ``TWINKLE_NPU_GMM_PATCH`` not set → **skip** the patch by default.
- ``TWINKLE_NPU_GMM_PATCH=1`` → EP-aware: apply only if EP is enabled
(each rank has few experts, weights are small and contiguous);
skip if EP is **not** enabled (avoid ~8x overhead).
- ``TWINKLE_NPU_GMM_PATCH=0`` → **disable** the patch regardless.
"""
moe_enabled = _is_env_enabled('TWINKLE_NPU_GMM_PATCH', default=False)

if not moe_enabled:
has_native_gmm = hasattr(torch.nn.functional, 'grouped_mm')
logger.info(
'[PATCH] TWINKLE_NPU_GMM_PATCH not set: MoE GMM patch skipped by default. '
'Set TWINKLE_NPU_GMM_PATCH=1 to enable (EP-aware). '
'Native grouped_mm available: %s.',
has_native_gmm,
)
return

if not _is_ep_enabled(model):
has_native_gmm = hasattr(torch.nn.functional, 'grouped_mm')
logger.info(
'[PATCH] TWINKLE_NPU_GMM_PATCH=1 but EP not enabled (all experts on each rank) — '
'skipping _grouped_mm_npu patch to avoid ~8x overhead from '
'contiguous copies on transposed weights. '
'Native grouped_mm available: %s.',
has_native_gmm,
)
return

import transformers.integrations.moe as hf_moe
hf_moe._grouped_mm = _grouped_mm_npu
logger.info('[PATCH] transformers.integrations.moe._grouped_mm -> _grouped_mm_npu')
Expand Down Expand Up @@ -865,24 +915,34 @@ def apply_npu_patch(model=None) -> None:
- SDPA Attention compatibility fixes
- Flash Linear Attention (FLA) for Qwen3.5

When ``model`` is **not** provided, the GMM patch is **skipped** by default
(EP cannot be detected without a model instance).

When ``model`` is provided, the GMM patch is evaluated with EP detection:
- EP enabled → apply GMM patch (efficient on small sharded weights).
- EP not enabled → skip GMM patch (avoid ~8x contiguous-copy overhead).

Environment variables:
- ``TWINKLE_NPU_PATCH``: overall switch (``1``/``0``)
- ``TWINKLE_NPU_FUSED_OPS``: fused ops switch (``1``/``0``)
- ``TWINKLE_NPU_MOE_PATCH``: MoE GMM switch (``1``/``0``)
- ``TWINKLE_NPU_GMM_PATCH``: MoE GMM switch (``1``/``0``/unset).
When unset: skip the patch by default.
When ``1``: EP-aware — patch is applied **only if EP is enabled**;
without EP the native grouped_mm or per-expert fallback is used
(avoiding ~8x overhead from contiguous copies).
When ``0``: disable the patch regardless.
- ``TWINKLE_NPU_FLA``: FLA switch (``1``/``0``)
- ``TWINKLE_NPU_GATED_RMSNorm_FP32``: force FP32 in Gated RMSNorm (``1``/``0``)

Args:
model: Optional model instance. Required for FLA to traverse and
replace per-instance ``chunk_gated_delta_rule`` bindings.
model: Optional model instance. If not provided, GMM patch is skipped.
If provided, GMM patch is evaluated with EP detection on the model.
"""
global _NPU_PATCH_APPLIED

if not _is_env_enabled('TWINKLE_NPU_PATCH', default=True):
return

moe_enabled = _is_env_enabled('TWINKLE_NPU_MOE_PATCH', default=True)

if _NPU_PATCH_APPLIED:
logger.debug('[NPU] Patches already applied, skipping.')
return
Expand All @@ -893,8 +953,7 @@ def apply_npu_patch(model=None) -> None:
logger.warning('torch_npu not available. Skipping NPU patches.')
return

if moe_enabled:
apply_hf_moe_grouped_mm_patch()
_apply_hf_moe_grouped_mm_patch(model)

_apply_all_fused_ops(model)

Expand Down
Loading