Skip to content
Merged
142 changes: 115 additions & 27 deletions tests/pytorch/test_fusible_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import transformer_engine.pytorch.ops as te_ops
from transformer_engine.pytorch.ops._common import (
_cudnn_frontend_version_supported,
is_glu_activation,
)

from transformer_engine.pytorch.ops.fused import (
Expand Down Expand Up @@ -2480,6 +2481,59 @@ def test_scaled_swiglu(
assert_close_grads(x_test, x_ref, **tols)
assert_close_grads(scales_test, scales_ref, **tols)

@pytest.mark.parametrize("in_shape", ((71, 192), (5, 7, 128)))
@pytest.mark.parametrize("input_requires_grad", (False, True))
@pytest.mark.parametrize("scales_requires_grad", (False, True))
def test_scaled_srelu(
self,
*,
in_shape: Iterable[int],
dtype: torch.dtype = torch.float32,
device: torch.device = "cuda",
input_requires_grad: bool,
scales_requires_grad: bool,
) -> None:
"""SReLU with post-scale"""

# Random data
x_ref, x_test = make_reference_and_test_tensors(
in_shape,
test_dtype=dtype,
test_device=device,
requires_grad=input_requires_grad,
)
scales_ref, scales_test = make_reference_and_test_tensors(
in_shape[:-1],
test_dtype=dtype,
test_device=device,
requires_grad=scales_requires_grad,
)
dy_ref, dy_test = make_reference_and_test_tensors(
in_shape,
test_dtype=dtype,
test_device=device,
requires_grad=False,
)

# Plain PyTorch implementation
y = torch.nn.functional.relu(x_ref).square()
y_ref = scales_ref.unsqueeze(-1) * y
if input_requires_grad or scales_requires_grad:
y_ref.backward(dy_ref)

# Implementation with fusible operation
op = te_ops.ScaledSReLU()
y_test = op(x_test, scales_test)
if input_requires_grad or scales_requires_grad:
y_test.backward(dy_test)

# Check results
tols = dtype_tols(dtype)
y_test = y_test.to(dtype=torch.float64, device="cpu")
assert_close(y_test, y_ref, **tols)
assert_close_grads(x_test, x_ref, **tols)
assert_close_grads(scales_test, scales_ref, **tols)

def test_interleaved_scaled_swiglu(self):
"""SwiGLU with post-scale and block interleaved input format"""
self.test_scaled_swiglu(
Expand All @@ -2489,6 +2543,15 @@ def test_interleaved_scaled_swiglu(self):
scales_requires_grad=True,
)

@pytest.mark.parametrize(
"op_cls",
(te_ops.ScaledSwiGLU, te_ops.ScaledSReLU, te_ops.ScaledClampedQGeGLU),
)
def test_scaled_activation_recompute_in_mlp_config(self, op_cls) -> None:
"""Scaled activations expose a per-op recompute knob."""
assert op_cls().activation_recompute_in_mlp is False
assert op_cls(activation_recompute_in_mlp=True).activation_recompute_in_mlp is True

@pytest.mark.parametrize("in_shape", ((71, 192), (5, 7, 128)))
@pytest.mark.parametrize("input_requires_grad", (False, True))
@pytest.mark.parametrize("scales_requires_grad", (False, True))
Expand Down Expand Up @@ -3570,7 +3633,9 @@ def test_layernorm_mlp(
@pytest.mark.parametrize("glu_interleave_size", (None, 32))
@pytest.mark.parametrize("delay_wgrad_compute", (False, True))
@pytest.mark.parametrize("hidden_size", (128, 256))
@pytest.mark.parametrize("activation", ("scaled_swiglu", "scaled_clamped_qgeglu"))
@pytest.mark.parametrize(
"activation", ("scaled_swiglu", "scaled_clamped_qgeglu", "scaled_srelu")
)
def test_grouped_mlp(
self,
*,
Expand All @@ -3588,7 +3653,7 @@ def test_grouped_mlp(
delay_wgrad_compute: bool,
activation: str,
) -> None:
"""GroupedLinear + ScaledSwiGLU / ScaledClampedQGeGLU + GroupedLinear"""
"""GroupedLinear + scaled activation + GroupedLinear"""

# Split sizes
split_sizes = [split_alignment * (i) for i in range(group_size)]
Expand All @@ -3601,16 +3666,30 @@ def test_grouped_mlp(

# Skip invalid configurations
with_quantization = quantization is not None
if activation == "scaled_swiglu":
scaled_act = te_ops.ScaledSwiGLU(glu_interleave_size=glu_interleave_size)
elif activation == "scaled_clamped_qgeglu":
scaled_act = te_ops.ScaledClampedQGeGLU(glu_interleave_size=glu_interleave_size)
elif activation == "scaled_srelu":
scaled_act = te_ops.ScaledSReLU()
else:
raise ValueError(f"Unexpected grouped MLP activation ({activation})")
activation_is_glu = is_glu_activation(scaled_act)
maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=dtype)
if single_grouped_weight and quantization != "mxfp8":
pytest.skip("single_grouped_weight is only supported for MXFP8 quantization")
if single_grouped_bias and not bias:
pytest.skip("single_grouped_bias requires bias=True")
if with_quantization and dtype not in (torch.bfloat16, torch.float16):
pytest.skip("Quantized group GEMM is only supported with BF16/FP16")
if not activation_is_glu and quantization != "mxfp8":
pytest.skip("Scaled unary grouped MLP fusion is only supported with MXFP8")
if not activation_is_glu and glu_interleave_size is not None:
pytest.skip("Unary activations do not use GLU interleaving")
if quantization == "nvfp4" and activation == "scaled_clamped_qgeglu" and bias:
# TODO: ksivaman: Need to debug numerics for this case.
pytest.skip("Bias/dbias not yet supported in NVFP4 fused grouped MLP with GeGLU")
fc1_out_features = 2 * hidden_size if activation_is_glu else hidden_size

# Random data
x_ref, x_test = make_reference_and_test_tensors(
Expand Down Expand Up @@ -3641,7 +3720,7 @@ def test_grouped_mlp(
fc2_bs_ref, fc2_bs_test = [], []
for _ in range(group_size):
fc1_w_ref, fc1_w_test = make_reference_and_test_tensors(
(2 * hidden_size, hidden_size),
(fc1_out_features, hidden_size),
min=-0.25,
max=0.25,
quantization=quantization,
Expand All @@ -3660,7 +3739,7 @@ def test_grouped_mlp(
fc2_b_ref, fc2_b_test = None, None
if bias:
fc1_b_ref, fc1_b_test = make_reference_and_test_tensors(
(2 * hidden_size,),
(fc1_out_features,),
min=-0.5,
max=0.5,
test_dtype=dtype,
Expand Down Expand Up @@ -3689,7 +3768,7 @@ def test_grouped_mlp(
for group_idx in range(group_size):
x = xs[group_idx]
x = torch.nn.functional.linear(x, fc1_ws_ref[group_idx], bias=fc1_bs_ref[group_idx])
if glu_interleave_size is not None:
if activation_is_glu and glu_interleave_size is not None:
x = x.reshape(
-1,
2 * hidden_size // (2 * glu_interleave_size),
Expand All @@ -3698,15 +3777,20 @@ def test_grouped_mlp(
)
x = x.transpose(1, 2)
x = x.reshape(-1, 2 * hidden_size)
x1, x2 = x.chunk(2, dim=-1)
if activation == "scaled_swiglu":
x1, x2 = x.chunk(2, dim=-1)
x = torch.nn.functional.silu(x1) * x2
else:
elif activation == "scaled_clamped_qgeglu":
x1, x2 = x.chunk(2, dim=-1)
lim = torch.tensor(7.0, device=x1.device, dtype=x1.dtype)
geglu_alpha = 1.702
x1c = torch.minimum(x1, lim)
x2c = torch.clamp(x2, -lim, lim)
x = (x2c + 1) * (x1c * torch.sigmoid(geglu_alpha * x1c))
elif activation == "scaled_srelu":
x = torch.nn.functional.relu(x).square()
else:
raise ValueError(f"Unexpected grouped MLP activation ({activation})")
x = x * probs[group_idx].unsqueeze(-1)
x = torch.nn.functional.linear(x, fc2_ws_ref[group_idx])
if bias:
Expand All @@ -3717,16 +3801,11 @@ def test_grouped_mlp(

# Construct operations
recipe = make_recipe(quantization)
scaled_act = (
te_ops.ScaledSwiGLU(glu_interleave_size=glu_interleave_size)
if activation == "scaled_swiglu"
else te_ops.ScaledClampedQGeGLU(glu_interleave_size=glu_interleave_size)
)
with te.quantized_model_init(enabled=with_quantization, recipe=recipe):
fc1 = te_ops.GroupedLinear(
group_size,
hidden_size,
2 * hidden_size,
fc1_out_features,
bias=bias,
device=device,
dtype=dtype,
Expand Down Expand Up @@ -3810,22 +3889,31 @@ def test_grouped_mlp(
if (
quantization == "mxfp8"
and dtype in (torch.bfloat16, torch.float16)
and glu_interleave_size == 32
and (
(not activation_is_glu and glu_interleave_size is None)
or (activation_is_glu and glu_interleave_size == 32)
)
and _cudnn_frontend_version_supported()
):
if te_ops.fused.ForwardGroupedMLP_CuTeGEMMSwiGLU_MXFP8.is_supported():
if activation_is_glu:
forward_cls = te_ops.fused.ForwardGroupedMLP_CuTeGEMMGLU_MXFP8
backward_cls = te_ops.fused.BackwardGroupedMLP_CuTeGEMMDGLU_MXFP8
else:
forward_cls = te_ops.fused.ForwardGroupedMLP_CuTeGEMMUnary_MXFP8
backward_cls = te_ops.fused.BackwardGroupedMLP_CuTeGEMMDUnary_MXFP8
if forward_cls.is_supported():
forward_ops = module._module_groups[0]._forward_ops
assert len(forward_ops) == 1
assert isinstance(
forward_ops[0][0],
te_ops.fused.ForwardGroupedMLP_CuTeGEMMSwiGLU_MXFP8,
forward_cls,
)
if te_ops.fused.BackwardGroupedMLP_CuTeGEMMDSwiGLU_MXFP8.is_supported():
if backward_cls is not None and backward_cls.is_supported():
backward_ops = module._module_groups[0]._backward_ops
assert len(backward_ops) == 1
assert isinstance(
backward_ops[0][0],
te_ops.fused.BackwardGroupedMLP_CuTeGEMMDSwiGLU_MXFP8,
backward_cls,
)

# Loose tols for sanity checking
Expand Down Expand Up @@ -3910,9 +3998,9 @@ def test_grouped_mlp_single_weight_numerics(
) -> None:
"""single_grouped_weight=True/False should match exactly for fused MXFP8 grouped MLP."""

if not te_ops.fused.ForwardGroupedMLP_CuTeGEMMSwiGLU_MXFP8.is_supported():
if not te_ops.fused.ForwardGroupedMLP_CuTeGEMMGLU_MXFP8.is_supported():
pytest.skip("MXFP8 fused grouped MLP forward is not supported on this system")
if not te_ops.fused.BackwardGroupedMLP_CuTeGEMMDSwiGLU_MXFP8.is_supported():
if not te_ops.fused.BackwardGroupedMLP_CuTeGEMMDGLU_MXFP8.is_supported():
pytest.skip("MXFP8 fused grouped MLP backward is not supported on this system")

split_sizes = [split_alignment * (i + 1) for i in range(group_size)]
Expand Down Expand Up @@ -4014,12 +4102,12 @@ def _run_case(single_grouped_weight: bool) -> tuple[torch.Tensor, ...]:
assert len(forward_ops) == 1
assert isinstance(
forward_ops[0][0],
te_ops.fused.ForwardGroupedMLP_CuTeGEMMSwiGLU_MXFP8,
te_ops.fused.ForwardGroupedMLP_CuTeGEMMGLU_MXFP8,
)
assert len(backward_ops) == 1
assert isinstance(
backward_ops[0][0],
te_ops.fused.BackwardGroupedMLP_CuTeGEMMDSwiGLU_MXFP8,
te_ops.fused.BackwardGroupedMLP_CuTeGEMMDGLU_MXFP8,
)

if single_grouped_weight:
Expand Down Expand Up @@ -4132,9 +4220,9 @@ def test_grouped_mlp_overwrite_main_grad(
that read ``.grad`` don't see stale bytes from the cached dummy).
"""

if not te_ops.fused.ForwardGroupedMLP_CuTeGEMMSwiGLU_MXFP8.is_supported():
if not te_ops.fused.ForwardGroupedMLP_CuTeGEMMGLU_MXFP8.is_supported():
pytest.skip("MXFP8 fused grouped MLP forward is not supported on this system")
if not te_ops.fused.BackwardGroupedMLP_CuTeGEMMDSwiGLU_MXFP8.is_supported():
if not te_ops.fused.BackwardGroupedMLP_CuTeGEMMDGLU_MXFP8.is_supported():
pytest.skip("MXFP8 fused grouped MLP backward is not supported on this system")

recipe = make_recipe("mxfp8")
Expand Down Expand Up @@ -4266,7 +4354,7 @@ def test_grouped_mlp_cuda_graph_safe_mxfp8(
) -> None:
"""Grouped MLP forward+backward should be CUDA graph capturable (MXFP8)."""

if not te_ops.fused.ForwardGroupedMLP_CuTeGEMMSwiGLU_MXFP8.is_supported():
if not te_ops.fused.ForwardGroupedMLP_CuTeGEMMGLU_MXFP8.is_supported():
pytest.skip("MXFP8 fused grouped MLP is not supported on this system")
if dtype not in (torch.bfloat16, torch.float16):
pytest.skip("MXFP8 fused grouped MLP is only supported with BF16/FP16")
Expand Down Expand Up @@ -4408,12 +4496,12 @@ def train_step(
assert len(forward_ops) == 1
assert isinstance(
forward_ops[0][0],
te_ops.fused.ForwardGroupedMLP_CuTeGEMMSwiGLU_MXFP8,
te_ops.fused.ForwardGroupedMLP_CuTeGEMMGLU_MXFP8,
)
assert len(backward_ops) == 1
assert isinstance(
backward_ops[0][0],
te_ops.fused.BackwardGroupedMLP_CuTeGEMMDSwiGLU_MXFP8,
te_ops.fused.BackwardGroupedMLP_CuTeGEMMDGLU_MXFP8,
)

fresh_x = torch.randn_like(static_x)
Expand Down
Loading
Loading