Skip to content
Open
136 changes: 136 additions & 0 deletions tests/jax/test_fused_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -559,3 +559,139 @@ def loss_fused_fn(probs_):
assert jnp.allclose(
grad_ref, grad_fused, atol=1e-5, rtol=1e-5
), f"Grad mismatch: max diff = {jnp.abs(grad_ref - grad_fused).max()}"


# =============================================================================
# Test: routing_map BITMAP_U8 vs BYTEMAP parity (fwd + bwd)
# =============================================================================


def _bytemap_to_bitmap_u8(bytemap):
"""Reference packer: bool[T, E] -> uint8[T, ceil(E/8)] LSB-first."""
import numpy as np

flat = np.asarray(bytemap).astype(np.uint8)
return np.packbits(flat, axis=-1, bitorder="little")


@pytest_parametrize_wrapper("dtype", DTYPES)
@pytest_parametrize_wrapper(
"num_tokens,num_experts,topk",
TOPK_CASES,
)
@pytest_parametrize_wrapper("score_function", SCORE_FUNCTIONS)
@pytest.mark.triton
def test_topk_bitmap_vs_bytemap(dtype, num_tokens, num_experts, topk, score_function):
"""fused_topk_with_score_function should produce the same probs and an
LSB-packed bitmap routing_map when routing_map_format=BITMAP_U8, and
backward gradients should match the bytemap path exactly."""
from transformer_engine.jax.router import RoutingMapFormat

logits = make_logits(num_tokens, num_experts, score_function, dtype)
expert_bias = None

fwd_byte = jax.jit(
partial(
fused_topk_with_score_function,
topk=topk,
score_function=score_function,
expert_bias=expert_bias,
routing_map_format=RoutingMapFormat.BYTEMAP,
)
)
fwd_bit = jax.jit(
partial(
fused_topk_with_score_function,
topk=topk,
score_function=score_function,
expert_bias=expert_bias,
routing_map_format=RoutingMapFormat.BITMAP_U8,
)
)
probs_byte, routing_map_byte = fwd_byte(logits)
probs_bit, routing_map_bit = fwd_bit(logits)

assert probs_byte.dtype == probs_bit.dtype
assert jnp.array_equal(probs_byte, probs_bit), "Probs must be identical across formats"

packed_expected = _bytemap_to_bitmap_u8(routing_map_byte)
assert routing_map_bit.shape == (
num_tokens,
(num_experts + 7) // 8,
), f"Bitmap shape {routing_map_bit.shape} != ({num_tokens}, {(num_experts + 7) // 8})"
assert routing_map_bit.dtype == jnp.uint8
assert jnp.array_equal(
routing_map_bit, packed_expected
), "Bitmap routing_map disagrees with np.packbits(bytemap, bitorder='little')"

# Backward parity: grad of probs.sum() must be identical for both formats.
def loss_byte(logits_):
p, _ = fused_topk_with_score_function(
logits_,
topk,
score_function=score_function,
expert_bias=expert_bias,
routing_map_format=RoutingMapFormat.BYTEMAP,
)
return p.sum()

def loss_bit(logits_):
p, _ = fused_topk_with_score_function(
logits_,
topk,
score_function=score_function,
expert_bias=expert_bias,
routing_map_format=RoutingMapFormat.BITMAP_U8,
)
return p.sum()

grad_byte = jax.jit(jax.grad(loss_byte))(logits)
grad_bit = jax.jit(jax.grad(loss_bit))(logits)
assert jnp.allclose(grad_byte, grad_bit, atol=0.0, rtol=0.0), (
"Backward grad must be bit-identical across routing_map_format; "
f"max diff = {jnp.abs(grad_byte - grad_bit).max()}"
)


@pytest_parametrize_wrapper("dtype", DTYPES)
@pytest_parametrize_wrapper(
"num_tokens,num_experts,topk",
SCORE_AUX_LOSS_CASES,
)
@pytest_parametrize_wrapper("score_function", SCORE_FUNCTIONS)
@pytest.mark.triton
def test_score_for_aux_loss_bitmap_vs_bytemap(dtype, num_tokens, num_experts, topk, score_function):
"""compute_aux_scores=True path: bitmap routing_map must equal LSB-packed
bytemap; scores must be bitwise identical across formats."""
from transformer_engine.jax.router import RoutingMapFormat

logits = make_logits(num_tokens, num_experts, score_function, dtype)

fwd_byte = jax.jit(
partial(
fused_topk_with_score_function,
topk=topk,
score_function=score_function,
compute_aux_scores=True,
routing_map_format=RoutingMapFormat.BYTEMAP,
)
)
fwd_bit = jax.jit(
partial(
fused_topk_with_score_function,
topk=topk,
score_function=score_function,
compute_aux_scores=True,
routing_map_format=RoutingMapFormat.BITMAP_U8,
)
)
scores_byte, routing_map_byte = fwd_byte(logits)
scores_bit, routing_map_bit = fwd_bit(logits)

assert jnp.array_equal(scores_byte, scores_bit), "Scores must be identical across formats"
packed_expected = _bytemap_to_bitmap_u8(routing_map_byte)
assert routing_map_bit.shape == (num_tokens, (num_experts + 7) // 8)
assert routing_map_bit.dtype == jnp.uint8
assert jnp.array_equal(
routing_map_bit, packed_expected
), "Bitmap routing_map (aux-loss path) disagrees with packed bytemap"
140 changes: 140 additions & 0 deletions tests/pytorch/test_fused_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import torch
from typing import Optional
from transformer_engine.pytorch.router import (
RoutingMapFormat,
fused_topk_with_score_function,
fused_compute_score_for_moe_aux_loss,
fused_moe_aux_loss,
Expand Down Expand Up @@ -458,6 +459,145 @@ def test_fused_moe_aux_loss(dtype, num_tokens, num_experts, topk):
torch.testing.assert_close(probs.grad, probs_clone.grad, atol=atol, rtol=rtol)


# =============================================================================
# Test: routing_map BITMAP_U8 vs BYTEMAP parity (fwd + bwd)
# Mirrors tests/jax/test_fused_router.py::test_topk_bitmap_vs_bytemap.
# =============================================================================


def _bytemap_to_bitmap_u8(bytemap: torch.Tensor) -> torch.Tensor:
"""Reference packer: bool[T, E] -> uint8[T, ceil(E/8)] LSB-first.

Matches numpy.packbits(..., bitorder='little'), which is what the JAX-side
parity test uses.
"""
flat = bytemap.to(torch.uint8).cpu().numpy()
import numpy as np

return torch.from_numpy(np.packbits(flat, axis=-1, bitorder="little")).to(bytemap.device)


@pytest.mark.parametrize("dtype", [torch.float32])
@pytest.mark.parametrize(
"num_tokens,num_experts,topk",
[(128, 32, 4), (256, 128, 8), (256, 130, 8), (128, 1024, 16)],
)
@pytest.mark.parametrize("score_function", ["softmax", "sigmoid", "sqrtsoftplus"])
def test_topk_bitmap_vs_bytemap(dtype, num_tokens, num_experts, topk, score_function):
"""fused_topk_with_score_function should produce identical probs and an
LSB-packed bitmap routing_map when routing_map_format=BITMAP_U8, and
backward gradients should match the bytemap path exactly."""
if topk >= num_experts:
pytest.skip(f"topk ({topk}) >= num_experts ({num_experts})")
if score_function in ("sigmoid", "sqrtsoftplus"):
offset = torch.arange(-num_tokens // 2, num_tokens // 2, dtype=dtype, device="cuda") * 1e-4
logits = (
torch.arange(-num_experts // 2, num_experts // 2, device="cuda", dtype=dtype) * 1e-2
)
logits = logits.unsqueeze(0).repeat(num_tokens, 1) + offset.unsqueeze(1)
else:
logits = (
torch.arange(
-num_tokens * num_experts // 2,
num_tokens * num_experts // 2,
device="cuda",
dtype=dtype,
)
* 1e-4
)
logits = logits.view(num_tokens, num_experts)

logits_byte = logits.detach().clone().requires_grad_(True)
logits_bit = logits.detach().clone().requires_grad_(True)

probs_byte, routing_map_byte = fused_topk_with_score_function(
logits=logits_byte,
topk=topk,
use_pre_softmax=False,
num_groups=None,
group_topk=None,
scaling_factor=None,
score_function=score_function,
expert_bias=None,
routing_map_format=RoutingMapFormat.BYTEMAP,
)
probs_bit, routing_map_bit = fused_topk_with_score_function(
logits=logits_bit,
topk=topk,
use_pre_softmax=False,
num_groups=None,
group_topk=None,
scaling_factor=None,
score_function=score_function,
expert_bias=None,
routing_map_format=RoutingMapFormat.BITMAP_U8,
)

assert probs_byte.dtype == probs_bit.dtype
torch.testing.assert_close(probs_byte, probs_bit, atol=0.0, rtol=0.0)

expected_shape = (num_tokens, (num_experts + 7) // 8)
assert (
routing_map_bit.shape == expected_shape
), f"Bitmap shape {tuple(routing_map_bit.shape)} != {expected_shape}"
assert routing_map_bit.dtype == torch.uint8
assert routing_map_byte.dtype == torch.bool

packed_expected = _bytemap_to_bitmap_u8(routing_map_byte)
torch.testing.assert_close(routing_map_bit, packed_expected, atol=0, rtol=0)

# Backward parity: grad of probs.sum() must be bit-identical across formats.
probs_byte.sum().backward()
probs_bit.sum().backward()
torch.testing.assert_close(logits_byte.grad, logits_bit.grad, atol=0.0, rtol=0.0)


@pytest.mark.parametrize("dtype", [torch.float32])
@pytest.mark.parametrize(
"num_tokens,num_experts,topk",
[(128, 32, 4), (256, 128, 8), (256, 130, 8)],
)
@pytest.mark.parametrize("score_function", ["softmax", "sigmoid", "sqrtsoftplus"])
def test_score_for_aux_loss_bitmap_vs_bytemap(dtype, num_tokens, num_experts, topk, score_function):
"""fused_compute_score_for_moe_aux_loss: bitmap routing_map must equal
LSB-packed bytemap; scores must be bit-identical across formats."""
if topk >= num_experts:
pytest.skip(f"topk ({topk}) >= num_experts ({num_experts})")
offset = torch.arange(-num_tokens // 2, num_tokens // 2, dtype=dtype, device="cuda") * 1e-4
logits = torch.arange(-num_experts // 2, num_experts // 2, device="cuda", dtype=dtype) * 1e-2
logits = logits.unsqueeze(0).repeat(num_tokens, 1) + offset.unsqueeze(1)

logits_byte = logits.detach().clone().requires_grad_(True)
logits_bit = logits.detach().clone().requires_grad_(True)

routing_map_byte, scores_byte = fused_compute_score_for_moe_aux_loss(
logits=logits_byte,
topk=topk,
score_function=score_function,
routing_map_format="bytemap",
)
routing_map_bit, scores_bit = fused_compute_score_for_moe_aux_loss(
logits=logits_bit,
topk=topk,
score_function=score_function,
routing_map_format="bitmap_u8",
)

torch.testing.assert_close(scores_byte, scores_bit, atol=0.0, rtol=0.0)

expected_shape = (num_tokens, (num_experts + 7) // 8)
assert routing_map_bit.shape == expected_shape
assert routing_map_bit.dtype == torch.uint8
assert routing_map_byte.dtype == torch.bool
packed_expected = _bytemap_to_bitmap_u8(routing_map_byte)
torch.testing.assert_close(routing_map_bit, packed_expected, atol=0, rtol=0)

# Backward parity through scores.
scores_byte.sum().backward()
scores_bit.sum().backward()
torch.testing.assert_close(logits_byte.grad, logits_bit.grad, atol=0.0, rtol=0.0)


def profile_topk_softmax(
dtype,
num_tokens,
Expand Down
Loading
Loading