diff --git a/tests/pytorch/distributed/fsdp2_tests/run_fsdp2_fused_adam.py b/tests/pytorch/distributed/fsdp2_tests/run_fsdp2_fused_adam.py index 42df06ed7f..dff5114125 100644 --- a/tests/pytorch/distributed/fsdp2_tests/run_fsdp2_fused_adam.py +++ b/tests/pytorch/distributed/fsdp2_tests/run_fsdp2_fused_adam.py @@ -148,12 +148,6 @@ def test_fused_adam_fp8_master_weights(recipe_name): """ recipe = get_recipe_from_string(recipe_name) - if recipe_name == "NVFP4BlockScaling": - pytest.xfail( - f"{recipe_name}: quantized_model_init and FSDP2 is not currently supported, since the " - "block tensor is dequantized before we flatten it for FSDP2." - ) - world_size, device = _get_dist_info() model = _build_model(fp8_init=True, recipe=recipe) diff --git a/tests/pytorch/distributed/fsdp2_tests/run_fsdp2_model.py b/tests/pytorch/distributed/fsdp2_tests/run_fsdp2_model.py index fce565ed9a..c84b91453a 100644 --- a/tests/pytorch/distributed/fsdp2_tests/run_fsdp2_model.py +++ b/tests/pytorch/distributed/fsdp2_tests/run_fsdp2_model.py @@ -224,12 +224,16 @@ def _check_fp8_fsdp2_allgather(model): if device_mesh.ndim > 1 else device_mesh.get_group() ) - # Perform manual allgather on local_tensor. zeros_like will create hp tensor since torch_dispatch - # for local_tensor will go down the dequantization route. + # Dequantize first, then create plain-tensor buffers for the manual + # all-gather. Using zeros_like(local_tensor) directly would return a + # QuantizedTensor for types like NVFP4Tensor (whose __torch_dispatch__ + # handles empty_like/zero_ and returns a new NVFP4Tensor), causing a + # dtype mismatch with the bfloat16 source. + deq_local = local_tensor.dequantize() gathered_tensor = [ - torch.zeros_like(local_tensor) for _ in range(dist.get_world_size(group=dist_group)) + torch.zeros_like(deq_local) for _ in range(dist.get_world_size(group=dist_group)) ] - dist.all_gather(gathered_tensor, local_tensor.dequantize(), group=dist_group) + dist.all_gather(gathered_tensor, deq_local, group=dist_group) full_tensor = torch.cat(gathered_tensor, dim=0) fp32_allgathered_params[name] = full_tensor # FP8 allgather using FSDP2 @@ -363,8 +367,19 @@ def _train(args): @pytest.mark.parametrize("fp8_init", [False, True]) @pytest.mark.parametrize("layer_type", ["LayerNormLinear", "TransformerLayer"]) def test_distributed(recipe_name, fp8_init, sharding_dims, layer_type): - if recipe_name in ("Float8BlockScaling", "NVFP4BlockScaling") and fp8_init: - pytest.xfail(f"{recipe_name} + fp8_init: test_fp8_fsdp2_allgather is currently failing.") + if recipe_name == "Float8BlockScaling" and fp8_init: + pytest.xfail( + "Float8BlockScaling + fp8_init: scale inverse padding is not handled " + "correctly during FSDP2 all-gather slice ops." + ) + if recipe_name == "NVFP4BlockScaling" and fp8_init: + pytest.xfail( + "NVFP4BlockScaling + fp8_init: _check_fp8_fsdp2_allgather numerical " + "comparison fails — the FSDP2 allgather path (pack → unpad scales → " + "allgather → repad → dequantize) produces small differences vs the " + "manual dequantize-then-allgather path (max abs diff ~1.7e-4). " + "NVFP4 + FSDP2 training is validated by run_fsdp2_fused_adam.py." + ) torch.manual_seed(42) torch.cuda.manual_seed(42) diff --git a/tests/pytorch/test_nvfp4_fsdp2_hooks.py b/tests/pytorch/test_nvfp4_fsdp2_hooks.py new file mode 100644 index 0000000000..669d94fc25 --- /dev/null +++ b/tests/pytorch/test_nvfp4_fsdp2_hooks.py @@ -0,0 +1,288 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Unit tests for NVFP4Tensor FSDP2 all-gather hooks. + +These tests verify the pre/post all-gather round-trip logic on a single GPU +without requiring torchrun or multi-GPU setup. +""" + +import math +from typing import List, Tuple + +import pytest +import torch + +import transformer_engine.pytorch as te +from transformer_engine.pytorch import ( + NVFP4Quantizer, + NVFP4Tensor, +) +from transformer_engine.pytorch.utils import round_up_to_nearest_multiple +from transformer_engine.pytorch.constants import NVFP4_BLOCK_SCALING_SIZE + +nvfp4_available, reason_for_no_nvfp4 = te.is_nvfp4_available(return_reason=True) + +# Shapes that exercise various M/K combinations: +# - (512, 256): both dims cleanly divisible by 128 +# - (640, 128): M not a multiple of 128*2 but divisible by 16 +# - (256, 1024): K > M +_test_shapes: List[Tuple[int, int]] = [ + (512, 256), + (640, 128), + (256, 1024), +] + + +def _make_nvfp4_tensor(shape: Tuple[int, int]) -> NVFP4Tensor: + """Create an NVFP4Tensor from random BF16 data.""" + quantizer = NVFP4Quantizer( + rowwise=True, + columnwise=True, + with_rht=False, + with_post_rht_amax=False, + with_2d_quantization=True, + stochastic_rounding=False, + with_random_sign_mask=False, + ) + src = torch.randn(shape, dtype=torch.bfloat16, device="cuda") + return quantizer(src) + + +def _simulate_all_gather( + sharded_tensors: Tuple[torch.Tensor, ...], + world_size: int, +) -> Tuple[torch.Tensor, ...]: + """Simulate FSDP2 all-gather by concatenating shards along dim0.""" + return tuple(torch.cat([t] * world_size, dim=0) for t in sharded_tensors) + + +@pytest.mark.skipif(not nvfp4_available, reason=reason_for_no_nvfp4) +class TestNVFP4FSDP2Hooks: + """Tests for fsdp_pre_all_gather / fsdp_post_all_gather round-trip.""" + + @classmethod + def setup_class(cls) -> None: + torch.manual_seed(42) + torch.cuda.manual_seed(42) + + @pytest.mark.parametrize("shape", _test_shapes) + @pytest.mark.parametrize("world_size", [2, 4]) + def test_round_trip_shapes(self, shape: Tuple[int, int], world_size: int): + """Verify that pre_all_gather -> all_gather -> post_all_gather produces correct shapes.""" + M, K = shape + shard_M = M // world_size + shard_shape = (shard_M, K) + + qt = _make_nvfp4_tensor(shard_shape) + + # Pre all-gather + sharded_tensors, metadata = qt.fsdp_pre_all_gather( + mesh=None, + orig_size=None, + contiguous_orig_stride=None, + module=None, + mp_policy=None, + ) + + # Only rowwise tensors are all-gathered; columnwise is derived locally + assert len(sharded_tensors) == 2, "Expected 2 tensors (rowwise data + scale only)" + + rowwise_data, rowwise_scale_inv = sharded_tensors + + # Rowwise data: (shard_M, K//2) — unmodified + assert rowwise_data.shape == (shard_M, K // 2) + # Rowwise scale: unpadded dim0 to shard_M + assert rowwise_scale_inv.shape[0] == shard_M + + # Simulate all-gather + all_gather_outputs = _simulate_all_gather(sharded_tensors, world_size) + + # Post all-gather + result, _ = qt.fsdp_post_all_gather( + all_gather_outputs, + metadata, + param_dtype=torch.bfloat16, + ) + + # Verify output is NVFP4Tensor with correct logical shape + assert isinstance(result, NVFP4Tensor) + assert tuple(result.shape) == (M, K) + + # Verify internal data shapes + assert result._rowwise_data.shape == (M, K // 2) + + expected_rowwise_scale_shape = ( + round_up_to_nearest_multiple(M, 128), + round_up_to_nearest_multiple(math.ceil(K / NVFP4_BLOCK_SCALING_SIZE), 4), + ) + assert result._rowwise_scale_inv.shape == expected_rowwise_scale_shape + + # Columnwise data derived locally via _create_columnwise() + assert result._columnwise_data.shape == (K, M // 2) + + expected_col_scale_shape = ( + round_up_to_nearest_multiple(K, 128), + round_up_to_nearest_multiple(math.ceil(M / NVFP4_BLOCK_SCALING_SIZE), 4), + ) + assert result._columnwise_scale_inv.shape == expected_col_scale_shape + + @pytest.mark.parametrize("shape", _test_shapes) + def test_round_trip_data_integrity(self, shape: Tuple[int, int]): + """Verify that data and dequantized values survive the pre -> all_gather -> post round-trip.""" + world_size = 2 + M, K = shape + shard_M = M // world_size + shard_shape = (shard_M, K) + + qt = _make_nvfp4_tensor(shard_shape) + + # Save original internal tensors for comparison + orig_rowwise_data = qt._rowwise_data.clone() + orig_rowwise_scale = qt._rowwise_scale_inv.clone() + orig_amax_row = qt._amax_rowwise.clone() + orig_amax_col = qt._amax_columnwise.clone() + orig_deq = qt.dequantize() + + # Pre all-gather + sharded_tensors, metadata = qt.fsdp_pre_all_gather( + mesh=None, + orig_size=None, + contiguous_orig_stride=None, + module=None, + mp_policy=None, + ) + + # Simulate all-gather (world_size copies — data from each "rank" is identical) + all_gather_outputs = _simulate_all_gather(sharded_tensors, world_size) + + # Post all-gather + result, _ = qt.fsdp_post_all_gather( + all_gather_outputs, + metadata, + param_dtype=torch.bfloat16, + ) + + # Since each "rank" has the same data, the full rowwise_data should be + # the original shard repeated world_size times + expected_rowwise_data = torch.cat([orig_rowwise_data] * world_size, dim=0) + assert torch.equal(result._rowwise_data, expected_rowwise_data) + + # Rowwise scale: each shard's unpadded scale is repeated, then repadded + # Check that the first shard_M rows of the scale match the original (unpadded) + assert torch.equal( + result._rowwise_scale_inv[:shard_M, :], + orig_rowwise_scale[:shard_M, :], + ) + + # Columnwise data is derived locally via _create_columnwise(), not all-gathered. + # Verify it was created and has the correct shape. + assert result._columnwise_data is not None + assert result._columnwise_data.shape == (K, M // 2) + assert result._columnwise_scale_inv is not None + + # Amax values passed through metadata — should be preserved + assert torch.equal(result._amax_rowwise, orig_amax_row) + assert torch.equal(result._amax_columnwise, orig_amax_col) + + # Dequantized values: the full tensor should dequantize to world_size copies of the shard + result_deq = result.dequantize() + expected_deq = torch.cat([orig_deq] * world_size, dim=0) + torch.testing.assert_close(result_deq, expected_deq) + + @pytest.mark.parametrize("shape", _test_shapes) + def test_in_place_update(self, shape: Tuple[int, int]): + """Verify the out= path (in-place update on subsequent iterations).""" + world_size = 2 + M, K = shape + shard_M = M // world_size + shard_shape = (shard_M, K) + + qt = _make_nvfp4_tensor(shard_shape) + + sharded_tensors, metadata = qt.fsdp_pre_all_gather( + mesh=None, + orig_size=None, + contiguous_orig_stride=None, + module=None, + mp_policy=None, + ) + all_gather_outputs = _simulate_all_gather(sharded_tensors, world_size) + + # First call: out=None -> creates new tensor + result, _ = qt.fsdp_post_all_gather( + all_gather_outputs, + metadata, + param_dtype=torch.bfloat16, + ) + first_deq = result.dequantize().clone() + + # Second call: out=result -> in-place update + result2, _ = qt.fsdp_post_all_gather( + all_gather_outputs, + metadata, + param_dtype=torch.bfloat16, + out=result, + ) + assert result2 is result # same object + torch.testing.assert_close(result2.dequantize(), first_deq) + + def test_swizzled_scales_rejected(self): + """Verify that GEMM-swizzled scales raise NotImplementedError.""" + shape = (512, 256) + quantizer = NVFP4Quantizer( + rowwise=True, + columnwise=True, + with_rht=False, + with_post_rht_amax=False, + with_2d_quantization=False, + stochastic_rounding=False, + with_random_sign_mask=False, + ) + quantizer.optimize_for_gemm = True + src = torch.randn(shape, dtype=torch.bfloat16, device="cuda") + qt = quantizer(src) + + if not qt._with_gemm_swizzled_scales: + pytest.skip( + "NVFP4Quantizer.optimize_for_gemm is not yet wired up in C++ " + "(see quantizer.cpp TODO). Test will be unskipped once supported." + ) + + with pytest.raises(NotImplementedError, match="GEMM-swizzled"): + qt.fsdp_pre_all_gather( + mesh=None, + orig_size=None, + contiguous_orig_stride=None, + module=None, + mp_policy=None, + ) + + +@pytest.mark.skipif(not nvfp4_available, reason=reason_for_no_nvfp4) +class TestNVFP4DispatchHandlers: + """Tests for as_strided, slice, and record_stream dispatch handlers.""" + + def test_as_strided_noop(self): + """as_strided with matching shape/strides returns NVFP4Tensor.""" + qt = _make_nvfp4_tensor((256, 128)) + M, K = qt.shape + result = torch.ops.aten.as_strided.default(qt, [M, K], [K, 1], 0) + assert isinstance(result, NVFP4Tensor) + assert tuple(result.shape) == (M, K) + + def test_slice_noop(self): + """slice covering full dimension returns NVFP4Tensor.""" + qt = _make_nvfp4_tensor((256, 128)) + M, K = qt.shape + result = torch.ops.aten.slice.Tensor(qt, 0, 0, M) + assert isinstance(result, NVFP4Tensor) + assert tuple(result.shape) == (M, K) + + def test_record_stream(self): + """record_stream completes without error.""" + qt = _make_nvfp4_tensor((256, 128)) + stream = torch.cuda.Stream() + result = torch.ops.aten.record_stream.default(qt, stream) + assert result is None diff --git a/transformer_engine/pytorch/tensor/nvfp4_tensor.py b/transformer_engine/pytorch/tensor/nvfp4_tensor.py index eb514d3a9e..91f9c81cb1 100644 --- a/transformer_engine/pytorch/tensor/nvfp4_tensor.py +++ b/transformer_engine/pytorch/tensor/nvfp4_tensor.py @@ -551,6 +551,122 @@ def get_usages(self) -> Dict[str, bool]: "columnwise": self._columnwise_data is not None, } + def fsdp_pre_all_gather(self, mesh, orig_size, contiguous_orig_stride, module, mp_policy): + """Called by FSDP2 before all-gather of weights. + + Only all-gathers rowwise data and scales. Columnwise data is derived + locally in post_all_gather via _create_columnwise(), halving the + all-gather communication volume. + """ + # pylint: disable=unused-argument + + if self._with_gemm_swizzled_scales: + raise NotImplementedError( + "FSDP2 is not supported for NVFP4Tensors with GEMM-swizzled scales." + ) + + shard_M = math.prod(self.shape[:-1]) + + assert shard_M % NVFP4_BLOCK_SCALING_SIZE == 0, ( + f"FSDP2 requires shard_M ({shard_M}) to be a multiple of " + f"NVFP4_BLOCK_SCALING_SIZE ({NVFP4_BLOCK_SCALING_SIZE}). " + "Adjust model dimensions or world size." + ) + + assert self._rowwise_data is not None, ( + "FSDP2 requires rowwise data, but _rowwise_data is None. " + "Ensure the NVFP4Quantizer was created with rowwise=True." + ) + + # Rowwise data: (shard_M, K//2) — M in dim0, pass as-is + rowwise_data = self._rowwise_data + # Rowwise scale: (round_up(shard_M, 128), inner) — unpad dim0 to shard_M + rowwise_scale_inv = self._rowwise_scale_inv + if rowwise_scale_inv is not None: + rowwise_scale_inv = rowwise_scale_inv[:shard_M, :] + + columnwise_usage = self._quantizer.columnwise_usage + if columnwise_usage: + assert self._quantizer.with_2d_quantization, ( + "FSDP2 columnwise usage requires 2D quantization to be enabled. " + "Ensure the NVFP4Quantizer was created with with_2d_quantization=True." + ) + + # Only all-gather rowwise tensors; columnwise will be derived locally + # via _create_columnwise() in post_all_gather. + sharded_tensors = (rowwise_data, rowwise_scale_inv) + + # Pass amax via metadata (scalar, same on all ranks — not all-gathered) + metadata = ( + self._fp4_dtype, + columnwise_usage, + self._amax_rowwise, + self._amax_columnwise, + self.shape[-1], + ) + return sharded_tensors, metadata + + def fsdp_post_all_gather( + self, + all_gather_outputs: Tuple[torch.Tensor, ...], + metadata, + param_dtype: torch.dtype, + *, + out: Optional[NVFP4Tensor] = None, + ): + """Called by FSDP2 after all-gather of weights. + + Repads rowwise scales and constructs the full NVFP4Tensor from + all-gathered rowwise data. Columnwise data is derived locally + via _create_columnwise() instead of being all-gathered. + """ + fp4_dtype, columnwise_usage, amax_rowwise, amax_columnwise, K = metadata + + # Only rowwise data+scales were all-gathered + rowwise_data, rowwise_scale_inv = all_gather_outputs[:2] + full_M = rowwise_data.shape[0] + + # Repad rowwise scale dim0 to round_up(full_M, 128) + if rowwise_scale_inv is not None: + target_m = round_up_to_nearest_multiple(full_M, 128) + current_m = rowwise_scale_inv.shape[0] + if current_m < target_m: + rowwise_scale_inv = torch.nn.functional.pad( + rowwise_scale_inv, (0, 0, 0, target_m - current_m) + ) + + logical_shape = (full_M, K) + + if out is not None: + # Update existing tensor in-place (subsequent iterations) + out._rowwise_data = rowwise_data + out._rowwise_scale_inv = rowwise_scale_inv + out._amax_rowwise = amax_rowwise + out._amax_columnwise = amax_columnwise + else: + # Construct new tensor (first iteration) + out = NVFP4Tensor( + shape=logical_shape, + dtype=param_dtype, + fp4_dtype=fp4_dtype, + rowwise_data=rowwise_data, + rowwise_scale_inv=rowwise_scale_inv, + columnwise_data=None, + columnwise_scale_inv=None, + amax_rowwise=amax_rowwise, + amax_columnwise=amax_columnwise, + quantizer=self._quantizer, + requires_grad=False, + with_gemm_swizzled_scales=False, + ) + + # Derive columnwise data locally via transpose instead of all-gathering it + if columnwise_usage: + out._create_columnwise() + + out._quantizer.set_usage(rowwise=True, columnwise=columnwise_usage) + return out, all_gather_outputs + @classmethod def __torch_dispatch__(cls, func, types, args, kwargs=None): @@ -564,6 +680,58 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): return tensor.detach() return tensor.view(shape) + # as_strided — FSDP2 applies this on the unsharded param. + # Only the identity case (same shape, contiguous strides, zero offset) is supported. + # Non-identity as_strided cannot fall through because NVFP4 does not support + # dequantization, so we raise explicitly rather than producing undefined behavior. + if func == aten.as_strided.default: + tensor = args[0] + shape = args[1] + strides = args[2] + storage_offset = args[3] if len(args) > 3 else 0 + if ( + len(shape) == len(strides) == 2 + and tuple(strides) == (shape[-1], 1) + and tuple(shape) == tuple(tensor.size()) + and storage_offset == 0 + ): + return NVFP4Tensor.make_like(tensor) + raise NotImplementedError( + "NVFP4Tensor does not support non-identity as_strided " + f"(shape={shape}, strides={strides}, storage_offset={storage_offset}, " + f"tensor.size()={tuple(tensor.size())})" + ) + + # slice — FSDP2 applies this for shard unpadding. + # When the slice covers the full dimension, return self. + if func == aten.slice.Tensor: + tensor = args[0] + dim = args[1] if len(args) > 1 else 0 + start = args[2] if len(args) > 2 else None + end = args[3] if len(args) > 3 else None + step = args[4] if len(args) > 4 else 1 + if ( + step == 1 + and (start is None or start == 0) + and (end is None or end >= tensor.size(dim)) + ): + return NVFP4Tensor.make_like(tensor) + + # record_stream — FSDP2 records streams on all-gathered tensors. + if func == torch.ops.aten.record_stream.default: + qt, stream = args + for t in ( + qt._rowwise_data, + qt._columnwise_data, + qt._rowwise_scale_inv, + qt._columnwise_scale_inv, + qt._amax_rowwise, + qt._amax_columnwise, + ): + if t is not None and t.is_cuda: + t.record_stream(stream) + return None + # NVFP4 dequantize not supported. Add manual support for needed funcs. if func in (aten.empty_like.default, aten.zero_.default): tensor = args[0]