From 82ea46b68c2fa160e8b26a4cb9094d0db8c99d4e Mon Sep 17 00:00:00 2001 From: Shiqing Fan Date: Sun, 17 May 2026 21:58:54 -0700 Subject: [PATCH 1/8] Generalized Tensor Parallelism (GTP) init commit Co-authored-by: Jieming Zhang Signed-off-by: Shiqing Fan --- tests/pytorch/distributed/test_gtp.py | 1335 +++++++++++++ tests/pytorch/distributed/test_tp_gtp.py | 427 +++++ transformer_engine/common/CMakeLists.txt | 1 + .../include/transformer_engine/recipe.h | 20 + .../common/recipe/multi_amax.cu | 274 +++ transformer_engine/pytorch/csrc/common.h | 21 +- transformer_engine/pytorch/csrc/extensions.h | 15 + .../pytorch/csrc/extensions/cast.cpp | 142 ++ .../pytorch/csrc/extensions/pybind.cpp | 10 + transformer_engine/pytorch/csrc/quantizer.cpp | 52 +- transformer_engine/pytorch/distributed.py | 138 +- transformer_engine/pytorch/module/base.py | 44 +- .../module/generalized_tensor_parallelism.py | 1692 +++++++++++++++++ .../pytorch/module/grouped_linear.py | 143 +- .../pytorch/module/layernorm_linear.py | 274 ++- transformer_engine/pytorch/module/linear.py | 138 +- 16 files changed, 4549 insertions(+), 177 deletions(-) create mode 100644 tests/pytorch/distributed/test_gtp.py create mode 100644 tests/pytorch/distributed/test_tp_gtp.py create mode 100644 transformer_engine/common/recipe/multi_amax.cu create mode 100644 transformer_engine/pytorch/module/generalized_tensor_parallelism.py diff --git a/tests/pytorch/distributed/test_gtp.py b/tests/pytorch/distributed/test_gtp.py new file mode 100644 index 0000000000..972af13762 --- /dev/null +++ b/tests/pytorch/distributed/test_gtp.py @@ -0,0 +1,1335 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Unit tests for Generalized Tensor Parallelism (GTP). + +Test groups +----------- +1. TestGTPWeightState – state-machine transitions (single-process) +2. TestGTPWeightCache – coat-check buffer pool (single-process) +3. TestGTPSharding – wrap_module_params_gtp: shard content + padding (multi-GPU) +4. TestWrapModuleParams – wrap_module_params_gtp: param replacement + weight_list (multi-GPU) +5. TestLinearGTP – Linear forward/backward numerical correctness (multi-GPU) +6. TestLayerNormLinearGTP – LayerNormLinear forward/backward smoke test (multi-GPU) +7. TestGroupedLinearGTP – GroupedLinear forward/backward smoke test (multi-GPU) +8. TestGTPPrefetchChain – linked-list next_w/prev_w wiring (multi-GPU) +9. TestGTPWgradRS – wgrad reduce-scatter shape + multi-layer deferred path (multi-GPU) +10. TestGTPMicrobatches – output consistency across microbatches (multi-GPU) +11. TestNVFP4LinearGTP – Linear + NVFP4 recipe: quantized shard setup, fwd/bwd (multi-GPU) +12. TestNVFP4GroupedLinearGTP – GroupedLinear + NVFP4 recipe: coalesced AG + fwd/bwd (multi-GPU) +13. TestMXFP8LinearGTP – Linear + MXFP8 recipe: quantized shard setup, fwd/bwd, padding (multi-GPU) +14. TestGTPConfig – update_config: valid/invalid keys (single-process) +15. TestGTPShardedParamProperties – shape computations, get_padded_shard, _strip_padding (single-process) +16. TestGTPCacheKey – _get_cache_key: expert vs non-expert, fwd vs bwd (single-process) +17. TestGTPCacheRelease – reserve/get/release pool semantics (single-process) +18. TestTagGTPParamsWithNames – _debug_name population on GTPShardedParam (single-process) +19. TestGTPGroupSizeOne – wrap_module_params_gtp no-op when gtp_group.size()==1 (single-process) +20. TestGTPPrefetchDisabled – weight_prefetch=False: single-pass forward still works (multi-GPU) +21. TestFuseWgradAccumulation – fuse_wgrad_accumulation=True: wgrad→main_grad (multi-GPU) +22. TestGTPGradAccumHook – main_grad updated after reduce-scatter backward (multi-GPU) + +Multi-GPU tests use torch.multiprocessing.spawn and are skipped when fewer +than the required CUDA devices are available. +""" + +import os +import socket + +import pytest +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +import torch.nn as nn + +import transformer_engine.pytorch as te +import transformer_engine.pytorch.module.generalized_tensor_parallelism as gtp_module +from transformer_engine.pytorch.module.generalized_tensor_parallelism import ( + GTPShardedParam, + GTPWeightCache, + GTPWeightState, + wrap_module_params_gtp, +) +from transformer_engine.pytorch import fp8_autocast, is_nvfp4_available, is_mxfp8_available +from transformer_engine.pytorch.quantization import FP8GlobalStateManager +from transformer_engine.pytorch.quantized_tensor import QuantizedTensor +from transformer_engine.common.recipe import NVFP4BlockScaling + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + +@pytest.fixture(autouse=True) +def reset_fp8_state(): + yield + FP8GlobalStateManager.reset() + + +@pytest.fixture(autouse=True) +def reset_gtp_globals(): + """Reset all GTP mutable class/module-level state between tests.""" + yield + GTPShardedParam._first_weight_flag = True + GTPShardedParam._pending_rs_weight = None + GTPShardedParam._chain_state = {} + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _free_port() -> int: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("", 0)) + return s.getsockname()[1] + + +def _dist_init(rank: int, world_size: int, port: int) -> None: + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = str(port) + dist.init_process_group(backend="nccl", rank=rank, world_size=world_size) + torch.cuda.set_device(rank) + + +def _run_distributed(fn, world_size: int, *args) -> None: + """Spawn `world_size` processes each running fn(rank, world_size, port, *args).""" + port = _free_port() + mp.spawn(fn, args=(world_size, port) + args, nprocs=world_size, join=True) + + +def _requires_multi_gpu(n: int = 4): + if torch.cuda.device_count() < n: + pytest.skip(f"Requires at least {n} CUDA devices") + + +def _requires_nvfp4(): + if not is_nvfp4_available(): + pytest.skip("NVFP4 not available (requires compute capability >= 10.0)") + + +# --------------------------------------------------------------------------- +# 1. GTPWeightState – state-machine transition tests +# --------------------------------------------------------------------------- + +class TestGTPWeightState: + + @staticmethod + def _param(): + return GTPShardedParam(torch.zeros(4, 4)) + + def test_full_cycle(self): + p = self._param() + assert p.state == GTPWeightState.NONE + p._set_state(GTPWeightState.ASYNC_WAIT) + p._set_state(GTPWeightState.DATA_READY) + p._set_state(GTPWeightState.NONE) + assert p.state == GTPWeightState.NONE + + def test_sync_path_cycle(self): + """NONE → DATA_READY_SYNC → NONE (sync all-gather path).""" + p = self._param() + p._set_state(GTPWeightState.DATA_READY_SYNC) + p._set_state(GTPWeightState.NONE) + assert p.state == GTPWeightState.NONE + + def test_rs_state_full_cycle(self): + """RS state machine: NONE → ASYNC_WAIT → DATA_READY → NONE.""" + p = self._param() + assert p.rs_state == GTPWeightState.NONE + p._set_rs_state(GTPWeightState.ASYNC_WAIT) + p._set_rs_state(GTPWeightState.DATA_READY) + p._set_rs_state(GTPWeightState.NONE) + assert p.rs_state == GTPWeightState.NONE + + +# --------------------------------------------------------------------------- +# 2. GTPWeightCache – coat-check buffer pool tests +# --------------------------------------------------------------------------- + +class TestGTPWeightCache: + + class _FakeGroup: + def __init__(self, size=2): + self._size = size + def size(self): + return self._size + def rank(self): + return 0 + + def _param(self, shape=(8, 4), gtp_size=2): + p = GTPShardedParam(torch.zeros(*shape)) + p.group = self._FakeGroup(gtp_size) + p.expert_idx = None + p.pad_length = 0 + p._quantizer = None + return p + + def test_reserve_returns_ticket(self): + cache = GTPWeightCache() + p = self._param() + ticket = cache.reserve(p, torch.bfloat16, fwd=True) + assert isinstance(ticket, int) + + def test_reserve_get_roundtrip(self): + cache = GTPWeightCache() + p = self._param() + ticket = cache.reserve(p, torch.bfloat16, fwd=True) + buf = cache.get(ticket) + assert buf is not None + # get() returns same buf on second call (buf cached in slot) + buf2 = cache.get(ticket) + assert buf2 is buf + + def test_buffer_reused_after_release(self): + cache = GTPWeightCache() + p = self._param() + t1 = cache.reserve(p, torch.bfloat16, fwd=True) + buf1 = cache.get(t1) + cache.release(t1) + # Reserve a new ticket, buf should come from pool + t2 = cache.reserve(p, torch.bfloat16, fwd=True) + buf2 = cache.get(t2) + assert buf1 is buf2, "Buffer should be reused from pool after release" + cache.release(t2) + + def test_two_simultaneous_reserves_are_distinct(self): + cache = GTPWeightCache() + p = self._param() + t1 = cache.reserve(p, torch.bfloat16, fwd=True) + buf1 = cache.get(t1) + t2 = cache.reserve(p, torch.bfloat16, fwd=True) + buf2 = cache.get(t2) + assert buf1 is not buf2, "Concurrent reserves must get distinct buffers" + + def test_tickets_are_unique(self): + """Each reserve() call returns a new unique ticket.""" + cache = GTPWeightCache() + p = self._param() + t1 = cache.reserve(p, torch.bfloat16, fwd=True) + t2 = cache.reserve(p, torch.bfloat16, fwd=True) + assert t1 != t2, "Each reserve() must return a unique ticket" + + def test_invalid_ticket_raises(self): + cache = GTPWeightCache() + with pytest.raises(KeyError): + cache.get(9999) + + def test_different_shapes_use_distinct_pool_slots(self): + cache = GTPWeightCache() + p1 = self._param(shape=(8, 4)) + p2 = self._param(shape=(16, 4)) + t1 = cache.reserve(p1, torch.bfloat16, fwd=True) + buf1 = cache.get(t1) + t2 = cache.reserve(p2, torch.bfloat16, fwd=True) + buf2 = cache.get(t2) + assert buf1.shape != buf2.shape + cache.release(t1); cache.release(t2) + + def test_fwd_bwd_tickets_are_distinct(self): + """fwd=True and fwd=False reserves always receive distinct ticket IDs.""" + cache = GTPWeightCache() + p = self._param() + t_fwd = cache.reserve(p, torch.bfloat16, fwd=True) + t_bwd = cache.reserve(p, torch.bfloat16, fwd=False) + assert t_fwd != t_bwd + + +# --------------------------------------------------------------------------- +# 3. GTP weight sharding: shard content and alignment padding +# --------------------------------------------------------------------------- + +def _worker_sharding_aligned(rank, world_size, port): + _dist_init(rank, world_size, port) + K, M = world_size * 32, 16 # K divisible by 16*world_size → no padding + full_weight = torch.arange(K * M, dtype=torch.float32).reshape(K, M).cuda() + dist.broadcast(full_weight, src=0) + + gtp_group = dist.new_group(list(range(world_size))) + mod = nn.Module() + mod.weight = nn.Parameter(full_weight.clone(), requires_grad=False) + wrap_module_params_gtp(mod, ['weight'], gtp_group) + shard = mod.weight + + rows_per_rank = K // world_size + assert shard.shape == (rows_per_rank, M), f"rank {rank}: unexpected shape {shard.shape}" + assert shard.pad_length == 0 + expected = full_weight[rank * rows_per_rank : (rank + 1) * rows_per_rank] + assert torch.allclose(shard.data, expected), f"rank {rank}: shard content mismatch" + dist.destroy_process_group() + + +def _worker_sharding_padding(rank, world_size, port): + _dist_init(rank, world_size, port) + alignment = 16 * world_size + K = alignment - 1 # deliberately unaligned + M = 16 + full_weight = torch.ones(K, M, dtype=torch.float32).cuda() + dist.broadcast(full_weight, src=0) + + gtp_group = dist.new_group(list(range(world_size))) + mod = nn.Module() + mod.weight = nn.Parameter(full_weight.clone(), requires_grad=False) + wrap_module_params_gtp(mod, ['weight'], gtp_group) + shard = mod.weight + + padded_K = alignment + rows_per_rank = padded_K // world_size + + if rank == world_size - 1: + assert shard.pad_length > 0 + # The shard tensor holds only the real rows; get_padded_shard() appends zero rows. + padded = shard.get_padded_shard() + assert padded.shape[0] == rows_per_rank, \ + f"rank {rank}: expected padded shard {rows_per_rank} rows, got {padded.shape[0]}" + n_real = K - rank * rows_per_rank + assert torch.all(padded[n_real:] == 0), "Padding rows must be zero" + else: + # pad_length is set globally on every rank's shard (slicer attaches the + # global padding amount), so we don't assert anything about it here — + # only the last rank's shard contains the actual padding rows. + assert shard.shape[0] == rows_per_rank, \ + f"rank {rank}: expected {rows_per_rank} rows, got {shard.shape[0]}" + + dist.destroy_process_group() + + +class TestGTPSharding: + def test_aligned_shard_content(self): + _requires_multi_gpu(4) + _run_distributed(_worker_sharding_aligned, 4) + + def test_unaligned_shard_padding(self): + _requires_multi_gpu(4) + _run_distributed(_worker_sharding_padding, 4) + + +# --------------------------------------------------------------------------- +# 4. wrap_module_params_gtp: param replacement and GroupedLinear weight_list +# --------------------------------------------------------------------------- + +def _worker_linear_param_replaced(rank, world_size, port): + _dist_init(rank, world_size, port) + in_f, out_f = 64, 128 + gtp_group = dist.new_group(list(range(world_size))) + layer = te.Linear( + in_features=in_f, out_features=out_f, + bias=False, params_dtype=torch.bfloat16, + device="cuda", gtp_group=gtp_group, + ) + w = layer.weight + assert isinstance(w, GTPShardedParam), "weight must be GTPShardedParam" + assert w.shape == (out_f // world_size, in_f), f"unexpected shard shape {w.shape}" + assert w.group is gtp_group + dist.destroy_process_group() + + +def _worker_grouped_weight_list(rank, world_size, port): + _dist_init(rank, world_size, port) + num_gemms, in_f, out_f = 3, 32, 64 + gtp_group = dist.new_group(list(range(world_size))) + layer = te.GroupedLinear( + num_gemms=num_gemms, in_features=in_f, out_features=out_f, + bias=False, params_dtype=torch.bfloat16, + device="cuda", gtp_group=gtp_group, + ) + w0 = layer.weight0 + assert isinstance(w0, GTPShardedParam) + assert w0.weight_list is not None + assert len(w0.weight_list) == num_gemms + assert [w.expert_idx for w in w0.weight_list] == list(range(num_gemms)) + dist.destroy_process_group() + + +class TestWrapModuleParams: + def test_linear_weight_replaced(self): + _requires_multi_gpu(4) + _run_distributed(_worker_linear_param_replaced, 4) + + def test_grouped_linear_weight_list(self): + _requires_multi_gpu(4) + _run_distributed(_worker_grouped_weight_list, 4) + + +# --------------------------------------------------------------------------- +# 5. Linear forward/backward numerical correctness +# --------------------------------------------------------------------------- + +def _worker_linear_correctness(rank, world_size, port): + """GTP output == (all-gathered weight) @ input, and dX matches.""" + _dist_init(rank, world_size, port) + torch.manual_seed(0) + batch, in_f, out_f = 16, 64, 128 # out_f % (16*world_size)==0 → no padding + dtype = torch.bfloat16 + gtp_group = dist.new_group(list(range(world_size))) + + layer = te.Linear( + in_features=in_f, out_features=out_f, + bias=False, params_dtype=dtype, + device="cuda", gtp_group=gtp_group, + ) + + # Reconstruct full weight from shards (all-gather) + shard = layer.weight.data.clone() + all_shards = [torch.zeros_like(shard) for _ in range(world_size)] + dist.all_gather(all_shards, shard, group=gtp_group) + full_weight = torch.cat(all_shards, dim=0).float()[:out_f] # strip any padding + + # Shared input across ranks + inp = torch.randn(batch, in_f, dtype=dtype, device="cuda") + dist.broadcast(inp, src=0) + + inp_gtp = inp.clone().requires_grad_(True) + inp_ref = inp.clone().requires_grad_(True) + + # GTP forward + out_gtp = layer(inp_gtp, is_first_microbatch=True) + + # Reference forward + out_ref = inp_ref.float() @ full_weight.T + out_ref = out_ref.to(dtype) + + assert out_gtp.shape == out_ref.shape, f"Shape mismatch {out_gtp.shape} vs {out_ref.shape}" + assert torch.allclose(out_gtp.float(), out_ref.float(), atol=0.1, rtol=0.1), ( + f"Output mismatch max_diff={(out_gtp.float()-out_ref.float()).abs().max():.4f}" + ) + + # wgrad RS path always accumulates into main_grad; allocate before backward. + layer.weight.main_grad = torch.zeros(layer.weight.shape, dtype=dtype, device="cuda") + + # Backward: compare input gradient + grad_out = torch.randn_like(out_gtp) + dist.broadcast(grad_out, src=0) + out_gtp.backward(grad_out) + out_ref.backward(grad_out.float()) + + assert inp_gtp.grad is not None + assert torch.allclose(inp_gtp.grad.float(), inp_ref.grad.float(), atol=0.1, rtol=0.1), ( + f"dX mismatch max_diff={(inp_gtp.grad.float()-inp_ref.grad.float()).abs().max():.4f}" + ) + dist.destroy_process_group() + + +class TestLinearGTP: + def test_forward_backward_correctness(self): + _requires_multi_gpu(4) + _run_distributed(_worker_linear_correctness, 4) + + +# --------------------------------------------------------------------------- +# 6. LayerNormLinear forward/backward smoke test +# --------------------------------------------------------------------------- + +def _worker_layernorm_linear(rank, world_size, port): + _dist_init(rank, world_size, port) + torch.manual_seed(0) + seq, batch, in_f, out_f = 4, 2, 64, 128 + dtype = torch.bfloat16 + gtp_group = dist.new_group(list(range(world_size))) + + layer = te.LayerNormLinear( + in_features=in_f, out_features=out_f, + bias=False, params_dtype=dtype, + device="cuda", gtp_group=gtp_group, + ) + assert isinstance(layer.weight, GTPShardedParam) + + inp = torch.randn(seq, batch, in_f, dtype=dtype, device="cuda", requires_grad=True) + dist.broadcast(inp, src=0) + + out = layer(inp, is_first_microbatch=True) + assert out.shape == (seq, batch, out_f), f"unexpected output shape {out.shape}" + + layer.weight.main_grad = torch.zeros(layer.weight.shape, dtype=dtype, device="cuda") + out.sum().backward() + assert inp.grad is not None and inp.grad.shape == inp.shape + dist.destroy_process_group() + + +class TestLayerNormLinearGTP: + def test_forward_backward(self): + _requires_multi_gpu(4) + _run_distributed(_worker_layernorm_linear, 4) + + +# --------------------------------------------------------------------------- +# 7. GroupedLinear forward/backward smoke test +# --------------------------------------------------------------------------- + +def _worker_grouped_linear(rank, world_size, port, num_gemms): + _dist_init(rank, world_size, port) + torch.manual_seed(0) + in_f, out_f, total_tokens = 32, 64, num_gemms * 4 + dtype = torch.bfloat16 + gtp_group = dist.new_group(list(range(world_size))) + + layer = te.GroupedLinear( + num_gemms=num_gemms, in_features=in_f, out_features=out_f, + bias=False, params_dtype=dtype, + device="cuda", gtp_group=gtp_group, + ) + assert isinstance(layer.weight0, GTPShardedParam) + + m_splits = [total_tokens // num_gemms] * num_gemms + m_splits[-1] += total_tokens - sum(m_splits) + + inp = torch.randn(total_tokens, in_f, dtype=dtype, device="cuda", requires_grad=True) + dist.broadcast(inp, src=0) + + out = layer(inp, m_splits=m_splits, is_first_microbatch=True) + assert out.shape == (total_tokens, out_f), f"unexpected output shape {out.shape}" + + for i in range(num_gemms): + w = getattr(layer, f"weight{i}") + w.main_grad = torch.zeros(w.shape, dtype=dtype, device="cuda") + out.sum().backward() + assert inp.grad is not None and inp.grad.shape == inp.shape + dist.destroy_process_group() + + +class TestGroupedLinearGTP: + @pytest.mark.parametrize("num_gemms", [2, 4]) + def test_forward_backward(self, num_gemms): + _requires_multi_gpu(4) + _run_distributed(_worker_grouped_linear, 4, num_gemms) + + +# --------------------------------------------------------------------------- +# 8. Prefetch chain: next_w / prev_w wiring after first forward pass +# --------------------------------------------------------------------------- + +def _worker_chain_wired(rank, world_size, port): + _dist_init(rank, world_size, port) + torch.manual_seed(0) + in_f, out_f = 32, 64 + dtype = torch.bfloat16 + gtp_group = dist.new_group(list(range(world_size))) + + l0 = te.Linear(in_features=in_f, out_features=out_f, bias=False, + params_dtype=dtype, device="cuda", gtp_group=gtp_group) + l1 = te.Linear(in_features=in_f, out_features=out_f, bias=False, + params_dtype=dtype, device="cuda", gtp_group=gtp_group) + + inp = torch.randn(4, in_f, dtype=dtype, device="cuda") + dist.broadcast(inp, src=0) + + # First forward pass builds the linked list + l0(inp, is_first_microbatch=True) + l1(inp, is_first_microbatch=True) + + w0, w1 = l0.weight, l1.weight + assert w0.next_w is w1, "w0.next_w should point to w1" + assert w1.prev_w is w0, "w1.prev_w should point back to w0" + assert w1.next_w is None + assert w0.prev_w is None + dist.destroy_process_group() + + +def _worker_chain_async_prefetch(rank, world_size, port): + """On the second forward pass, w1 should be in DATA_READY before its forward runs.""" + _dist_init(rank, world_size, port) + torch.manual_seed(0) + in_f, out_f = 32, 64 + dtype = torch.bfloat16 + gtp_group = dist.new_group(list(range(world_size))) + + l0 = te.Linear(in_features=in_f, out_features=out_f, bias=False, + params_dtype=dtype, device="cuda", gtp_group=gtp_group) + l1 = te.Linear(in_features=in_f, out_features=out_f, bias=False, + params_dtype=dtype, device="cuda", gtp_group=gtp_group) + + inp = torch.randn(4, in_f, dtype=dtype, device="cuda") + dist.broadcast(inp, src=0) + + # First pass builds chain, second pass uses async prefetch + for _ in range(2): + out = l0(inp, is_first_microbatch=True) + l1(inp, is_first_microbatch=True) + assert torch.isfinite(out).all(), "Non-finite output on second pass" + dist.destroy_process_group() + + +class TestGTPPrefetchChain: + def test_chain_wired_after_first_pass(self): + _requires_multi_gpu(4) + _run_distributed(_worker_chain_wired, 4) + + def test_async_prefetch_second_pass(self): + _requires_multi_gpu(4) + _run_distributed(_worker_chain_async_prefetch, 4) + + +# --------------------------------------------------------------------------- +# 9. Wgrad reduce-scatter: shape and deferred async path +# --------------------------------------------------------------------------- + +def _worker_wgrad_shape(rank, world_size, port): + """After backward, weight.grad shape must match the local shard shape.""" + _dist_init(rank, world_size, port) + torch.manual_seed(0) + in_f, out_f = 32, 64 + dtype = torch.bfloat16 + gtp_group = dist.new_group(list(range(world_size))) + + layer = te.Linear( + in_features=in_f, out_features=out_f, + bias=False, params_dtype=dtype, + device="cuda", gtp_group=gtp_group, + fuse_wgrad_accumulation=False, + ) + inp = torch.randn(8, in_f, dtype=dtype, device="cuda", requires_grad=True) + dist.broadcast(inp, src=0) + + layer.weight.main_grad = torch.zeros(layer.weight.shape, dtype=dtype, device="cuda") + layer(inp, is_first_microbatch=True).sum().backward() + + w = layer.weight + if w.grad is not None: + assert w.grad.shape == w.shape, \ + f"wgrad shape {w.grad.shape} != shard shape {w.shape}" + dist.destroy_process_group() + + +def _worker_multilayer_deferred_rs(rank, world_size, port): + """Two-layer GTP: async RS deferred for layer0 (non-last), sync for layer1 (last in bwd).""" + _dist_init(rank, world_size, port) + torch.manual_seed(0) + in_f, out_f = 32, 64 + dtype = torch.bfloat16 + gtp_group = dist.new_group(list(range(world_size))) + + l0 = te.Linear(in_features=in_f, out_features=out_f, bias=False, + params_dtype=dtype, device="cuda", gtp_group=gtp_group) + l1 = te.Linear(in_features=in_f, out_features=out_f, bias=False, + params_dtype=dtype, device="cuda", gtp_group=gtp_group) + + inp = torch.randn(8, in_f, dtype=dtype, device="cuda", requires_grad=True) + dist.broadcast(inp, src=0) + + # wgrad RS path always accumulates into main_grad; allocate before backward. + l0.weight.main_grad = torch.zeros(l0.weight.shape, dtype=dtype, device="cuda") + l1.weight.main_grad = torch.zeros(l1.weight.shape, dtype=dtype, device="cuda") + + out = l0(inp, is_first_microbatch=True) + l1(inp, is_first_microbatch=True) + out.sum().backward() + + # Both weights' main_grad should have been updated + for lyr in [l0, l1]: + w = lyr.weight + assert w.main_grad is not None, f"No main_grad on {lyr.__class__.__name__}.weight" + dist.destroy_process_group() + + +class TestGTPWgradRS: + def test_wgrad_shape_matches_shard(self): + _requires_multi_gpu(4) + _run_distributed(_worker_wgrad_shape, 4) + + def test_multilayer_deferred_rs(self): + _requires_multi_gpu(4) + _run_distributed(_worker_multilayer_deferred_rs, 4) + + +# --------------------------------------------------------------------------- +# 10. Multiple microbatches: output must be consistent when weight unchanged +# --------------------------------------------------------------------------- + +def _worker_microbatches(rank, world_size, port): + _dist_init(rank, world_size, port) + torch.manual_seed(0) + batch, in_f, out_f = 8, 64, 128 + dtype = torch.bfloat16 + gtp_group = dist.new_group(list(range(world_size))) + + layer = te.Linear( + in_features=in_f, out_features=out_f, + bias=False, params_dtype=dtype, + device="cuda", gtp_group=gtp_group, + ) + inp = torch.randn(batch, in_f, dtype=dtype, device="cuda") + dist.broadcast(inp, src=0) + + # First microbatch + out1 = layer(inp, is_first_microbatch=True).detach().clone() + + # Second microbatch with same weight (skip_weight_cast=True path) + out2 = layer(inp, is_first_microbatch=False).detach() + + assert torch.allclose(out1, out2), \ + f"Microbatch outputs differ; max_diff={(out1-out2).abs().max():.6f}" + dist.destroy_process_group() + + +class TestGTPMicrobatches: + def test_consistent_across_microbatches(self): + _requires_multi_gpu(4) + _run_distributed(_worker_microbatches, 4) + + +# --------------------------------------------------------------------------- +# 11. NVFP4 + GTP: Linear forward/backward, quantized shard setup +# --------------------------------------------------------------------------- + +def _worker_nvfp4_linear(rank, world_size, port): + """Verify that GTP Linear correctly quantizes, all-gathers, and computes with NVFP4.""" + _dist_init(rank, world_size, port) + torch.manual_seed(0) + # batch=32: NVFP4 wgrad GEMM (K=batch) requires K divisible by 32 + batch, in_f, out_f = 32, 64, 128 # out_f % (16*world_size)==0 → no padding + dtype = torch.bfloat16 + gtp_group = dist.new_group(list(range(world_size))) + + layer = te.Linear( + in_features=in_f, out_features=out_f, + bias=False, params_dtype=dtype, + device="cuda", gtp_group=gtp_group, + ) + inp = torch.randn(batch, in_f, dtype=dtype, device="cuda", requires_grad=True) + dist.broadcast(inp, src=0) + + # Forward under NVFP4 recipe – triggers setup() and NVFP4 quantization + recipe = NVFP4BlockScaling() + with fp8_autocast(enabled=True, fp8_recipe=recipe): + out = layer(inp, is_first_microbatch=True) + + # After the first forward pass setup() must have created a quantized shard + w = layer.weight + assert w.quantized is not None, "NVFP4 quantized shard must be set after setup()" + assert isinstance(w.quantized, QuantizedTensor), \ + f"weight.quantized should be QuantizedTensor, got {type(w.quantized)}" + + assert out.shape == (batch, out_f), f"unexpected output shape {out.shape}" + assert torch.isfinite(out).all(), "NVFP4 GTP output has non-finite values" + + # Second microbatch reuses cached quantized weight (skip_weight_cast path) + with fp8_autocast(enabled=True, fp8_recipe=recipe): + out2 = layer(inp.detach(), is_first_microbatch=False) + assert torch.isfinite(out2).all(), "NVFP4 GTP second-microbatch output has non-finite values" + + dist.destroy_process_group() + + +def _worker_nvfp4_linear_unaligned(rank, world_size, port): + """Verify NVFP4 GTP when out_features is not aligned to 16*world_size (padding path). + + out_f is chosen to be divisible by 8 (satisfies NVFP4 GEMM alignment) but not by + 16*world_size (so padding is needed). The last GTP rank receives a shard that is + zero-padded to reach the shard_size boundary. After all-gather, _strip_padding + removes the padded rows from the gathered weight before the GEMM, so the output + has the original out_f columns. + """ + _dist_init(rank, world_size, port) + torch.manual_seed(0) + alignment = 16 * world_size # 64 for world_size=4 + # Choose out_f divisible by 8 (NVFP4 GEMM constraint) but not by 64 (GTP alignment). + # With out_f=56: pad_length=8, shard_size=16, last rank gets 8 rows padded to 16. + out_f = alignment - 8 # 56 for world_size=4 + in_f = 64 + batch = 32 + dtype = torch.bfloat16 + gtp_group = dist.new_group(list(range(world_size))) + + layer = te.Linear( + in_features=in_f, out_features=out_f, + bias=False, params_dtype=dtype, + device="cuda", gtp_group=gtp_group, + ) + inp = torch.randn(batch, in_f, dtype=dtype, device="cuda", requires_grad=True) + dist.broadcast(inp, src=0) + + with fp8_autocast(enabled=True, fp8_recipe=NVFP4BlockScaling()): + out = layer(inp, is_first_microbatch=True) + + # After _strip_padding removes the padded rows, output has out_f (not padded) cols. + assert out.shape == (batch, out_f), f"unexpected output shape {out.shape}" + assert torch.isfinite(out).all(), "NVFP4 GTP (unaligned) output has non-finite values" + dist.destroy_process_group() + + +class TestNVFP4LinearGTP: + def test_forward_backward(self): + _requires_nvfp4() + _requires_multi_gpu(4) + _run_distributed(_worker_nvfp4_linear, 4) + + def test_forward_unaligned_padding(self): + _requires_nvfp4() + _requires_multi_gpu(4) + _run_distributed(_worker_nvfp4_linear_unaligned, 4) + + +# --------------------------------------------------------------------------- +# 12. NVFP4 + GTP: GroupedLinear forward/backward (coalesced batched all-gather) +# --------------------------------------------------------------------------- + +def _worker_nvfp4_grouped_linear(rank, world_size, port, num_gemms): + """Verify NVFP4 GTP with GroupedLinear (uses grouped_gather_along_first_dim).""" + _dist_init(rank, world_size, port) + torch.manual_seed(0) + # NVFP4 split_quantize constraints: in_f % 128 == 0, tokens_per_expert % 64 == 0 + # (Hadamard transform requirement), and K=tokens_per_expert % 32 == 0 for wgrad. + in_f, out_f, total_tokens = 128, 256, num_gemms * 64 + dtype = torch.bfloat16 + gtp_group = dist.new_group(list(range(world_size))) + + layer = te.GroupedLinear( + num_gemms=num_gemms, in_features=in_f, out_features=out_f, + bias=False, params_dtype=dtype, + device="cuda", gtp_group=gtp_group, + ) + assert isinstance(layer.weight0, GTPShardedParam) + + m_splits = [total_tokens // num_gemms] * num_gemms + m_splits[-1] += total_tokens - sum(m_splits) + + inp = torch.randn(total_tokens, in_f, dtype=dtype, device="cuda", requires_grad=True) + dist.broadcast(inp, src=0) + + with fp8_autocast(enabled=True, fp8_recipe=NVFP4BlockScaling()): + out = layer(inp, m_splits=m_splits, is_first_microbatch=True) + + assert out.shape == (total_tokens, out_f), f"unexpected output shape {out.shape}" + assert torch.isfinite(out).all(), "NVFP4 GroupedLinear GTP output has non-finite values" + + # All expert weight shards should be quantized after setup() + for i in range(num_gemms): + name = f"weight{i}" + w = getattr(layer, name) + assert isinstance(w, GTPShardedParam) + assert w.quantized is not None, f"{name}.quantized not set after NVFP4 setup()" + assert isinstance(w.quantized, QuantizedTensor), \ + f"{name}.quantized should be QuantizedTensor, got {type(w.quantized)}" + + for i in range(num_gemms): + w = getattr(layer, f"weight{i}") + w.main_grad = torch.zeros(w.shape, dtype=dtype, device="cuda") + out.sum().backward() + assert inp.grad is not None and inp.grad.shape == inp.shape + dist.destroy_process_group() + + +class TestNVFP4GroupedLinearGTP: + @pytest.mark.parametrize("num_gemms", [2, 4]) + def test_forward_backward(self, num_gemms): + _requires_nvfp4() + _requires_multi_gpu(4) + _run_distributed(_worker_nvfp4_grouped_linear, 4, num_gemms) + + +# --------------------------------------------------------------------------- +# 13. MXFP8 + GTP: Linear forward/backward, quantized shard setup +# --------------------------------------------------------------------------- + +def _worker_mxfp8_linear(rank, world_size, port): + """Verify that GTP Linear correctly quantizes, all-gathers, and computes with MXFP8.""" + from transformer_engine.common.recipe import MXFP8BlockScaling + _dist_init(rank, world_size, port) + torch.manual_seed(0) + # batch=32: MXFP8 wgrad GEMM (K=batch) requires K divisible by MXFP8_BLOCK_SCALING_SIZE=32 + batch, in_f, out_f = 32, 64, 128 # out_f % (16*world_size)==0 → no padding + dtype = torch.bfloat16 + gtp_group = dist.new_group(list(range(world_size))) + + layer = te.Linear( + in_features=in_f, out_features=out_f, + bias=False, params_dtype=dtype, + device="cuda", gtp_group=gtp_group, + ) + inp = torch.randn(batch, in_f, dtype=dtype, device="cuda", requires_grad=True) + dist.broadcast(inp, src=0) + + # Forward under MXFP8 recipe – triggers setup() and MXFP8 quantization + recipe = MXFP8BlockScaling() + with fp8_autocast(enabled=True, fp8_recipe=recipe): + out = layer(inp, is_first_microbatch=True) + + # After the first forward pass setup() must have created a quantized shard + w = layer.weight + assert w.quantized is not None, "MXFP8 quantized shard must be set after setup()" + assert isinstance(w.quantized, QuantizedTensor), \ + f"weight.quantized should be QuantizedTensor, got {type(w.quantized)}" + + assert out.shape == (batch, out_f), f"unexpected output shape {out.shape}" + assert torch.isfinite(out).all(), "MXFP8 GTP output has non-finite values" + + # Backward should complete without error + layer.weight.main_grad = torch.zeros(layer.weight.shape, dtype=dtype, device="cuda") + out.sum().backward() + assert inp.grad is not None + assert inp.grad.shape == inp.shape + + # Second microbatch reuses cached quantized weight (skip_weight_cast path) + with fp8_autocast(enabled=True, fp8_recipe=recipe): + out2 = layer(inp.detach(), is_first_microbatch=False) + assert torch.isfinite(out2).all(), "MXFP8 GTP second-microbatch output has non-finite values" + + dist.destroy_process_group() + + +def _worker_mxfp8_linear_unaligned(rank, world_size, port): + """Verify MXFP8 GTP when out_features is not aligned to 16*world_size (padding path). + + MXFP8 requires tensor dims divisible by 32, so shard_size (= M_padded / world_size) + must be a multiple of 32. With world_size=4 this requires M_padded % 128 == 0. + out_f=120 gives M_padded=128, shard_size=32 (32 % 32 == 0). The last rank has + 24 real rows zero-padded to 32. After all-gather, _strip_padding removes the padded + rows before the GEMM, so the output has the original out_f columns. + """ + from transformer_engine.common.recipe import MXFP8BlockScaling + _dist_init(rank, world_size, port) + torch.manual_seed(0) + # out_f=120: M_padded=128, shard_size=32, last rank has 24 rows padded to 32. + # 120 is divisible by 8 (GEMM constraint), not by 64 (GTP alignment → padding needed). + out_f = 120 + in_f = 64 + batch = 32 + dtype = torch.bfloat16 + gtp_group = dist.new_group(list(range(world_size))) + + layer = te.Linear( + in_features=in_f, out_features=out_f, + bias=False, params_dtype=dtype, + device="cuda", gtp_group=gtp_group, + ) + inp = torch.randn(batch, in_f, dtype=dtype, device="cuda", requires_grad=True) + dist.broadcast(inp, src=0) + + with fp8_autocast(enabled=True, fp8_recipe=MXFP8BlockScaling()): + out = layer(inp, is_first_microbatch=True) + + # After _strip_padding removes the padded rows, output has out_f (not padded) cols. + assert out.shape == (batch, out_f), f"unexpected output shape {out.shape}" + assert torch.isfinite(out).all(), "MXFP8 GTP (unaligned) output has non-finite values" + dist.destroy_process_group() + + +def _requires_mxfp8(): + available, reason = is_mxfp8_available(return_reason=True) + if not available: + pytest.skip(f"MXFP8 not available: {reason}") + + +class TestMXFP8LinearGTP: + def test_forward_backward(self): + _requires_mxfp8() + _requires_multi_gpu(4) + _run_distributed(_worker_mxfp8_linear, 4) + + def test_forward_unaligned_padding(self): + _requires_mxfp8() + _requires_multi_gpu(4) + _run_distributed(_worker_mxfp8_linear_unaligned, 4) + + +# --------------------------------------------------------------------------- +# 14. GTPConfig / update_config +# --------------------------------------------------------------------------- + +class TestGTPConfig: + + def test_update_pad_for_alignment(self): + original = gtp_module.GTP_CONFIG.pad_for_alignment + try: + gtp_module.update_config(pad_for_alignment=8) + assert gtp_module.GTP_CONFIG.pad_for_alignment == 8 + finally: + gtp_module.update_config(pad_for_alignment=original) + + def test_update_weight_prefetch(self): + original = gtp_module.GTP_CONFIG.weight_prefetch + try: + gtp_module.update_config(weight_prefetch=False) + assert gtp_module.GTP_CONFIG.weight_prefetch is False + finally: + gtp_module.update_config(weight_prefetch=original) + + def test_invalid_key_raises(self): + with pytest.raises(ValueError, match="Unknown GTP config option"): + gtp_module.update_config(nonexistent_key=123) + + +# --------------------------------------------------------------------------- +# 15. GTPShardedParam properties – shape computations and padding +# --------------------------------------------------------------------------- + +class TestGTPShardedParamProperties: + + class _FakeGroup: + def __init__(self, size=4, rank=0): + self._size = size + self._rank = rank + def size(self): return self._size + def rank(self): return self._rank + + def _make_param(self, shape, pad_length=0, group_size=4, group_rank=0): + p = GTPShardedParam(torch.zeros(*shape)) + p.group = self._FakeGroup(size=group_size, rank=group_rank) + p.pad_length = pad_length + p.expert_idx = None + return p + + # --- _unsharded_shape_padded --- + + def test_unsharded_shape_padded_no_padding(self): + # shape=(8, 4), group_size=4 → 8*4=32 rows, no padding + p = self._make_param((8, 4), pad_length=0, group_size=4, group_rank=2) + assert p._unsharded_shape_padded == (32, 4) + + def test_unsharded_shape_padded_last_rank_with_padding(self): + # Local shard includes its slice of padding rows: 16 rows per rank, + # pad_length=1 marks 1 of those (on the last rank) as pad → padded + # unsharded shape = 16 * 4 = 64. pad_length is global metadata, the + # same value lives on every rank's shard. + p = self._make_param((16, 32), pad_length=1, group_size=4, group_rank=3) + assert p._unsharded_shape_padded == (64, 32) + + def test_unsharded_shape_padded_non_last_rank_with_padding(self): + # Non-last rank: pad_length is the same global value, same formula. + p = self._make_param((16, 32), pad_length=1, group_size=4, group_rank=0) + assert p._unsharded_shape_padded == (64, 32) + + # --- _unsharded_shape --- + + def test_unsharded_shape_no_padding(self): + p = self._make_param((8, 4), pad_length=0, group_size=4, group_rank=0) + assert p._unsharded_shape == (32, 4) + + def test_unsharded_shape_strips_padding(self): + # Local 16 rows × 4 ranks = 64 padded; pad_length=1 → unsharded = 63. + p = self._make_param((16, 32), pad_length=1, group_size=4, group_rank=3) + assert p._unsharded_shape == (63, 32) + + # --- get_padded_shard --- + + def test_get_padded_shard_identity_when_no_padding(self): + p = self._make_param((6, 4), pad_length=0) + result = p.get_padded_shard() + assert result is p # identity – no copy needed + + def test_get_padded_shard_identity_non_last_rank(self): + # pad_length > 0 but not the padded last rank → no padding added + p = self._make_param((16, 4), pad_length=1, group_size=4, group_rank=0) + result = p.get_padded_shard() + assert result is p + + def test_get_padded_shard_identity_last_rank(self): + # Under current semantics the local shard already contains its share + # of padding (slicer F.pads with zeros before slicing), so + # get_padded_shard() is the identity on the last rank too. + p = self._make_param((8, 4), pad_length=2, group_size=4, group_rank=3) + assert p.get_padded_shard() is p + + # --- _strip_padding --- + + def test_strip_padding_identity_no_padding(self): + p = self._make_param((8, 4), pad_length=0) + t = torch.randn(32, 4) + assert p._strip_padding(t) is t + + def test_strip_padding_plain_tensor(self): + # Gathered weight [32, 4] with pad_length=1 → strip 1 row → [31, 4] + p = self._make_param((7, 4), pad_length=1, group_size=4, group_rank=0) + t = torch.randn(32, 4) + result = p._strip_padding(t) + assert result.shape == (31, 4) + assert torch.equal(result, t[:-1]) + + def test_strip_padding_multi_row(self): + # pad_length=4 strips 4 rows + p = self._make_param((12, 8), pad_length=4, group_size=4, group_rank=0) + t = torch.ones(64, 8) + result = p._strip_padding(t) + assert result.shape == (60, 8) + + +# --------------------------------------------------------------------------- +# 16. _get_cache_key – expert vs non-expert, fwd vs bwd +# --------------------------------------------------------------------------- + +class TestGTPCacheKey: + + class _FakeGroup: + def size(self): return 4 + def rank(self): return 0 + + def _param(self, shape=(16, 32), expert_idx=None): + p = GTPShardedParam(torch.zeros(*shape)) + p.group = self._FakeGroup() + p.expert_idx = expert_idx + p.pad_length = 0 + return p + + def test_non_expert_key_same_for_fwd_bwd(self): + """Non-routed params produce the same cache key for fwd and bwd.""" + p = self._param(expert_idx=None) + assert p._get_cache_key(torch.bfloat16, fwd=True, reduce_scatter=False) == \ + p._get_cache_key(torch.bfloat16, fwd=False, reduce_scatter=False) + + def test_expert_key_differs_fwd_bwd(self): + """For quantized (non-torch.dtype) recipes, expert fwd vs bwd keys differ.""" + p = self._param(expert_idx=0) + # _get_cache_key differentiates fwd/bwd only for non-torch.dtype objects + # (e.g. quantized recipe dtype descriptors). Use a mock to trigger that path. + mock_dtype = "fp8" + assert p._get_cache_key(mock_dtype, fwd=True, reduce_scatter=False) != \ + p._get_cache_key(mock_dtype, fwd=False, reduce_scatter=False) + + def test_different_expert_idx_different_keys(self): + """Two experts with same shape but different indices get distinct keys.""" + p0 = self._param(expert_idx=0) + p1 = self._param(expert_idx=1) + assert p0._get_cache_key(torch.bfloat16, fwd=True, reduce_scatter=False) != \ + p1._get_cache_key(torch.bfloat16, fwd=True, reduce_scatter=False) + + def test_same_expert_idx_same_key(self): + """Same-shaped experts with the same idx share a cache key (cross-layer buffer reuse).""" + p_l0 = self._param(expert_idx=0) + p_l1 = self._param(expert_idx=0) + assert p_l0._get_cache_key(torch.bfloat16, fwd=True, reduce_scatter=False) == \ + p_l1._get_cache_key(torch.bfloat16, fwd=True, reduce_scatter=False) + + def test_different_dtypes_different_keys(self): + p = self._param() + assert p._get_cache_key(torch.bfloat16, fwd=True, reduce_scatter=False) != \ + p._get_cache_key(torch.float32, fwd=True, reduce_scatter=False) + + def test_rs_key_differs_from_ag_key(self): + """reduce_scatter=True key must differ from reduce_scatter=False key.""" + p = self._param() + assert p._get_cache_key(torch.bfloat16, fwd=True, reduce_scatter=False) != \ + p._get_cache_key(torch.bfloat16, fwd=True, reduce_scatter=True) + + +# --------------------------------------------------------------------------- +# 17. GTPWeightCache.take() deferred vs get() immediate pool return +# --------------------------------------------------------------------------- + +class TestGTPCacheRelease: + """Tests for GTPWeightCache reserve/get/release semantics.""" + + class _FakeGroup: + def size(self): return 2 + def rank(self): return 0 + + def _param(self, shape=(8, 4)): + p = GTPShardedParam(torch.zeros(*shape)) + p.group = self._FakeGroup() + p.expert_idx = None + p.pad_length = 0 + p._quantizer = None + return p + + def test_release_returns_buffer_to_pool(self): + """release() puts the buffer back so the next reserve+get reuses it.""" + cache = GTPWeightCache() + p = self._param() + t1 = cache.reserve(p, torch.bfloat16, fwd=True) + buf1 = cache.get(t1) + cache.release(t1) + # New ticket should pop buf1 from pool + t2 = cache.reserve(p, torch.bfloat16, fwd=True) + buf2 = cache.get(t2) + assert buf2 is buf1, "Buffer should be reused after release()" + cache.release(t2) + + def test_without_release_pool_stays_empty(self): + """Without release(), subsequent reserves allocate fresh buffers.""" + cache = GTPWeightCache() + p = self._param() + t1 = cache.reserve(p, torch.bfloat16, fwd=True) + buf1 = cache.get(t1) + # Do NOT release t1 — pool stays empty + t2 = cache.reserve(p, torch.bfloat16, fwd=True) + buf2 = cache.get(t2) + assert buf2 is not buf1, "Without release, a fresh buffer must be allocated" + + def test_get_same_ticket_returns_same_buf(self): + """get() is idempotent — calling it twice returns the same buffer.""" + cache = GTPWeightCache() + p = self._param() + t = cache.reserve(p, torch.bfloat16, fwd=True) + buf_a = cache.get(t) + buf_b = cache.get(t) + assert buf_a is buf_b + cache.release(t) + + def test_release_invalid_ticket_raises(self): + cache = GTPWeightCache() + with pytest.raises(KeyError): + cache.release(9999) + + +# --------------------------------------------------------------------------- +# 18. tag_gtp_params_with_names – _debug_name population +# --------------------------------------------------------------------------- + +class TestTagGTPParamsWithNames: + + def test_debug_name_populated_for_gtp_param(self): + """GTPShardedParam._debug_name is set to the dotted parameter path.""" + class _FakeGroup: + def size(self): return 1 + def rank(self): return 0 + + model = nn.Linear(4, 8, bias=False) + w = GTPShardedParam(torch.randn(8, 4)) + w.group = _FakeGroup() + model._parameters['weight'] = w + + gtp_module.tag_gtp_params_with_names(model) + assert w._debug_name == 'weight', \ + f"Expected 'weight', got '{w._debug_name}'" + + def test_nested_module_debug_name(self): + """Nested module produces a dotted debug name.""" + class _FakeGroup: + def size(self): return 1 + def rank(self): return 0 + + outer = nn.Sequential(nn.Linear(4, 8, bias=False)) + w = GTPShardedParam(torch.randn(8, 4)) + w.group = _FakeGroup() + outer._modules['0']._parameters['weight'] = w + + gtp_module.tag_gtp_params_with_names(outer) + assert w._debug_name == '0.weight', \ + f"Expected '0.weight', got '{w._debug_name}'" + + def test_non_gtp_params_are_skipped(self): + """Plain nn.Parameter instances are silently ignored.""" + model = nn.Linear(4, 8) + gtp_module.tag_gtp_params_with_names(model) # must not raise + + +# --------------------------------------------------------------------------- +# 19. wrap_module_params_gtp is a no-op when gtp_group.size() == 1 +# --------------------------------------------------------------------------- + +class TestGTPGroupSizeOne: + + class _SingletonGroup: + def size(self): return 1 + def rank(self): return 0 + + def test_no_sharding_when_gtp_size_one(self): + """wrap_module_params_gtp must be a no-op for a singleton GTP group.""" + mod = nn.Linear(32, 64, bias=False) + original_weight = mod.weight + wrap_module_params_gtp(mod, ['weight'], self._SingletonGroup()) + assert mod.weight is original_weight, \ + "gtp_group.size()==1 should leave parameters unchanged" + assert not isinstance(mod.weight, GTPShardedParam) + + +# --------------------------------------------------------------------------- +# 21. weight_prefetch=False: forward still produces correct output +# --------------------------------------------------------------------------- + +def _worker_prefetch_disabled(rank, world_size, port): + _dist_init(rank, world_size, port) + torch.manual_seed(0) + in_f, out_f = 32, 64 + dtype = torch.bfloat16 + gtp_group = dist.new_group(list(range(world_size))) + + gtp_module.update_config(weight_prefetch=False) + try: + l0 = te.Linear(in_features=in_f, out_features=out_f, bias=False, + params_dtype=dtype, device="cuda", gtp_group=gtp_group) + l1 = te.Linear(in_features=in_f, out_features=out_f, bias=False, + params_dtype=dtype, device="cuda", gtp_group=gtp_group) + + inp = torch.randn(4, in_f, dtype=dtype, device="cuda") + dist.broadcast(inp, src=0) + + # Single forward pass: builds chain and verifies output is correct + out = l0(inp, is_first_microbatch=True) + l1(inp, is_first_microbatch=True) + + # Chain should still be wired even with prefetch disabled + assert l0.weight.next_w is l1.weight + assert torch.isfinite(out).all(), "Non-finite output with prefetch disabled" + finally: + gtp_module.update_config(weight_prefetch=True) + dist.destroy_process_group() + + +class TestGTPPrefetchDisabled: + def test_forward_works_without_prefetch(self): + _requires_multi_gpu(4) + _run_distributed(_worker_prefetch_disabled, 4) + + +# --------------------------------------------------------------------------- +# 22. fuse_wgrad_accumulation=True: wgrad is accumulated into main_grad +# --------------------------------------------------------------------------- + +def _worker_fuse_wgrad(rank, world_size, port): + _dist_init(rank, world_size, port) + torch.manual_seed(0) + in_f, out_f = 32, 128 # out_f % (16*world_size)==0, no padding + dtype = torch.bfloat16 + gtp_group = dist.new_group(list(range(world_size))) + + layer = te.Linear( + in_features=in_f, out_features=out_f, + bias=False, params_dtype=dtype, + device="cuda", gtp_group=gtp_group, + fuse_wgrad_accumulation=True, + ) + + # Allocate main_grad on the local shard shape + w = layer.weight + w.main_grad = torch.zeros(w.shape, dtype=dtype, device="cuda") + + inp = torch.randn(8, in_f, dtype=dtype, device="cuda", requires_grad=True) + dist.broadcast(inp, src=0) + + layer(inp, is_first_microbatch=True).sum().backward() + + # With fused accumulation, wgrad was added into main_grad + assert torch.any(w.main_grad != 0), \ + "main_grad should have been updated by fused wgrad accumulation" + dist.destroy_process_group() + + +class TestFuseWgradAccumulation: + def test_wgrad_accumulated_into_main_grad(self): + _requires_multi_gpu(4) + _run_distributed(_worker_fuse_wgrad, 4) + + +# --------------------------------------------------------------------------- +# 23. _grad_accum_hook is called after reduce-scatter +# --------------------------------------------------------------------------- + +def _worker_main_grad_updated_after_bwd(rank, world_size, port): + """After backward, the wgrad RS path must have accumulated wgrad into main_grad.""" + _dist_init(rank, world_size, port) + torch.manual_seed(0) + in_f, out_f = 32, 64 + dtype = torch.bfloat16 + gtp_group = dist.new_group(list(range(world_size))) + + layer = te.Linear( + in_features=in_f, out_features=out_f, + bias=False, params_dtype=dtype, + device="cuda", gtp_group=gtp_group, + ) + + # wgrad RS path always accumulates into main_grad; allocate before backward. + layer.weight.main_grad = torch.zeros(layer.weight.shape, dtype=dtype, device="cuda") + + inp = torch.randn(8, in_f, dtype=dtype, device="cuda", requires_grad=True) + dist.broadcast(inp, src=0) + layer(inp, is_first_microbatch=True).sum().backward() + + assert torch.any(layer.weight.main_grad != 0), \ + "main_grad should have been updated after the reduce-scatter accumulation" + dist.destroy_process_group() + + +class TestGTPGradAccumHook: + def test_main_grad_updated_after_backward(self): + _requires_multi_gpu(4) + _run_distributed(_worker_main_grad_updated_after_bwd, 4) + + diff --git a/tests/pytorch/distributed/test_tp_gtp.py b/tests/pytorch/distributed/test_tp_gtp.py new file mode 100644 index 0000000000..7310f1c450 --- /dev/null +++ b/tests/pytorch/distributed/test_tp_gtp.py @@ -0,0 +1,427 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Unit tests for combined Tensor Parallelism + Generalized Tensor Parallelism (TP+GTP). + +Process group layout (world_size = tp_size × gtp_size): + + rank = gtp_rank × tp_size + tp_rank + + TP group: all ranks that share the same gtp_rank (size = tp_size) + GTP group: all ranks that share the same tp_rank (size = gtp_size) + +Test groups +----------- +1. TestTPGTPProcessGroups – verify TP/GTP group sizes and rank assignment +2. TestTPGTPColumnParallelLinear – column-parallel Linear: weight shape + fwd/bwd correctness +3. TestTPGTPRowParallelLinear – row-parallel Linear: weight shape + fwd/bwd smoke test +4. TestTPGTPLayerNormLinear – LayerNormLinear column-parallel smoke test +5. TestTPGTPLayerNormMLP – LayerNormMLP (column FC1 + row FC2) smoke test + +Tests use (tp_size, gtp_size) = (2, 2) → world_size = 4 (runs on 4-GPU machines). + +Multi-GPU tests use torch.multiprocessing.spawn and skip automatically when fewer GPUs are +available. +""" + +import os +import socket + +import pytest +import torch +import torch.distributed as dist +import torch.multiprocessing as mp + +import transformer_engine.pytorch as te +from transformer_engine.pytorch.module.generalized_tensor_parallelism import GTPShardedParam +from transformer_engine.pytorch.quantization import FP8GlobalStateManager + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + +@pytest.fixture(autouse=True) +def reset_fp8_state(): + yield + FP8GlobalStateManager.reset() + + +@pytest.fixture(autouse=True) +def reset_gtp_globals(): + """Reset GTP mutable class/module-level state between tests.""" + yield + GTPShardedParam._first_weight_flag = True + GTPShardedParam._pending_rs_weight = None + GTPShardedParam._chain_state = {} + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _free_port() -> int: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("", 0)) + return s.getsockname()[1] + + +def _dist_init(rank: int, world_size: int, port: int) -> None: + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = str(port) + dist.init_process_group(backend="nccl", rank=rank, world_size=world_size) + torch.cuda.set_device(rank) + + +def _run_distributed(fn, world_size: int, *args) -> None: + """Spawn `world_size` processes, each running fn(rank, world_size, port, *args).""" + port = _free_port() + mp.spawn(fn, args=(world_size, port) + args, nprocs=world_size, join=True) + + +def _requires_multi_gpu(n: int): + if torch.cuda.device_count() < n: + pytest.skip(f"Requires at least {n} CUDA devices") + + +def _build_groups(rank: int, world_size: int, tp_size: int, gtp_size: int): + """Create TP and GTP process groups for a 2D parallelism grid. + + Layout: rank = gtp_rank × tp_size + tp_rank + TP group: contiguous block [gtp_rank*tp_size, (gtp_rank+1)*tp_size) + GTP group: strided set {tp_rank, tp_rank+tp_size, tp_rank+2*tp_size, ...} + + Every rank must call new_group for ALL groups (PyTorch distributed requirement). + + Returns: + tp_group: this rank's TP process group + gtp_group: this rank's GTP process group + tp_rank: this rank's index within its TP group + gtp_rank: this rank's index within its GTP group + """ + assert tp_size * gtp_size == world_size + tp_rank = rank % tp_size + gtp_rank = rank // tp_size + + tp_group = None + for er in range(gtp_size): + ranks = list(range(er * tp_size, (er + 1) * tp_size)) + grp = dist.new_group(ranks) + if er == gtp_rank: + tp_group = grp + + gtp_group = None + for tr in range(tp_size): + ranks = list(range(tr, world_size, tp_size)) + grp = dist.new_group(ranks) + if tr == tp_rank: + gtp_group = grp + + return tp_group, gtp_group, tp_rank, gtp_rank + + +# --------------------------------------------------------------------------- +# 1. TestTPGTPProcessGroups – group sizes and rank membership +# --------------------------------------------------------------------------- + +def _worker_groups(rank, world_size, port, tp_size, gtp_size): + _dist_init(rank, world_size, port) + tp_group, gtp_group, tp_rank, gtp_rank = _build_groups(rank, world_size, tp_size, gtp_size) + + assert tp_group.size() == tp_size, \ + f"rank {rank}: TP group size {tp_group.size()} != {tp_size}" + assert gtp_group.size() == gtp_size, \ + f"rank {rank}: GTP group size {gtp_group.size()} != {gtp_size}" + assert dist.get_rank(tp_group) == tp_rank, \ + f"rank {rank}: TP rank {dist.get_rank(tp_group)} != expected {tp_rank}" + assert dist.get_rank(gtp_group) == gtp_rank, \ + f"rank {rank}: GTP rank {dist.get_rank(gtp_group)} != expected {gtp_rank}" + + dist.destroy_process_group() + + +class TestTPGTPProcessGroups: + @pytest.mark.parametrize("tp_size,gtp_size", [(2, 2)]) + def test_group_sizes_and_ranks(self, tp_size, gtp_size): + world_size = tp_size * gtp_size + _requires_multi_gpu(world_size) + _run_distributed(_worker_groups, world_size, tp_size, gtp_size) + + +# --------------------------------------------------------------------------- +# 2. TestTPGTPColumnParallelLinear +# --------------------------------------------------------------------------- + +def _worker_column_shape(rank, world_size, port, tp_size, gtp_size): + """Column-parallel: weight shape must be [out_f/(tp_size*gtp_size), in_f].""" + _dist_init(rank, world_size, port) + tp_group, gtp_group, _, _ = _build_groups(rank, world_size, tp_size, gtp_size) + + in_f = 64 + out_f = tp_size * gtp_size * 32 # per-rank shard = 32 rows + + layer = te.Linear( + in_features=in_f, out_features=out_f, + parallel_mode="column", bias=False, params_dtype=torch.bfloat16, + device="cuda", tp_group=tp_group, gtp_group=gtp_group, + ) + + expected_rows = out_f // (tp_size * gtp_size) + assert isinstance(layer.weight, GTPShardedParam), \ + f"rank {rank}: weight should be GTPShardedParam" + assert layer.weight.shape == (expected_rows, in_f), \ + f"rank {rank}: expected ({expected_rows}, {in_f}), got {layer.weight.shape}" + + dist.destroy_process_group() + + +def _worker_column_correctness(rank, world_size, port, tp_size, gtp_size): + """Column-parallel output must equal inp @ (GTP-gathered TP-local weight)^T.""" + _dist_init(rank, world_size, port) + torch.manual_seed(0) + tp_group, gtp_group, tp_rank, gtp_rank = _build_groups(rank, world_size, tp_size, gtp_size) + + batch, in_f = 16, 64 + out_f = tp_size * gtp_size * 32 # per-rank shard = 32 rows + dtype = torch.bfloat16 + + layer = te.Linear( + in_features=in_f, out_features=out_f, + parallel_mode="column", bias=False, params_dtype=dtype, + device="cuda", tp_group=tp_group, gtp_group=gtp_group, + ) + + # All-gather GTP shards → TP-local full weight [out_f/tp_size, in_f] + shard = layer.weight.data.clone() + all_gtp_shards = [torch.zeros_like(shard) for _ in range(gtp_size)] + dist.all_gather(all_gtp_shards, shard, group=gtp_group) + tp_local_weight = torch.cat(all_gtp_shards, dim=0).float() # strip padding + tp_local_weight = tp_local_weight[:out_f // tp_size] + + # Same full input on all ranks (column-parallel: each rank processes full input) + inp = torch.randn(batch, in_f, dtype=dtype, device="cuda") + dist.broadcast(inp, src=0) + inp_te = inp.clone().requires_grad_(True) + + # TE forward: GTP all-gathers weight internally; no TP comm in column-parallel fwd + out = layer(inp_te, is_first_microbatch=True) + assert out.shape == (batch, out_f // tp_size), \ + f"rank {rank}: output shape {out.shape} != ({batch}, {out_f // tp_size})" + + # Reference: this TP rank's output = inp @ tp_local_weight^T + ref = inp.float() @ tp_local_weight.T + ref = ref.to(dtype) + assert torch.allclose(out.float(), ref.float(), atol=0.1, rtol=0.1), ( + f"rank {rank}: output mismatch, " + f"max_diff={(out.float() - ref.float()).abs().max():.4f}" + ) + + # Backward: dX is all-reduced across TP group internally by TE + grad = torch.randn_like(out) + dist.broadcast(grad, src=0) + # wgrad RS path always accumulates into main_grad; allocate before backward. + layer.weight.main_grad = torch.zeros(layer.weight.shape, dtype=dtype, device="cuda") + out.backward(grad) + assert inp_te.grad is not None and inp_te.grad.shape == inp.shape + assert torch.isfinite(inp_te.grad).all(), f"rank {rank}: non-finite dX" + + dist.destroy_process_group() + + +class TestTPGTPColumnParallelLinear: + @pytest.mark.parametrize("tp_size,gtp_size", [(2, 2)]) + def test_weight_shape(self, tp_size, gtp_size): + world_size = tp_size * gtp_size + _requires_multi_gpu(world_size) + _run_distributed(_worker_column_shape, world_size, tp_size, gtp_size) + + @pytest.mark.parametrize("tp_size,gtp_size", [(2, 2)]) + def test_forward_backward_correctness(self, tp_size, gtp_size): + world_size = tp_size * gtp_size + _requires_multi_gpu(world_size) + _run_distributed(_worker_column_correctness, world_size, tp_size, gtp_size) + + +# --------------------------------------------------------------------------- +# 3. TestTPGTPRowParallelLinear +# --------------------------------------------------------------------------- + +def _worker_row_shape(rank, world_size, port, tp_size, gtp_size): + """Row-parallel: weight shape must be [out_f/gtp_size, in_f/tp_size].""" + _dist_init(rank, world_size, port) + tp_group, gtp_group, _, _ = _build_groups(rank, world_size, tp_size, gtp_size) + + in_f = tp_size * 64 # TE divides by tp_size → local in_f = 64 + out_f = gtp_size * 64 # GTP divides by gtp_size → local out_f = 64 + + layer = te.Linear( + in_features=in_f, out_features=out_f, + parallel_mode="row", bias=False, params_dtype=torch.bfloat16, + device="cuda", tp_group=tp_group, gtp_group=gtp_group, + ) + + expected_shape = (out_f // gtp_size, in_f // tp_size) + assert isinstance(layer.weight, GTPShardedParam), \ + f"rank {rank}: weight should be GTPShardedParam" + assert layer.weight.shape == expected_shape, \ + f"rank {rank}: expected {expected_shape}, got {layer.weight.shape}" + + dist.destroy_process_group() + + +def _worker_row_forward_backward(rank, world_size, port, tp_size, gtp_size): + """Row-parallel: output is all-reduced [batch, out_f]; backward produces finite dX.""" + _dist_init(rank, world_size, port) + torch.manual_seed(0) + tp_group, gtp_group, tp_rank, _ = _build_groups(rank, world_size, tp_size, gtp_size) + + batch = 16 + in_f = tp_size * 64 # full in_features + out_f = gtp_size * 64 # full out_features + dtype = torch.bfloat16 + + layer = te.Linear( + in_features=in_f, out_features=out_f, + parallel_mode="row", bias=False, params_dtype=dtype, + device="cuda", tp_group=tp_group, gtp_group=gtp_group, + ) + + # Row-parallel: each TP rank takes the corresponding slice of in_f + full_inp = torch.randn(batch, in_f, dtype=dtype, device="cuda") + dist.broadcast(full_inp, src=0) + local_in_f = in_f // tp_size + inp = full_inp[:, tp_rank * local_in_f : (tp_rank + 1) * local_in_f] + inp = inp.clone().requires_grad_(True) + + # TE forward: GTP all-gathers weight, row-parallel all-reduces output across TP + out = layer(inp, is_first_microbatch=True) + assert out.shape == (batch, out_f), \ + f"rank {rank}: output shape {out.shape} != ({batch}, {out_f})" + assert torch.isfinite(out).all(), f"rank {rank}: non-finite output" + + # wgrad RS path always accumulates into main_grad; allocate before backward. + layer.weight.main_grad = torch.zeros(layer.weight.shape, dtype=dtype, device="cuda") + out.sum().backward() + assert inp.grad is not None and inp.grad.shape == inp.shape + assert torch.isfinite(inp.grad).all(), f"rank {rank}: non-finite dX" + + dist.destroy_process_group() + + +def _worker_row_correctness(rank, world_size, port, tp_size, gtp_size): + """Row-parallel all-reduced output must equal inp_full @ full_weight^T.""" + _dist_init(rank, world_size, port) + torch.manual_seed(0) + tp_group, gtp_group, tp_rank, gtp_rank = _build_groups(rank, world_size, tp_size, gtp_size) + + batch = 16 + in_f = tp_size * 64 + out_f = gtp_size * 64 + dtype = torch.bfloat16 + + layer = te.Linear( + in_features=in_f, out_features=out_f, + parallel_mode="row", bias=False, params_dtype=dtype, + device="cuda", tp_group=tp_group, gtp_group=gtp_group, + ) + + # Reconstruct full weight: all-gather GTP shards → TP-local, then all-gather TP shards + shard = layer.weight.data.clone() + all_gtp_shards = [torch.zeros_like(shard) for _ in range(gtp_size)] + dist.all_gather(all_gtp_shards, shard, group=gtp_group) + tp_local_weight = torch.cat(all_gtp_shards, dim=0).float() # [out_f, in_f/tp_size] + + all_tp_weights = [torch.zeros_like(tp_local_weight) for _ in range(tp_size)] + dist.all_gather(all_tp_weights, tp_local_weight, group=tp_group) + full_weight = torch.cat(all_tp_weights, dim=1).float() # [out_f, in_f] + + # Full input (same on all ranks; we slice below to simulate row-parallel) + full_inp = torch.randn(batch, in_f, dtype=dtype, device="cuda") + dist.broadcast(full_inp, src=0) + local_in_f = in_f // tp_size + inp = full_inp[:, tp_rank * local_in_f : (tp_rank + 1) * local_in_f].clone() + inp.requires_grad_(True) + + out = layer(inp, is_first_microbatch=True) + + # Reference: full input @ full weight^T — all ranks should see the same output + ref = full_inp.float() @ full_weight.T + ref = ref.to(dtype) + assert torch.allclose(out.float(), ref.float(), atol=0.1, rtol=0.1), ( + f"rank {rank}: output mismatch, " + f"max_diff={(out.float() - ref.float()).abs().max():.4f}" + ) + + dist.destroy_process_group() + + +class TestTPGTPRowParallelLinear: + @pytest.mark.parametrize("tp_size,gtp_size", [(2, 2)]) + def test_weight_shape(self, tp_size, gtp_size): + world_size = tp_size * gtp_size + _requires_multi_gpu(world_size) + _run_distributed(_worker_row_shape, world_size, tp_size, gtp_size) + + @pytest.mark.parametrize("tp_size,gtp_size", [(2, 2)]) + def test_forward_backward(self, tp_size, gtp_size): + world_size = tp_size * gtp_size + _requires_multi_gpu(world_size) + _run_distributed(_worker_row_forward_backward, world_size, tp_size, gtp_size) + + @pytest.mark.parametrize("tp_size,gtp_size", [(2, 2)]) + def test_forward_correctness(self, tp_size, gtp_size): + world_size = tp_size * gtp_size + _requires_multi_gpu(world_size) + _run_distributed(_worker_row_correctness, world_size, tp_size, gtp_size) + + +# --------------------------------------------------------------------------- +# 4. TestTPGTPLayerNormLinear – column-parallel smoke test +# --------------------------------------------------------------------------- + +def _worker_layernorm_linear(rank, world_size, port, tp_size, gtp_size): + _dist_init(rank, world_size, port) + torch.manual_seed(0) + tp_group, gtp_group, _, _ = _build_groups(rank, world_size, tp_size, gtp_size) + + seq, batch = 4, 2 + in_f = 64 + out_f = tp_size * gtp_size * 32 + dtype = torch.bfloat16 + + layer = te.LayerNormLinear( + in_features=in_f, out_features=out_f, + bias=False, params_dtype=dtype, + parallel_mode="column", + device="cuda", tp_group=tp_group, gtp_group=gtp_group, + ) + assert isinstance(layer.weight, GTPShardedParam), \ + f"rank {rank}: LayerNormLinear.weight should be GTPShardedParam" + expected_rows = out_f // (tp_size * gtp_size) + assert layer.weight.shape == (expected_rows, in_f), \ + f"rank {rank}: unexpected weight shape {layer.weight.shape}" + + inp = torch.randn(seq, batch, in_f, dtype=dtype, device="cuda", requires_grad=True) + dist.broadcast(inp, src=0) + + out = layer(inp, is_first_microbatch=True) + assert out.shape == (seq, batch, out_f // tp_size), \ + f"rank {rank}: output shape {out.shape}" + assert torch.isfinite(out).all(), f"rank {rank}: non-finite output" + + # wgrad RS path always accumulates into main_grad; allocate before backward. + layer.weight.main_grad = torch.zeros(layer.weight.shape, dtype=dtype, device="cuda") + out.sum().backward() + assert inp.grad is not None and inp.grad.shape == inp.shape + assert torch.isfinite(inp.grad).all(), f"rank {rank}: non-finite dX" + + dist.destroy_process_group() + + +class TestTPGTPLayerNormLinear: + @pytest.mark.parametrize("tp_size,gtp_size", [(2, 2)]) + def test_forward_backward(self, tp_size, gtp_size): + world_size = tp_size * gtp_size + _requires_multi_gpu(world_size) + _run_distributed(_worker_layernorm_linear, world_size, tp_size, gtp_size) diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index 030023d949..0e651a1a0c 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -207,6 +207,7 @@ list(APPEND transformer_engine_cuda_sources recipe/current_scaling.cu recipe/delayed_scaling.cu recipe/fp8_block_scaling.cu + recipe/multi_amax.cu comm_gemm_overlap/userbuffers/userbuffers.cu) list(APPEND transformer_engine_cuda_arch_specific_sources diff --git a/transformer_engine/common/include/transformer_engine/recipe.h b/transformer_engine/common/include/transformer_engine/recipe.h index cad27a2992..06a37c1800 100644 --- a/transformer_engine/common/include/transformer_engine/recipe.h +++ b/transformer_engine/common/include/transformer_engine/recipe.h @@ -99,6 +99,26 @@ void nvte_compute_amax(const NVTETensor input, NVTETensor output, cudaStream_t s void nvte_compute_amax_with_config(const NVTETensor input, NVTETensor output, const NVTEQuantizationConfig config, cudaStream_t stream); +/*! \brief Compute amax for a list of independent tensors in a single kernel launch. + * + * Unlike nvte_group_amax (which requires a single contiguous input split along dim 0), + * this API accepts arrays of independent input tensors, each with its own allocation. + * Designed for the GTP grouped-experts case where per-expert weights live in separate + * buffers. For each i in [0, num_tensors), computes amax(inputs[i]) and writes it to + * outputs[i]'s amax buffer. outputs[i] must be an FP8 per-tensor scaling or NVFP4 1D + * scaling tensor. All inputs must share the same dtype. If the list exceeds the + * per-launch batch capacity, it is internally chunked. + * + * \param[in] inputs Array of input tensors (unquantized). Size num_tensors. + * \param[in,out] outputs Array of output tensors. Only the amax is updated. + * Size num_tensors. + * \param[in] num_tensors Number of tensors. + * \param[in] config Quantization configuration (for noop_tensor). May be NULL. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_multi_compute_amax(const NVTETensor *inputs, NVTETensor *outputs, size_t num_tensors, + const NVTEQuantizationConfig config, cudaStream_t stream); + /*! \brief Update an FP8 tensor's scale based on its amax. * * This is only supported for FP8 tensors with per-tensor scaling. diff --git a/transformer_engine/common/recipe/multi_amax.cu b/transformer_engine/common/recipe/multi_amax.cu new file mode 100644 index 0000000000..5420dde587 --- /dev/null +++ b/transformer_engine/common/recipe/multi_amax.cu @@ -0,0 +1,274 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include + +#include +#include + +#include "../common.h" +#include "../util/logging.h" +#include "../util/vectorized_pointwise.h" +#include "recipe_common.cuh" + +namespace transformer_engine { +namespace { + +constexpr int multi_amax_kernel_threads = 512; +// Per-launch capacity. kMaxTensorsPerBatch * ~40 bytes per slot keeps the args +// struct within the 4KB kernel parameter limit with comfortable headroom. +constexpr int kMaxTensorsPerBatch = 64; + +struct MultiAmaxArgs { + const void *input_list[kMaxTensorsPerBatch]; + void *output_rowwise_amax_list[kMaxTensorsPerBatch]; + void *output_columnwise_amax_list[kMaxTensorsPerBatch]; + size_t input_numel[kMaxTensorsPerBatch]; + size_t num_aligned_elements[kMaxTensorsPerBatch]; + int num_tensors; +}; + +// Zero out every output amax slot (rowwise + columnwise, deduped) in a single launch. +// Respects the noop_ptr contract shared with the single-tensor amax path. +__launch_bounds__(multi_amax_kernel_threads) __global__ + void MultiZeroAmaxKernel(MultiAmaxArgs args, const float *noop_ptr) { + if (noop_ptr != nullptr && noop_ptr[0] == 1.0f) { + return; + } + int tid = blockIdx.x * blockDim.x + threadIdx.x; + int stride = blockDim.x * gridDim.x; + for (; tid < args.num_tensors; tid += stride) { + float *rw = static_cast(args.output_rowwise_amax_list[tid]); + float *cw = static_cast(args.output_columnwise_amax_list[tid]); + if (rw != nullptr) { + *rw = 0.0f; + } + if (cw != nullptr && cw != rw) { + *cw = 0.0f; + } + } +} + +// Per-tensor amax with one block-strip per tensor. blockIdx.y selects the +// tensor; blockIdx.x is the work chunk within that tensor. Each block +// vector-loads the tensor, reduces across threads, and atomicMaxFloats the +// result into BOTH output amax slots (rowwise + columnwise, deduped). This +// subsumes the per-expert D2D copy that the single-tensor path does after the +// amax kernel. +template +__launch_bounds__(multi_amax_kernel_threads) __global__ + void MultiAmaxKernel(MultiAmaxArgs args, const float *noop_ptr) { + if (noop_ptr != nullptr && noop_ptr[0] == 1.0f) { + return; + } + + const int t_idx = blockIdx.y; + if (t_idx >= args.num_tensors) { + return; + } + + const InputType *input = static_cast(args.input_list[t_idx]); + const size_t N = args.input_numel[t_idx]; + if (N == 0) { + return; + } + const size_t M = args.num_aligned_elements[t_idx]; + + VectorizedLoader loader(input, N); + InputType max = InputType{0.f}; + const int warp_id = threadIdx.x / THREADS_PER_WARP; + + for (size_t tid = blockIdx.x * blockDim.x + threadIdx.x; tid < M; + tid += gridDim.x * blockDim.x) { + loader.load(tid, N); +#pragma unroll + for (int i = 0; i < nvec; ++i) { + const InputType val = static_cast(loader.separate()[i]); + __builtin_assume(max >= InputType{0.f}); + if constexpr (std::is_same_v) { +#if __CUDA_ARCH__ >= 800 + max = __hmax(__habs(val), max); +#else + max = static_cast<__nv_bfloat16>( + fmaxf(fabsf(static_cast(val)), static_cast(max))); +#endif + } else if constexpr (std::is_same_v) { + max = __hmax(__habs(val), max); + } else { + max = fmaxf(fabsf(val), max); + } + } + } + + // Reduce amax over block. + max = reduce_max(max, warp_id); + if (threadIdx.x == 0) { + float *rw = static_cast(args.output_rowwise_amax_list[t_idx]); + float *cw = static_cast(args.output_columnwise_amax_list[t_idx]); + if (rw != nullptr) { + atomicMaxFloat(rw, static_cast(max)); + } + if (cw != nullptr && cw != rw) { + atomicMaxFloat(cw, static_cast(max)); + } + } +} + +template +void launch_multi_amax_batch(const MultiAmaxArgs &args, size_t max_numel, Alignment align, + const float *noop_ptr, cudaStream_t stream) { + // Zero all amax outputs in one launch. + { + constexpr int threads = multi_amax_kernel_threads; + const int num_blocks = std::max(1, DIVUP(args.num_tensors, threads)); + MultiZeroAmaxKernel<<>>(args, noop_ptr); + NVTE_CHECK_CUDA(cudaGetLastError()); + } + + if (max_numel == 0) { + return; + } + + // Grid: y = tensor index, x = work chunks within the largest tensor. Blocks + // that exceed a shorter tensor's aligned element count bail out via the + // bounds check inside the kernel. + constexpr int nvec = 32 / sizeof(InputType); + constexpr size_t threads = multi_amax_kernel_threads; + const size_t max_aligned = (max_numel + nvec - 1) / nvec; + size_t num_blocks_x = DIVUP(max_aligned, threads); + constexpr size_t max_blocks = 65535; + num_blocks_x = std::min(num_blocks_x, max_blocks); + num_blocks_x = std::max(num_blocks_x, 1); + dim3 grid(num_blocks_x, static_cast(args.num_tensors), 1); + + switch (align) { + case Alignment::SAME_ALIGNED: + MultiAmaxKernel + <<>>(args, noop_ptr); + break; + case Alignment::SAME_UNALIGNED: + MultiAmaxKernel + <<>>(args, noop_ptr); + break; + case Alignment::DIFFERENT: + // Heterogeneous alignment across tensors — fall back to nvec=1, aligned=true path + // which is safe for any pointer alignment. + MultiAmaxKernel<1, true, InputType> + <<>>(args, noop_ptr); + break; + } + NVTE_CHECK_CUDA(cudaGetLastError()); +} + +// Fill one MultiAmaxArgs batch from a slice of the full input/output list. +// Returns (max_numel in this batch, worst-case alignment across the batch). +template +std::pair build_batch_args(const std::vector &inputs, + const std::vector &outputs, size_t start, + size_t count, MultiAmaxArgs &args) { + constexpr int nvec = 32 / sizeof(InputType); + size_t max_numel = 0; + // SAME_ALIGNED is the most optimistic; degrade to SAME_UNALIGNED if any + // tensor is merely same-layout but unaligned, to DIFFERENT if alignment + // varies across tensors. + Alignment batch_align = Alignment::SAME_ALIGNED; + for (size_t i = 0; i < count; ++i) { + const Tensor &inp = *inputs[start + i]; + Tensor &out = *outputs[start + i]; + const size_t N = inp.data.numel(); + void *rw_ptr = out.amax.dptr; + void *cw_ptr = out.columnwise_amax.dptr; + + args.input_list[i] = inp.data.dptr; + args.output_rowwise_amax_list[i] = rw_ptr; + args.output_columnwise_amax_list[i] = cw_ptr; + args.input_numel[i] = N; + args.num_aligned_elements[i] = get_num_aligned_elements(inp.data.dptr, N, nvec, + sizeof(InputType)); + max_numel = std::max(max_numel, N); + + // Fold this tensor's alignment into the batch decision. CheckAlignment on a + // single pointer yields SAME_ALIGNED or SAME_UNALIGNED; mixing the two across + // tensors means heterogeneous — switch to the DIFFERENT fall-back. + if (N > 0) { + Alignment a = CheckAlignment(N, nvec, static_cast(inp.data.dptr)); + if (batch_align == Alignment::SAME_ALIGNED && a == Alignment::SAME_UNALIGNED) { + batch_align = Alignment::SAME_UNALIGNED; + } else if (batch_align == Alignment::SAME_UNALIGNED && a == Alignment::SAME_ALIGNED) { + batch_align = Alignment::SAME_UNALIGNED; + } else if (a == Alignment::DIFFERENT) { + batch_align = Alignment::DIFFERENT; + } + } + } + args.num_tensors = static_cast(count); + return {max_numel, batch_align}; +} + +void multi_compute_amax_impl(const NVTETensor *inputs_, NVTETensor *outputs_, size_t num_tensors, + const NVTEQuantizationConfig config_, cudaStream_t stream) { + if (num_tensors == 0) { + return; + } + NVTE_CHECK(inputs_ != nullptr, "nvte_multi_compute_amax: inputs is NULL"); + NVTE_CHECK(outputs_ != nullptr, "nvte_multi_compute_amax: outputs is NULL"); + + // Convert, validate, collect into plain vectors. + std::vector inputs(num_tensors); + std::vector outputs(num_tensors); + DType input_dtype; + for (size_t i = 0; i < num_tensors; ++i) { + inputs[i] = convertNVTETensorCheck(inputs_[i]); + outputs[i] = convertNVTETensorCheck(outputs_[i]); + const auto &inp = *inputs[i]; + auto &out = *outputs[i]; + NVTE_CHECK(inp.scaling_mode == NVTE_DELAYED_TENSOR_SCALING, + "nvte_multi_compute_amax: input[", i, + "] must be unquantized, got scaling_mode=", to_string(inp.scaling_mode)); + NVTE_CHECK(!is_fp8_dtype(inp.data.dtype), + "nvte_multi_compute_amax: input[", i, + "] must be unquantized, got dtype=", to_string(inp.data.dtype)); + if (i == 0) { + input_dtype = inp.data.dtype; + } else { + NVTE_CHECK(inp.data.dtype == input_dtype, + "nvte_multi_compute_amax: all inputs must share dtype; input[0]=", + to_string(input_dtype), ", input[", i, "]=", to_string(inp.data.dtype)); + } + NVTE_CHECK(out.scaling_mode == NVTE_DELAYED_TENSOR_SCALING || + out.scaling_mode == NVTE_NVFP4_1D_SCALING, + "nvte_multi_compute_amax: output[", i, "] must be FP8 per-tensor or NVFP4 1D"); + NVTE_CHECK(out.amax.dptr != nullptr || out.columnwise_amax.dptr != nullptr, + "nvte_multi_compute_amax: output[", i, "] has no amax buffer"); + } + + const float *noop_ptr = nullptr; + if (config_ != nullptr) { + const QuantizationConfig *config_cpp = reinterpret_cast(config_); + const NVTETensor noop = config_cpp->noop_tensor; + noop_ptr = reinterpret_cast( + (noop != nullptr ? convertNVTETensorCheck(noop)->data.dptr : nullptr)); + } + + // Chunk across kMaxTensorsPerBatch launches (single launch in the common 8-expert case). + TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(input_dtype, IType, { + for (size_t start = 0; start < num_tensors; start += kMaxTensorsPerBatch) { + const size_t count = std::min(kMaxTensorsPerBatch, num_tensors - start); + MultiAmaxArgs args = {}; + auto [max_numel, batch_align] = build_batch_args(inputs, outputs, start, count, args); + launch_multi_amax_batch(args, max_numel, batch_align, noop_ptr, stream); + } + }); // NOLINT(*) +} + +} // anonymous namespace +} // namespace transformer_engine + +void nvte_multi_compute_amax(const NVTETensor *inputs, NVTETensor *outputs, size_t num_tensors, + const NVTEQuantizationConfig config, cudaStream_t stream) { + NVTE_API_CALL(nvte_multi_compute_amax); + transformer_engine::multi_compute_amax_impl(inputs, outputs, num_tensors, config, stream); +} diff --git a/transformer_engine/pytorch/csrc/common.h b/transformer_engine/pytorch/csrc/common.h index 94350da1e6..d95ce96af7 100644 --- a/transformer_engine/pytorch/csrc/common.h +++ b/transformer_engine/pytorch/csrc/common.h @@ -369,11 +369,30 @@ class NVFP4Quantizer : public Quantizer { */ void quantize_with_amax(TensorWrapper& input, TensorWrapper& out); + /*! @brief Compute (and D2D fill) local amax only — no cast, no allreduce. + * + * Writes the local amax into out's rowwise and/or columnwise amax + * buffers. Callers are expected to perform a coalesced allreduce + * across the amax reduction group afterwards, then invoke + * quantize_cast_only to finish the cast with the reduced amax. + */ + void compute_amax_only(const TensorWrapper& input, TensorWrapper& out); + + /*! @brief Cast to NVFP4 assuming amax already reduced externally. + * + * Skips both local amax compute and the internal amax allreduce. + * Callers must guarantee out's amax buffers already hold the reduced + * amax (e.g. via compute_amax_only + allreduce_coalesced). + */ + void quantize_cast_only(const TensorWrapper& input, TensorWrapper& out, + const std::optional& noop_flag = std::nullopt); + std::vector get_scale_shape(const std::vector& shape, bool columnwise) const; private: void quantize_impl(const TensorWrapper& input, TensorWrapper& out, - const std::optional& noop_flag, bool compute_amax); + const std::optional& noop_flag, bool compute_amax, + bool skip_amax_reduction = false); void quantize_with_rht_unfused_helper(const TensorWrapper& input, TensorWrapper& out, TensorWrapper& rht_output_t_cpp, QuantizationConfigWrapper& quant_config, diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 8082ff07ed..6a29c3adb3 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -329,6 +329,21 @@ py::object create_empty_quantized_tensor(py::handle quantizer, const std::vector py::object quantize(const at::Tensor &tensor, py::handle quantizer, const py::object &output, std::optional noop_flag); +// NVFP4-only split-phase quantize: compute amax, coalesce allreduce externally, then cast. +py::object compute_amax_nvfp4(const at::Tensor &tensor, py::handle quantizer, + const py::object &output); +py::object quantize_cast_only_nvfp4(const at::Tensor &tensor, py::handle quantizer, + const py::object &output, + std::optional noop_flag); + +// NVFP4-only multi-tensor amax: fuses N per-expert (zero_amax + amax + D2D replicate) +// chains into a single pair of kernel launches (one multi-zero + one multi-amax) that +// writes amax into every output's rowwise AND columnwise buffers. Outputs must be +// pre-allocated; amax is written in place, no return. +void compute_multi_amax_nvfp4(const std::vector &tensor_list, + std::vector quantizer_list, + const std::vector &output_list); + py::object dequantize(const py::handle &input, DType otype); py::object group_quantize(const at::Tensor &tensor, py::handle quantizer, const size_t num_tensors, diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cpp b/transformer_engine/pytorch/csrc/extensions/cast.cpp index 2b38339d67..10a507b194 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cpp +++ b/transformer_engine/pytorch/csrc/extensions/cast.cpp @@ -65,6 +65,148 @@ py::object quantize(const at::Tensor &tensor, py::handle quantizer, const py::ob return output_py; } +/*! @brief NVFP4-only: compute local amax into `output`'s amax buffers, no cast, no allreduce. + * + * Pair with an external coalesced allreduce of the returned amax tensors, + * then call `quantize_cast_only_nvfp4` to finish the cast. + */ +py::object compute_amax_nvfp4(const at::Tensor &tensor, py::handle quantizer, + const py::object &output) { + NVTE_CHECK(detail::IsNVFP4Quantizers(quantizer.ptr()), + "compute_amax_nvfp4 requires an NVFP4Quantizer"); + auto quantizer_cpp = convert_quantizer(quantizer); + auto *nvfp4_quantizer = dynamic_cast(quantizer_cpp.get()); + NVTE_CHECK(nvfp4_quantizer != nullptr, "Failed to cast quantizer to NVFP4Quantizer"); + + auto input_contiguous = tensor.contiguous(); + auto input_cpp = makeTransformerEngineTensor(input_contiguous); + + TensorWrapper output_cpp; + py::object output_py; + if (output.is_none()) { + const auto shape = get_tensor_shape(input_cpp); + const auto fake_dtype = input_cpp.dtype(); + std::tie(output_cpp, output_py) = quantizer_cpp->create_tensor(shape, fake_dtype); + } else { + std::tie(output_cpp, output_py) = quantizer_cpp->convert_and_update_tensor(output); + } + + nvfp4_quantizer->compute_amax_only(input_cpp, output_cpp); + return output_py; +} + +/*! @brief NVFP4-only: cast to FP4 using pre-reduced amax in `output`'s amax buffers. + * + * Skips both local amax compute and the internal allreduce. Caller must have + * already populated `output`'s amax via compute_amax_nvfp4 + coalesced allreduce. + */ +py::object quantize_cast_only_nvfp4(const at::Tensor &tensor, py::handle quantizer, + const py::object &output, + std::optional noop_flag) { + NVTE_CHECK(detail::IsNVFP4Quantizers(quantizer.ptr()), + "quantize_cast_only_nvfp4 requires an NVFP4Quantizer"); + auto quantizer_cpp = convert_quantizer(quantizer); + auto *nvfp4_quantizer = dynamic_cast(quantizer_cpp.get()); + NVTE_CHECK(nvfp4_quantizer != nullptr, "Failed to cast quantizer to NVFP4Quantizer"); + + auto input_contiguous = tensor.contiguous(); + auto input_cpp = makeTransformerEngineTensor(input_contiguous); + + TensorWrapper output_cpp; + py::object output_py; + if (output.is_none()) { + const auto shape = get_tensor_shape(input_cpp); + const auto fake_dtype = input_cpp.dtype(); + std::tie(output_cpp, output_py) = quantizer_cpp->create_tensor(shape, fake_dtype); + } else { + std::tie(output_cpp, output_py) = quantizer_cpp->convert_and_update_tensor(output); + } + + std::optional noop_flag_cpp; + if (noop_flag.has_value()) { + noop_flag_cpp = makeTransformerEngineTensor(*noop_flag); + } + + nvfp4_quantizer->quantize_cast_only(input_cpp, output_cpp, noop_flag_cpp); + return output_py; +} + +/*! @brief NVFP4-only: compute amax for N input tensors in a single launch. + * + * Each output's rowwise AND columnwise amax buffers are populated directly by the + * kernel (atomicMaxFloat), fusing the per-expert zero_amax + amax_kernel + D2D + * replicate chain into two multi-tensor launches. Caller pairs this with an + * external coalesced allreduce and then N calls to quantize_cast_only_nvfp4. + * + * Amax is written into the outputs passed in via output_list; no return value is + * needed — caller already holds references to those objects. + */ +void compute_multi_amax_nvfp4(const std::vector &tensor_list, + std::vector quantizer_list, + const std::vector &output_list) { + const size_t num_tensors = tensor_list.size(); + NVTE_CHECK(num_tensors > 0, "compute_multi_amax_nvfp4 requires at least one tensor"); + NVTE_CHECK(quantizer_list.size() == num_tensors, + "compute_multi_amax_nvfp4: quantizer_list size mismatch"); + NVTE_CHECK(output_list.size() == num_tensors, + "compute_multi_amax_nvfp4: output_list size mismatch"); + + // Locals held for the duration of this call (destroyed at function return). + // TensorWrappers only hold NVTETensor handles (opaque indexes into a global pool + // released by ~TensorWrapper); they do NOT reference quantizer_cpp or py::object, + // so we do not need to preserve quantizer unique_ptrs past this scope. + std::vector input_contiguous; + input_contiguous.reserve(num_tensors); + std::vector input_wrappers; + input_wrappers.reserve(num_tensors); + std::vector output_wrappers; + output_wrappers.reserve(num_tensors); + + std::vector inputs_nvte; + std::vector outputs_nvte; + inputs_nvte.reserve(num_tensors); + outputs_nvte.reserve(num_tensors); + + for (size_t i = 0; i < num_tensors; ++i) { + NVTE_CHECK(detail::IsNVFP4Quantizers(quantizer_list[i].ptr()), + "compute_multi_amax_nvfp4: quantizer[", i, "] is not an NVFP4Quantizer"); + auto quantizer_cpp = convert_quantizer(quantizer_list[i]); + auto *nvfp4_quantizer = dynamic_cast(quantizer_cpp.get()); + NVTE_CHECK(nvfp4_quantizer != nullptr && !nvfp4_quantizer->with_rht, + "compute_multi_amax_nvfp4 requires NVFP4Quantizer with with_rht=false (idx=", i, + ")"); + + input_contiguous.emplace_back(tensor_list[i].contiguous()); + input_wrappers.emplace_back(makeTransformerEngineTensor(input_contiguous.back())); + + TensorWrapper out_cpp; + py::object out_py; + NVTE_CHECK(!output_list[i].is_none(), + "compute_multi_amax_nvfp4: output_list[", i, "] is None; caller must pre-allocate"); + std::tie(out_cpp, out_py) = quantizer_cpp->convert_and_update_tensor(output_list[i]); + + NVTE_CHECK(out_cpp.get_amax().data_ptr != nullptr || + out_cpp.get_columnwise_amax().data_ptr != nullptr, + "compute_multi_amax_nvfp4: output[", i, "] has no amax buffer"); + + output_wrappers.emplace_back(std::move(out_cpp)); + // quantizer_cpp and out_py are released here at end-of-iteration. + + if (input_wrappers.back().numel() == 0) continue; + inputs_nvte.push_back(input_wrappers.back().data()); + outputs_nvte.push_back(output_wrappers.back().data()); + } + + if (inputs_nvte.empty()) return; + + QuantizationConfigWrapper quant_config; + auto stream = at::cuda::getCurrentCUDAStream(); + NVTE_SCOPED_GIL_RELEASE({ + nvte_multi_compute_amax(inputs_nvte.data(), outputs_nvte.data(), inputs_nvte.size(), + quant_config, stream); + }); +} + py::object create_empty_quantized_tensor(py::handle quantizer, const std::vector &shape, at::ScalarType dtype, at::Device device, bool pin_memory) { auto quantizer_cpp = convert_quantizer(quantizer); diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index a4571c64e2..c15379cae5 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -137,6 +137,16 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { NVTE_DECLARE_COMMON_PYBIND11_HANDLES(m) m.def("quantize", transformer_engine::pytorch::quantize, py::arg("tensor"), py::arg("quantizer"), py::arg("output") = py::none(), py::arg("noop") = py::none()); + m.def("compute_amax_nvfp4", transformer_engine::pytorch::compute_amax_nvfp4, + "NVFP4: compute local amax into output's amax buffers; no cast, no allreduce", + py::arg("tensor"), py::arg("quantizer"), py::arg("output") = py::none()); + m.def("quantize_cast_only_nvfp4", transformer_engine::pytorch::quantize_cast_only_nvfp4, + "NVFP4: cast using pre-reduced amax in output's amax buffers; skips amax compute and allreduce", + py::arg("tensor"), py::arg("quantizer"), py::arg("output") = py::none(), + py::arg("noop") = py::none()); + m.def("compute_multi_amax_nvfp4", transformer_engine::pytorch::compute_multi_amax_nvfp4, + "NVFP4: fused multi-tensor amax compute (writes both rowwise+columnwise amax per output)", + py::arg("tensor_list"), py::arg("quantizer_list"), py::arg("output_list")); m.def("dequantize", &transformer_engine::pytorch::dequantize, "Dequantize", py::arg("input"), py::arg("otype")); m.def("create_empty_quantized_tensor", diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index 7045995dd7..7e75c74726 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -2269,7 +2269,7 @@ void NVFP4Quantizer::quantize_with_rht_unfused_helper( void NVFP4Quantizer::quantize_impl(const TensorWrapper& input, TensorWrapper& out, const std::optional& noop_flag, - bool compute_amax) { + bool compute_amax, bool skip_amax_reduction) { // Nothing to be done if input is empty if (input.numel() == 0) { return; @@ -2399,7 +2399,7 @@ void NVFP4Quantizer::quantize_impl(const TensorWrapper& input, TensorWrapper& ou } // amax reduction - if (this->with_amax_reduction) { + if (this->with_amax_reduction && !skip_amax_reduction) { std::vector amax_tensors; // push amax tensors inside if they need to be reduced auto make_amax_tensor = [](void* data_ptr) { @@ -2499,6 +2499,54 @@ void NVFP4Quantizer::quantize_with_amax(TensorWrapper& input, TensorWrapper& out this->quantize_impl(input, out, std::nullopt, false); } +void NVFP4Quantizer::compute_amax_only(const TensorWrapper& input, TensorWrapper& out) { + // Nothing to be done if input is empty + if (input.numel() == 0) { + return; + } + + // Only the non-RHT path is supported for the split-phase API today. + // RHT path's amax depends on the RHT-rotated view, which is produced + // alongside the cast; decoupling amax from cast is not meaningful there. + NVTE_CHECK(!this->with_rht, + "NVFP4Quantizer::compute_amax_only does not support with_rht=true"); + + auto stream = at::cuda::getCurrentCUDAStream(); + + QuantizationConfigWrapper quant_config; + quant_config.set_nvfp4_2d_quantization(this->with_2d_quantization); + + // Mirror the compute-amax block of quantize_impl exactly. + auto rowwise_amax_ptr = out.get_amax().data_ptr; + auto columnwise_amax_ptr = out.get_columnwise_amax().data_ptr; + void* amax_ptr = rowwise_amax_ptr != nullptr ? rowwise_amax_ptr : columnwise_amax_ptr; + NVTE_CHECK(amax_ptr != nullptr, "Could not find amax pointer"); + + out.set_amax(amax_ptr, DType::kFloat32, std::vector{1}); + NVTE_SCOPED_GIL_RELEASE( + { nvte_compute_amax_with_config(input.data(), out.data(), quant_config, stream); }); + out.set_amax(rowwise_amax_ptr, DType::kFloat32, std::vector{1}); + + // Replicate amax into whichever of rowwise/columnwise slots were requested. + if (rowwise_amax_ptr != amax_ptr && rowwise_amax_ptr != nullptr) { + NVTE_CHECK_CUDA(cudaMemcpyAsync(rowwise_amax_ptr, amax_ptr, sizeof(float), + cudaMemcpyDeviceToDevice, stream)); + } + if (columnwise_amax_ptr != amax_ptr && columnwise_amax_ptr != nullptr) { + NVTE_CHECK_CUDA(cudaMemcpyAsync(columnwise_amax_ptr, amax_ptr, sizeof(float), + cudaMemcpyDeviceToDevice, stream)); + } +} + +void NVFP4Quantizer::quantize_cast_only(const TensorWrapper& input, TensorWrapper& out, + const std::optional& noop_flag) { + // Amax is expected to already live in out's amax buffers (e.g. from + // compute_amax_only + an external coalesced allreduce). Skip both local + // amax compute and the internal allreduce. + this->quantize_impl(input, out, noop_flag, /*compute_amax=*/false, + /*skip_amax_reduction=*/true); +} + std::vector NVFP4Quantizer::get_scale_shape(const std::vector& shape, bool columnwise) const { const auto [flat_first_dim, last_dim] = get_2d_dims(shape); diff --git a/transformer_engine/pytorch/distributed.py b/transformer_engine/pytorch/distributed.py index a0d4ac3530..8bcbb8d6c1 100644 --- a/transformer_engine/pytorch/distributed.py +++ b/transformer_engine/pytorch/distributed.py @@ -6,7 +6,7 @@ from __future__ import annotations from collections.abc import Iterable -from contextlib import contextmanager, AbstractContextManager, ContextDecorator +from contextlib import contextmanager, AbstractContextManager, ContextDecorator, nullcontext from functools import lru_cache from dataclasses import dataclass import math @@ -918,7 +918,7 @@ def fork(self, name: str = "model-parallel-rng"): def reduce_scatter_along_first_dim( - inp: torch.Tensor, tp_group: dist_group_type, async_op: bool = False + inp: torch.Tensor, tp_group: dist_group_type, async_op: bool = False, output: torch.Tensor = None ) -> Tuple[torch.Tensor, Optional[torch.distributed.Work]]: """Reduce-scatter the input tensor across model parallel group.""" world_size = get_distributed_world_size(tp_group) @@ -936,7 +936,8 @@ def reduce_scatter_along_first_dim( dim_size[0] = dim_size[0] // world_size - output = torch.empty(dim_size, dtype=inp.dtype, device=torch.cuda.current_device()) + if output is None: + output = torch.empty(dim_size, dtype=inp.dtype, device=torch.cuda.current_device()) handle = torch.distributed.reduce_scatter_tensor( output, inp.contiguous(), group=tp_group, async_op=async_op ) @@ -1280,12 +1281,16 @@ def _post_process_nvfp4_gather( handle.wait() handle = None - # Fix the interleaved transposed data from gathering along first dim. - out._columnwise_scale_inv = _swap_first_dims(columnwise_scale_inv_interleaved, world_size) - out._columnwise_data = _swap_first_dims(columnwise_data_interleaved, world_size) + # TODO + # # Fix the interleaved transposed data from gathering along first dim. + # out._columnwise_scale_inv = _swap_first_dims(columnwise_scale_inv_interleaved, world_size) + # out._columnwise_data = _swap_first_dims(columnwise_data_interleaved, world_size) + out._columnwise_scale_inv.copy_(_swap_first_dims(columnwise_scale_inv_interleaved, world_size)) + out._columnwise_data.copy_(_swap_first_dims(columnwise_data_interleaved, world_size)) - # Optionally pad the scaling inverse if needed. - out._columnwise_scale_inv = pad_columnwise_scale_inv(out._columnwise_scale_inv) + # # Optionally pad the scaling inverse if needed. + # out._columnwise_scale_inv = pad_columnwise_scale_inv(out._columnwise_scale_inv) + out._columnwise_scale_inv.copy_(pad_columnwise_scale_inv(out._columnwise_scale_inv)) @dataclass @@ -1299,17 +1304,20 @@ class _NVFP4AllGatherAsyncHandle: async_handle: torch.distributed.Work _synchronized: bool = False - def wait(self) -> None: - """Wait for the async operation to complete and post-process the tensor.""" - if self._synchronized: - return - self.async_handle.wait() + def post_process_nvfp4_gather(self) -> None: _post_process_nvfp4_gather( self.output, self.columnwise_data_interleaved, self.columnwise_scale_inv_interleaved, self.world_size, ) + + def wait(self) -> None: + """Wait for the async operation to complete and post-process the tensor.""" + if self._synchronized: + return + self.async_handle.wait() + self.post_process_nvfp4_gather() self._synchronized = True @@ -1320,6 +1328,8 @@ def _all_gather_nvfp4( async_op: bool = False, quantizer: NVFP4Quantizer, out_shape: Optional[list[int]] = None, + output_tensor = None, + grouped = False, ) -> tuple[NVFP4TensorStorage, Optional[torch.distributed.Work]]: """All-gather NVFP4 tensor along first dimension.""" @@ -1383,6 +1393,12 @@ def _all_gather_nvfp4( out = quantizer(out) return out, None + # Construct NVFP4 output tensor + if output_tensor is not None: + out = output_tensor + else: + out = quantizer.make_empty(out_shape, dtype=dtype, device=device) + # Cast input tensor to NVFP4 with required data if not isinstance(inp, NVFP4TensorStorage): inp = quantizer(inp) @@ -1395,17 +1411,19 @@ def _all_gather_nvfp4( ) inp = quantizer(inp.dequantize(dtype=dtype)) - # Construct NVFP4 output tensor - out = quantizer.make_empty(out_shape, dtype=dtype, device=device) - - # Coalesce NCCL collectives for gathering data and scale inverses. - with torch.distributed._coalescing_manager( - group=process_group, - device=device, - async_ops=async_op, - ) as gather_coalescing_manager: + if not grouped: + # Coalesce NCCL collectives for gathering data and scale inverses. + gather_coalescing_manager = torch.distributed._coalescing_manager( + group=process_group, + device=device, + async_ops=async_op, + ) + else: + gather_coalescing_manager = nullcontext() + with gather_coalescing_manager as coalesced_handle: # Gather NVFP4 data for row-wise usage + out_columnwise_data = None if quantizer.rowwise_usage: # Remove padding from NVFP4 scale-inverses @@ -1434,7 +1452,9 @@ def _all_gather_nvfp4( ) # Transfer amax to output. - out._amax_rowwise = inp._amax_rowwise + #TODO: jiemingz + # out._amax_rowwise = inp._amax_rowwise + out._amax_rowwise.copy_(inp._amax_rowwise) # Gather the transposed NVFP4 data along first dimension. Fix format later. if quantizer.columnwise_usage: @@ -1483,17 +1503,25 @@ def _all_gather_nvfp4( ) # Transfer amax to output. - out._amax_columnwise = inp._amax_columnwise + out._amax_columnwise.copy_(inp._amax_columnwise) - handle = gather_coalescing_manager if async_op else None + + handle = coalesced_handle if async_op else None # Fixes interleaved data for transposed tensor/scale inv and pads scale inv if needed. - if async_op and quantizer.columnwise_usage: - handle = _NVFP4AllGatherAsyncHandle( - out, out_columnwise_data, out_scale_inv, world_size, handle - ) - elif quantizer.columnwise_usage: - _post_process_nvfp4_gather(out, out_columnwise_data, out_scale_inv, world_size, handle) + if quantizer.columnwise_usage: + if async_op or grouped: + # Defer post-processing: either the async op hasn't completed yet, or an + # external coalescing manager owns the NCCL ops and hasn't flushed them. + inner_handle = handle if async_op else None + handle = _NVFP4AllGatherAsyncHandle( + out, out_columnwise_data, out_scale_inv, world_size, inner_handle + ) + else: + _post_process_nvfp4_gather(out, out_columnwise_data, out_scale_inv, world_size, handle) + else: + if handle is not None: + handle.output = out return out, handle @@ -1505,6 +1533,8 @@ def _all_gather_mxfp8( async_op: bool = False, quantizer: MXFP8Quantizer, out_shape: Optional[list[int]] = None, + output_tensor: torch.Tensor = None, + grouped: bool = False, ) -> tuple[MXFP8TensorStorage, Optional[torch.distributed.Work]]: """All-gather MXFP8 tensor along first dimension.""" @@ -1570,15 +1600,22 @@ def _all_gather_mxfp8( inp = quantizer(inp.dequantize(dtype=dtype)) # Construct MXFP8 output tensor - out = quantizer.make_empty(out_shape, dtype=dtype, device=device) + if output_tensor is not None: + out = output_tensor + else: + out = quantizer.make_empty(out_shape, dtype=dtype, device=device) - # Coalesce NCCL collectives - with torch.distributed._coalescing_manager( - group=process_group, - device=device, - async_ops=async_op, - ) as coalescing_manager: + if not grouped: + # Coalesce NCCL collectives for gathering data and scale inverses. + gather_coalescing_manager = torch.distributed._coalescing_manager( + group=process_group, + device=device, + async_ops=async_op, + ) + else: + gather_coalescing_manager = nullcontext() + with gather_coalescing_manager as coalesced_handle: # Gather MXFP8 data for row-wise usage if quantizer.rowwise_usage: @@ -1625,7 +1662,7 @@ def _all_gather_mxfp8( group=process_group, ) - handle = coalescing_manager if async_op else None + handle = coalesced_handle if async_op else None return out, handle @@ -1634,6 +1671,8 @@ def gather_along_first_dim( process_group: dist_group_type, async_op: bool = False, quantizer: Optional[Quantizer] = None, + output_tensor: torch.Tensor = None, + grouped: bool = False, ) -> tuple[torch.Tensor, Optional[torch.distributed.Work]]: """ All-gather tensors and concatenate along first dimension. @@ -1724,6 +1763,8 @@ def gather_along_first_dim( async_op=async_op, quantizer=quantizer, out_shape=out_shape, + output_tensor=output_tensor, + grouped=grouped, ) # NVFP4 case @@ -1738,6 +1779,8 @@ def gather_along_first_dim( async_op=async_op, quantizer=quantizer, out_shape=out_shape, + output_tensor=output_tensor, + grouped=grouped, ) # High-precision communication for quantized tensors @@ -1767,19 +1810,20 @@ def gather_along_first_dim( inp = inp.dequantize() # Communication for plain PyTorch tensors - out = torch.empty( - out_shape, - dtype=inp.dtype, - device=inp.device, - memory_format=torch.contiguous_format, - ) + if output_tensor is None: + output_tensor = torch.empty( + out_shape, + dtype=inp.dtype, + device=inp.device, + memory_format=torch.contiguous_format, + ) handle = torch.distributed.all_gather_into_tensor( - out, + output_tensor, inp.contiguous(), group=process_group, async_op=async_op, ) - return out, handle + return output_tensor, handle # Global cache to store symmetric memory tensors diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 746177ec78..3d75c5761a 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -1604,7 +1604,10 @@ def reset_parameters(self, defer_init: Optional[bool] = False) -> None: if defer_init: return - for name, param in self.named_parameters(recurse=False): + # Names of GTP-sharded weights, for GroupedLinear's post-loop finalize. + _gtp_sharded_weight_names = [] + + for idx, (name, param) in enumerate(self.named_parameters(recurse=False)): # Check if parameter is a DTensor (FSDP2) or regular tensor is_dtensor = isinstance(param, DTensor) dtensor_param = param if is_dtensor else None @@ -1626,10 +1629,32 @@ def reset_parameters(self, defer_init: Optional[bool] = False) -> None: with get_rng_state_tracker().fork(): init_fn(param) + # GTP slice (pre-quantize): shard freshly-init BF16 weight into a + # per-rank GTPShardedParam. Runs before the FP8 quantize block so + # the full-size FP8 allocation + dequantize round-trip are avoided. + # GTP-sharded params skip the quantize block; their FP8 cache lives + # in GTPShardedParam.quantized, built later by init_quantizer. + gtp_sharded = None + if not is_dtensor and getattr(self, "_gtp_group", None) is not None: + from .generalized_tensor_parallelism import ( + gtp_slice_in_reset_parameters, + ) + + gtp_sharded = gtp_slice_in_reset_parameters( + self, name, param, expert_idx=idx + ) + if gtp_sharded is not None: + param = gtp_sharded + _gtp_sharded_weight_names.append(name) + # Wrap parameters in QuantizedTensor if needed fp8_meta_index = self.param_init_meta[name].fp8_meta_index high_precision_init_val = None - if self.primary_weights_in_fp8 and fp8_meta_index is not None: + if ( + self.primary_weights_in_fp8 + and fp8_meta_index is not None + and gtp_sharded is None + ): # Keep high-precision values on CPU if needed if self.preserve_high_precision_init_val: @@ -1657,6 +1682,8 @@ def reset_parameters(self, defer_init: Optional[bool] = False) -> None: # NOTE: Currently this can only be broken when primary weights are in Fp8 but # re-applying the nn.Parameter() wrap is a no-op when the input is already # a parameter so we always re-apply it just for extra safety. + # Skip the wrap for GTPShardedParam: Parameter.__new__'s custom-tensor + # path detaches and returns a NEW instance, dropping GTP attrs. if is_dtensor: # recreate the DTensor from the parameter. dtensor_param = DTensor.from_local( @@ -1667,7 +1694,7 @@ def reset_parameters(self, defer_init: Optional[bool] = False) -> None: stride=dtensor_param.stride(), ) dtensor_param = torch.nn.Parameter(dtensor_param) - else: + elif gtp_sharded is None: param = torch.nn.Parameter(param) # Keep high-precision values on CPU if needed @@ -1705,6 +1732,17 @@ def clear(self): else: self.module_setattr(name, dtensor_param) + # GroupedLinear-only: attach weight_list to the first expert's shard for + # batched all-gather. No-op unless self._gtp_is_grouped is set. + if _gtp_sharded_weight_names: + from .generalized_tensor_parallelism import ( + gtp_finalize_module_in_reset_parameters, + ) + + gtp_finalize_module_in_reset_parameters( + self, _gtp_sharded_weight_names + ) + @abstractmethod def forward(self): """Needs override.""" diff --git a/transformer_engine/pytorch/module/generalized_tensor_parallelism.py b/transformer_engine/pytorch/module/generalized_tensor_parallelism.py new file mode 100644 index 0000000000..ecc6785314 --- /dev/null +++ b/transformer_engine/pytorch/module/generalized_tensor_parallelism.py @@ -0,0 +1,1692 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +from collections import defaultdict +from contextlib import nullcontext +from typing import Dict, List, Optional +from enum import Enum +from dataclasses import dataclass, field +import math +import re +import torch + +from ..distributed import ( + gather_along_first_dim, + reduce_scatter_along_first_dim, + _NVFP4AllGatherAsyncHandle +) +from ..quantized_tensor import QuantizedTensor +from ..tensor import NVFP4TensorStorage, MXFP8TensorStorage +from ..tensor.mxfp8_tensor import MXFP8Quantizer +from ..utils import nvtx_range_pop, nvtx_range_push, round_up_to_nearest_multiple +from ..constants import NVFP4_BLOCK_SCALING_SIZE, MXFP8_BLOCK_SCALING_SIZE +from .base import get_dummy_wgrad + +import transformer_engine_torch as tex + + +class GTPChain(str, Enum): + """Prefetch chain identifier for an GTPShardedParam. + + GRAPHED — fwd/bwd captured by a CUDA graph (MLM _CudaGraphRunner). + UNGRAPHED — fwd/bwd runs eagerly; includes embedding/output_layer and + routed grouped experts always, plus router/shared_experts + when their scope tag is not in cuda_graph_scope. + + Chains never cross-link (prev_w/next_w stay within one chain). CG + disabled → single UNGRAPHED chain; full-iteration graph → single GRAPHED. + """ + GRAPHED = "GTP_graphed" + UNGRAPHED = "GTP_ungraphed" + + +# Module-level cuda_graph_scope, set by the integrator at init via set_cuda_graph_scope(). +# None or empty → CG is disabled; every GTP param classifies as UNGRAPHED. +# Value is a set of scope tags; e.g. {"mamba","attn","moe_router"}. +_CUDA_GRAPH_SCOPE: Optional[set] = None +# Whether shared_experts are run with overlap (cannot be captured). When True, +# shared_experts stay UNGRAPHED regardless of moe_router scope inclusion, matching +# the transformer_layer.py guard that excludes them from the captured submodules. +_MOE_SHARED_EXPERT_OVERLAP: bool = False + + +def set_cuda_graph_scope(scope, moe_shared_expert_overlap: bool = False): + """Record the active cuda_graph_scope for GTP chain classification. + + Called by MLM at init, BEFORE classify_gtp_chains(). ``scope`` may be + None, an empty iterable (CG disabled), or an iterable of scope tags. + """ + global _CUDA_GRAPH_SCOPE, _MOE_SHARED_EXPERT_OVERLAP + _CUDA_GRAPH_SCOPE = set(scope) if scope else None + _MOE_SHARED_EXPERT_OVERLAP = bool(moe_shared_expert_overlap) + + +def _classify_param_chain(param_name: str) -> 'GTPChain': + """Classify an GTPShardedParam by name + active cuda_graph_scope. + + embedding / output_layer are always UNGRAPHED. Other kinds (mamba mixer, + self/cross_attention, shared_experts, routed experts) are GRAPHED iff + their scope tag is present in cuda_graph_scope; otherwise UNGRAPHED. + """ + n = param_name + + # Always ungraphed — embedding and output_layer live outside any CG runner. + if "embedding" in n or "output_layer" in n: + return GTPChain.UNGRAPHED + + scope = _CUDA_GRAPH_SCOPE + if not scope: + # CG disabled: every GTP param goes to the single UNGRAPHED chain. + return GTPChain.UNGRAPHED + + if ".mlp.shared_experts." in n: + if _MOE_SHARED_EXPERT_OVERLAP: + return GTPChain.UNGRAPHED + return GTPChain.GRAPHED if ("moe" in scope or "moe_router" in scope) else GTPChain.UNGRAPHED + + if ".mlp.experts." in n: + return GTPChain.GRAPHED if "moe" in scope else GTPChain.UNGRAPHED + + if ".self_attention." in n or ".cross_attention." in n: + return GTPChain.GRAPHED if "attn" in scope else GTPChain.UNGRAPHED + + if ".mixer." in n: + return GTPChain.GRAPHED if "mamba" in scope else GTPChain.UNGRAPHED + + return GTPChain.UNGRAPHED + + +def classify_gtp_chains(model) -> None: + """Walk model.named_parameters() and set chain_id on every GTPShardedParam. + + Call once at init, AFTER set_cuda_graph_scope() and BEFORE the first fwd + of any graphed param. Raises if an already chain-initialized param would + be reclassified into a different chain (its prev/next links are already + wired into the wrong list). + """ + conflicts = [] + for name, param in model.named_parameters(): + if not isinstance(param, GTPShardedParam): + continue + target = _classify_param_chain(name).value + if param.prefetch_initialized and param.chain_id != target: + conflicts.append((name, param.chain_id, target)) + continue + param.chain_id = target + + # Bwd-prefetch opt-out: embedding.word_embeddings.weight does not need + # an AG in the bwd pass (its wgrad is a scatter-add on sharded rows + # and its input has no dgrad). Skipping its bwd AG saves one collective. + if "embedding" in name: + param._need_weight_prefetch_bwd = False + if conflicts: + raise RuntimeError( + "classify_gtp_chains: the following params were already chain-initialized " + "with a different chain_id than the classifier would assign — this means " + "their chain links are already wired into the wrong list. Move classification " + "earlier in init. Conflicts: " + + ", ".join(f"{n}: {old!r}->{new!r}" for n, old, new in conflicts[:3]) + + ("..." if len(conflicts) > 3 else "") + ) + + +class GTPWeightState(Enum): + NONE = "NONE" # Sharded, no pending operation + ASYNC_WAIT = "ASYNC_WAIT" # Async all-gather in progress + DATA_READY = "DATA_READY" # Async all-gather complete, result in cache + DATA_READY_SYNC = "DATA_READY_SYNC" # Sync all-gather complete, result in cache + + + +# Global GTP buffer cache (persists across clear(); never set to None after creation). +_GTP_CACHE = None +_GTP_PARAMS = [] + +# Global set of GTPShardedParam with in-flight async comms (AG or RS). +_inflight_comm_params: set = set() +_AG_STREAMS: Dict[str, torch.cuda.Stream] = {} +_RS_STREAMS: Dict[str, torch.cuda.Stream] = {} + +# Wgrad input buffer pool, keyed by (shape, dtype). UNGRAPHED-only: GRAPHED +# wgrad bufs need address stability for CG replay and are not pool-recycled. +_wgrad_buf_pool: Dict[tuple, list] = {} + + +def _wgrad_pool_get(shape: tuple, dtype: torch.dtype, device) -> torch.Tensor: + """Get a pool buffer or allocate fresh. Tagged so _wgrad_pool_put accepts + only pool-owned buffers — callers that don't use _wgrad_pool_get (e.g. + Megatron layers.py wgrad GEMM, aten F.embedding bwd) fall through to the + caching allocator on release.""" + key = (shape, dtype) + pool = _wgrad_buf_pool.get(key) + if pool: + buf = pool.pop() + else: + buf = torch.empty(shape, dtype=dtype, device=device, requires_grad=False) + buf._from_gtp_wgrad_pool = True + return buf + + +def _wgrad_pool_put(buf: torch.Tensor): + """Return a pool-owned buffer for reuse (no-op for untagged buffers; see + _wgrad_pool_get).""" + if not getattr(buf, '_from_gtp_wgrad_pool', False): + return + key = (tuple(buf.shape), buf.dtype) + if key not in _wgrad_buf_pool: + _wgrad_buf_pool[key] = [] + _wgrad_buf_pool[key].append(buf) + + +def _stream_key(chain_id: str, group) -> tuple: + """Key for the per-(chain, group) AG/RS stream dicts. + + Two partitioning axes: + - chain_id: captured (GRAPHED) vs eager (UNGRAPHED) ops must not share + a stream (eager ops would contaminate capture/replay state). + - group: independent NCCL communicators (e.g. GTP vs EGTP) get their + own user-level stream to avoid cross-group serialization. + """ + return (chain_id, id(group) if group is not None else 0) + + +def get_ag_stream(chain_id: str = GTPChain.GRAPHED.value, group=None) -> torch.cuda.Stream: + """Return the GTP all-gather stream for (chain_id, group). See _stream_key.""" + key = _stream_key(chain_id, group) + if key not in _AG_STREAMS: + _AG_STREAMS[key] = torch.cuda.Stream() + return _AG_STREAMS[key] + + +def get_rs_stream(chain_id: str = GTPChain.GRAPHED.value, group=None) -> torch.cuda.Stream: + """Return the GTP reduce-scatter stream for (chain_id, group). See _stream_key.""" + key = _stream_key(chain_id, group) + if key not in _RS_STREAMS: + _RS_STREAMS[key] = torch.cuda.Stream() + return _RS_STREAMS[key] + + +def get_all_ag_streams() -> list: + """All AG streams created so far, across chains and groups.""" + return list(_AG_STREAMS.values()) + + +def get_all_rs_streams() -> list: + """All RS streams created so far, across chains and groups.""" + return list(_RS_STREAMS.values()) + + +def get_ag_streams_for_chain(chain_id: str) -> list: + """AG streams for one chain (all groups that chain has touched).""" + return [s for k, s in _AG_STREAMS.items() if k[0] == chain_id] + + +def get_rs_streams_for_chain(chain_id: str) -> list: + """RS streams for one chain (all groups that chain has touched).""" + return [s for k, s in _RS_STREAMS.items() if k[0] == chain_id] + +# Cached once per process: whether the TE build exposes the split-phase APIs. +_COALESCED_AMAX_TE_APIS_AVAILABLE = ( + hasattr(tex, "compute_amax_nvfp4") and hasattr(tex, "quantize_cast_only_nvfp4") +) + +# Tier-2: multi-tensor amax kernel fuses N per-expert (zero_amax + amax + D2D) chains +# into two multi-tensor kernel launches. Independent of Tier-1 coalesced allreduce. +_MULTI_AMAX_TE_API_AVAILABLE = hasattr(tex, "compute_multi_amax_nvfp4") + + +def _coalesced_amax_static_eligible(weights): + """Check whether the coalesced-amax path is applicable (NVFP4 only). + + Caller already gates on GTP_CONFIG.coalesce_amax_allreduce (False for + non-NVFP4). Here we additionally verify TE API availability, batch size, + quantizer type (must have amax reduction), and the RHT flag.""" + if not _COALESCED_AMAX_TE_APIS_AVAILABLE: + return False + if len(weights) <= 1: + return False + has_amax = [getattr(w._quantizer, "with_amax_reduction", False) for w in weights] + if not all(has_amax): + return False + has_rht = any(getattr(w._quantizer, "with_rht", False) for w in weights) + if has_rht: + return False + return True + + +def _quantize_with_coalesced_amax(weights, skip_weight_cast, cast_noop_flag): + """Replace the per-weight (compute_amax + allreduce + cast) loop with: + compute_amax loop → one coalesced allreduce → cast loop.""" + group = weights[0]._quantizer.amax_reduction_group + + # Materialize padded shards once; on padded last-rank get_padded_shard() + # launches an F.pad kernel, and we'd otherwise pay it twice per expert. + padded_shards = [w.get_padded_shard() for w in weights] + + # Phase 1: per-weight local amax into each w.quantized's amax buffers. + # Keep rowwise/columnwise both populated so the group allreduce sees + # whichever the consumer GEMM will read. + for w in weights: + w._quantizer.set_usage(rowwise=True, columnwise=True) + if _MULTI_AMAX_TE_API_AVAILABLE: + # Tier-2: single multi-tensor launch writes both rowwise and columnwise + # amax directly (no per-expert D2D replicate), fusing N per-expert chains. + # Reuse the _cached_quantizers list already populated by _all_gather_weight + anchor = weights[0] + quantizer_list = anchor._cached_quantizers + if quantizer_list is None: + quantizer_list = [w._quantizer for w in weights] + anchor._cached_quantizers = quantizer_list + tex.compute_multi_amax_nvfp4( + padded_shards, + quantizer_list, + [w.quantized for w in weights], + ) + else: + for w, shard in zip(weights, padded_shards): + tex.compute_amax_nvfp4( + tensor=shard, + quantizer=w._quantizer, + output=w.quantized, + ) + + # Phase 2: one coalesced allreduce across every weight's amax tensors. + amax_tensors = [] + for w in weights: + rw = w.quantized._amax_rowwise + cw = w.quantized._amax_columnwise + if rw is not None: + amax_tensors.append(rw) + if cw is not None and (rw is None or cw.data_ptr() != rw.data_ptr()): + amax_tensors.append(cw) + torch.distributed.all_reduce_coalesced( + amax_tensors, + op=torch.distributed.ReduceOp.MAX, + group=group, + ) + + # Phase 3: per-weight cast using the pre-reduced amax; skips the internal + # allreduce inside the quantizer. + for w, shard in zip(weights, padded_shards): + tex.quantize_cast_only_nvfp4( + tensor=shard, + quantizer=w._quantizer, + output=w.quantized, + noop=cast_noop_flag, + ) + w.did_cast_to_low_precision = True + + +@dataclass +class GTPConfig: + """Global configuration for Generalized Tensor Parallelism.""" + pad_for_alignment: int = 16 + check_param_states: bool = False + weight_prefetch: bool = True + # When True (default), wgrad reduce-scatter for non-chain-head GTP + # params uses async_op=True; finalize (handle.wait + main_grad.add_) + # runs in the cascade walk of a later bwd call, allowing RS-compute + # overlap. When False, every wgrad RS is synchronous and finalizes + # inline, at the cost of that overlap. + async_reduction: bool = True + # When True, _Linear.backward and _LayerNormLinear.backward run wgrad + # GEMM before dgrad GEMM. The GTP wgrad reduce-scatter is issued between + # them so its NCCL kernel overlaps with the dgrad GEMM, and the prev_w + # AG prefetch (issued by all_gather_and_prefetch_bwd at the top of bwd) + # overlaps with wgrad GEMM. When False (default), use the original + # dgrad-first order. Only affects _Linear and _LayerNormLinear; MLP and + # GroupedLinear keep the original schedule. + wgrad_before_dgrad: bool = False + # GTP companion to Megatron --fp8-param-gather: optimizer casts FP32 master + # directly into GTPShardedParam.quantized; forward's _quantize_if_needed + # short-circuits to the cached FP8. Moves BF16->FP8 off the fwd critical path. + fp8_param_gather: bool = False + # When True and the weight list in _all_gather_weight contains >1 NVFP4 + # shards that share an amax reduction group, coalesce their per-expert + # amax allreduces into a single NCCL call. Requires TE with + # tex.compute_amax_nvfp4 / tex.quantize_cast_only_nvfp4; the eligibility + # guard in _coalesced_amax_static_eligible falls back to the per-weight + # path when either binding is missing. + coalesce_amax_allreduce: bool = True + +GTP_CONFIG = GTPConfig() + + +def update_config(**kwargs): + """Update the global GTP configuration.""" + for key, value in kwargs.items(): + if not hasattr(GTP_CONFIG, key): + raise ValueError(f"Unknown GTP config option: {key}") + setattr(GTP_CONFIG, key, value) + + +def tag_gtp_params_with_names(model): + """Populate _debug_name on every GTPShardedParam with its full dotted parameter name. + + Call once after model construction so the linking log prints human-readable names + instead of raw tensor ids. + """ + for name, param in model.named_parameters(): + if isinstance(param, GTPShardedParam): + param._debug_name = name + + +def _gtp_slice_one_param(param, gtp_group, *, name=""): + """Pad + slice a full-size BF16 weight to this rank's GTP shard. + + Caller attaches GTP attrs (see _gtp_attach_attrs). When called from the + legacy post-init path under fp8_model_init, tensor may be a + QuantizedTensor — F.pad dequantizes it before slicing. + """ + gtp_size = gtp_group.size() + gtp_rank = gtp_group.rank() + tensor = param.data + + if GTP_CONFIG.pad_for_alignment > 0: + # Pad before slicing so shards stay alignment-divisible and padding + # ends up contiguous at the tail of the gathered result. + alignment = GTP_CONFIG.pad_for_alignment * gtp_size + dim0 = tensor.shape[0] + pad_length = (alignment - dim0 % alignment) % alignment + if pad_length > 0: + tensor = torch.nn.functional.pad(tensor, (0, 0, 0, pad_length)) + else: + # No-pad mode: dim-0 must divide gtp_size or AG output loses tail rows. + assert tensor.shape[0] % gtp_size == 0, ( + f"_gtp_slice_one_param: {name}.shape[0]={tensor.shape[0]} is not " + f"divisible by gtp_size={gtp_size}. Either enable padding by " + f"setting GTP_CONFIG.pad_for_alignment > 0, or ensure the weight's " + f"dim-0 is a multiple of the GTP group size." + ) + pad_length = 0 + + shard_size = tensor.shape[0] // gtp_size + shard = tensor[gtp_rank * shard_size : (gtp_rank + 1) * shard_size] + gtp_shard = GTPShardedParam(shard.clone()) + gtp_shard.pad_length = pad_length + return gtp_shard + + +def _gtp_attach_attrs(gtp_shard, gtp_group, *, is_grouped=False, expert_idx=0): + """Attach group / ps_size / routed-expert tags and register in _GTP_PARAMS. + + Kept separate from _gtp_slice_one_param so attrs land on the post-quantize + param (when quantize fires between slice and attach). + """ + if is_grouped: + gtp_shard.expert_idx = expert_idx + gtp_shard.is_routed_expert = True + # Default to UNGRAPHED; classify_gtp_chains() reclassifies based on the + # cuda_graph_scope at init time. + gtp_shard.chain_id = GTPChain.UNGRAPHED.value + gtp_shard.group = gtp_group + gtp_shard.ps_size = gtp_group.size() + global _GTP_PARAMS + _GTP_PARAMS.append(gtp_shard) + + +def wrap_module_params_gtp(module, weight_names, gtp_group, is_grouped=None): + """Shard and re-register module params as GTPShardedParam. + + Two call paths: + 1. Megatron-style modules (ColumnParallelLinear, etc.): full post-init slice. + 2. TE modules: per-param body no-ops because the reset_parameters hook + already produced GTPShardedParam instances. + """ + if gtp_group.size() == 1: + return + + for idx, name in enumerate(weight_names): + param = getattr(module, name, None) + if param is None: + continue + + # TE-side hook already sliced this one. + if isinstance(param, GTPShardedParam): + continue + + # delete the original parameter, which will be replaced by an GTP sharded one + delattr(module, name) + gtp_shard = _gtp_slice_one_param(param, gtp_group, name=name) + del param + _gtp_attach_attrs( + gtp_shard, gtp_group, is_grouped=bool(is_grouped), expert_idx=idx + ) + # register the newly sharded param back to the module + module._parameters[name] = gtp_shard + + if is_grouped: + allweights = [getattr(module, name) for name in weight_names] + allweights[0].weight_list = allweights + + +def gtp_slice_in_reset_parameters(module, name, param, expert_idx=0): + """Slice + attach attrs for one param. Called between init_fn(param) and + the optional quantizer(param) in TransformerEngineBaseModule.reset_parameters. + + Only fires for params in module.weight_names (the GEMM weights); + layer-norm gammas, biases, etc. are left full-size. + + Returns the new GTPShardedParam or None (GTP not active for this param). + """ + gtp_group = getattr(module, "_gtp_group", None) + if gtp_group is None or gtp_group.size() == 1: + return None + weight_names = getattr(module, "weight_names", None) + if weight_names is None or name not in weight_names: + return None + is_grouped = bool(getattr(module, "_gtp_is_grouped", False)) + gtp_shard = _gtp_slice_one_param(param, gtp_group, name=name) + _gtp_attach_attrs( + gtp_shard, gtp_group, is_grouped=is_grouped, expert_idx=expert_idx + ) + return gtp_shard + + +def gtp_finalize_module_in_reset_parameters(module, weight_names): + """GroupedLinear-only: attach weight_list to expert 0's shard for batched + all-gather. No-op when module._gtp_is_grouped is False. + """ + if not getattr(module, "_gtp_is_grouped", False): + return + gtp_group = getattr(module, "_gtp_group", None) + if gtp_group is None or gtp_group.size() == 1: + return + allweights = [getattr(module, n) for n in weight_names] + if allweights: + allweights[0].weight_list = allweights + + +class GTPShardHandle: + + def __init__(self, handle, gtp_shards, reduce_scatter=False): + self.handle = handle + self.gtp_shards = gtp_shards + self.reduce_scatter = reduce_scatter + _inflight_comm_params.add(gtp_shards[0]) + + def wait(self): + if self.handle is not None: + self.handle.wait() + self.handle = None # Release NCCL Work and its C++ tensor references promptly + if GTP_CONFIG.check_param_states: + for w in self.gtp_shards: + if self.reduce_scatter: + w._set_rs_state(GTPWeightState.DATA_READY) + else: + w._set_state(GTPWeightState.DATA_READY) + + _inflight_comm_params.discard(self.gtp_shards[0]) + + +class GTPShardedParam(torch.nn.Parameter): + + _pending_rs_weight = None + _first_weight_flag = True + # Per-chain state: each chain_id (GTPChain.GRAPHED / GTPChain.UNGRAPHED) has + # its own linked list. Chains never cross-link: prev_w/next_w only connect + # params with the same chain_id. + _chain_state: Dict[str, dict] = {} + + @classmethod + def _get_chain_state(cls, chain_id: str) -> dict: + if chain_id not in cls._chain_state: + cls._chain_state[chain_id] = { + 'last_weight': None, + 'link_node_count': 0, + 'link_table_buffer': [], + 'link_table_flushed': False, + } + return cls._chain_state[chain_id] + + @classmethod + def _buffer_link_table_row(cls, prev: "GTPShardedParam", curr: "GTPShardedParam", chain: dict) -> None: + """Buffer one row of the prefetch-link table (flushed atomically on the second forward pass).""" + _W = 70 + + def _layer_id(name: str) -> str: + m = re.search(r"\d+", name) + return m.group() if m else "-" + + chain['link_node_count'] += 1 + if chain['link_node_count'] == 1: + chain_id = getattr(curr, 'chain_id', GTPChain.UNGRAPHED.value) + chain['link_table_buffer'].append( + f"\n[{chain_id} chain]" + f"\n{'node_id':>7} | {'layer_id':>8} | {'curr_weight_name':<{_W}} | prev_weight_name" + f"\n{'-'*7}-+-{'-'*8}-+-{'-'*_W}-+-{'-'*_W}" + ) + # Seed weight (first GTP param) as row 0 + chain['link_table_buffer'].append( + f"{'0':>7} | {_layer_id(prev._debug_name):>8} | {prev._debug_name:<{_W}} | -" + ) + chain['link_table_buffer'].append( + f"{chain['link_node_count']:>7} | {_layer_id(curr._debug_name):>8} | " + f"{curr._debug_name:<{_W}} | {prev._debug_name}" + ) + + @staticmethod + def __new__(cls, tensor, *args, **kwargs): + requires_grad = kwargs.get('requires_grad', True) + return super(GTPShardedParam, cls).__new__(cls, tensor, requires_grad=requires_grad) + + def __init__(self, x, *args, **kwargs): + super().__init__() + + # all gather + self.state = GTPWeightState.NONE + self._ag_ticket_fwd = None + self._ag_ticket_bwd = None + self._prefetch_handle = None + self._need_weight_prefetch = True + # Per-direction prefetch opt-outs. Default True. The embedding weight + # never needs an AG during bwd (its wgrad is a scatter-add indexed by + # token ids, and its input is non-differentiable, so no dgrad either). + # classify_gtp_chains() sets this to False for embedding.word_embeddings.weight. + self._need_weight_prefetch_bwd = True + self.ag_event = torch.cuda.Event(external=True) + # DDP backward hook (set by register_grad_accum_hook); invoked from _finalize_wgrad. + self._grad_accum_hook = None + # Quantization + self._quantizer = None + self.did_cast_to_low_precision = False + self.quantized = None + # Prefetching linked list + self.prefetch_initialized = False + self.next_w = None + self.prev_w = None + # Chain identity (GTPChain.GRAPHED / GTPChain.UNGRAPHED). Defaults to + # UNGRAPHED as a safe fallback; classify_gtp_chains(model) walks the + # model at init time (after set_cuda_graph_scope) and reclassifies + # based on param name + active cuda_graph_scope. + self.chain_id = GTPChain.UNGRAPHED.value + # Grouped gemm + self.is_routed_expert = False + self.expert_idx = None + self.group = None + self.weight_list = None + # Reduce-scatter state (set during wgrad_reduce_scatter) + self.rs_state = GTPWeightState.NONE + self._wgrad_rs_handle = None + self.rs_event = torch.cuda.Event(external=True) + self._rs_ticket = None + # Padding + self.pad_length = 0 + # Debug + self._debug_name = "" + # Hot-path caches (populated lazily on first use). chain_id/group are + # set after __init__, so we can't resolve streams eagerly here. + self._cached_ag_stream = None + self._cached_rs_stream = None + self._cached_quantizers = None + self._cached_dtypes = None + self._cached_gtp_group = None + + def setup(self, weight_quantizer=None): + """Set quantizer and create quantized shard.""" + + if self._quantizer is None: + def _configure_quantizer(q, group): + q = q.copy() + if hasattr(q, 'with_amax_reduction'): + q.with_amax_reduction = True + q.amax_reduction_group = group + q.internal = False + # MXFP8 scales must stay in compact (unswizzled) layout so that + # per-shard scale_inv can be all-gathered via byte concatenation. + # GEMM-swizzled scales from independent shards don't compose into + # a valid swizzled layout for the full tensor after AG. + q.optimize_for_gemm = not isinstance(q, MXFP8Quantizer) + return q + + weights = self.weight_list if self.is_routed_expert and self.weight_list is not None else [self] + for quantizer, weight in zip(weight_quantizer, weights): + if quantizer is None: + continue + + weight._quantizer = _configure_quantizer(quantizer, weight.group) + weight.quantized = weight._quantizer.quantize(weight.get_padded_shard()) + weight.quantized.is_routed_expert = getattr(weight, 'is_routed_expert', False) + # fp8_param_gather: the init quantize above already produced a + # valid FP8 cache from the BF16 shard; flag did_cast so iter-0's + # forward _quantize_if_needed short-circuits and the redundant + # BF16->FP8 cast on iter 0 is skipped. + if GTP_CONFIG.fp8_param_gather: + weight.did_cast_to_low_precision = True + + @property + def _weights(self): + """Return the list of individual weight shards (self for non-routed, weight_list for routed).""" + weights = self.weight_list if self.is_routed_expert else [self] + # Only meaningful when _set_state is actively tracking transitions. + if GTP_CONFIG.check_param_states: + assert all(w.state == weights[0].state for w in weights) + return list(weights) + + @property + def _unsharded_shape_padded(self): + out_shape = list(self.size()) + out_shape[0] = out_shape[0] * self.group.size() + return tuple(out_shape) + + @property + def _unsharded_shape(self): + out_shape = list(self._unsharded_shape_padded) + out_shape[0] -= self.pad_length + return tuple(out_shape) + + @property + def _sharded_padded_shape(self): + return tuple(self.size()) + + def get_padded_shard(self): + return self + + def _set_state(self, new_state: GTPWeightState): + # Only inspected when check_param_states is on; skip writes otherwise. + if not GTP_CONFIG.check_param_states: + return + self.state = new_state + + def _set_rs_state(self, new_state: GTPWeightState): + if not GTP_CONFIG.check_param_states: + return + self.rs_state = new_state + + def _get_cache_key(self, dtype, fwd: bool, reduce_scatter: bool) -> tuple: + """Build cache key using output shape + dtype. + + Weights with matching gathered shape and dtype share a buffer. + For expert weights gathered in parallel, self.expert_idx distinguishes them so + each gets a distinct buffer, while same-indexed experts across layers share. + """ + + if not isinstance(dtype, torch.dtype): + return (self._unsharded_shape_padded, dtype, fwd, not fwd, self.expert_idx, reduce_scatter) + return (self._unsharded_shape_padded, dtype, self.expert_idx, reduce_scatter) + + def _quantize_if_needed(self, skip_weight_cast=False, cast_noop_flag=None): + """Re-quantize sharded weight into existing buffer. Returns quantized weight or self.""" + if self._quantizer is None: + self.did_cast_to_low_precision = False + return self + + # fp8_param_gather fast-path: optimizer already filled self.quantized; + # reuse it and keep BF16->FP8 off the forward critical path. + if GTP_CONFIG.fp8_param_gather and self.did_cast_to_low_precision: + return self.quantized + + self._quantizer.set_usage(rowwise=True, columnwise=True) + if skip_weight_cast is False or cast_noop_flag is not None: + tex.quantize( + tensor=self.get_padded_shard(), + quantizer=self._quantizer, + output=self.quantized, + noop=cast_noop_flag, + ) + self.did_cast_to_low_precision = True + + return self.quantized + + def _strip_padding(self, tensor): + if self.pad_length == 0: + return tensor + + if isinstance(tensor, QuantizedTensor): + assert isinstance(tensor, (NVFP4TensorStorage, MXFP8TensorStorage)), \ + f"Unsupported quantized tensor type for GTP padding: {type(tensor)}" + + metadata = tensor.get_metadata() + if metadata.get("rowwise_data") is not None: + metadata["rowwise_data"] = metadata["rowwise_data"][:-self.pad_length] + if metadata.get("columnwise_data") is not None: + if isinstance(tensor, NVFP4TensorStorage): + # NVFP4 transposes columnwise and packs 2 values per byte + metadata["columnwise_data"] = metadata["columnwise_data"][ + ..., :-self.pad_length // 2 + ].contiguous() + else: + # MXFP8 columnwise is not transposed, strip first dim + metadata["columnwise_data"] = metadata["columnwise_data"][ + :-self.pad_length + ] + M = self._unsharded_shape[0] + if isinstance(tensor, NVFP4TensorStorage): + # NVFP4 scale_inv shapes (see NVFP4Quantizer.get_scale_shape): + # rowwise_scale_inv: [round_up(M, 128), round_up(ceil(K/16), 4)] + # columnwise_scale_inv: [round_up(K, 128), round_up(ceil(M/16), 4)] + # GTP shards M (dim 0 of the weight), so strip to the unpadded sizes. + if metadata.get("rowwise_scale_inv") is not None: + m_rows = round_up_to_nearest_multiple(M, 128) + metadata["rowwise_scale_inv"] = metadata["rowwise_scale_inv"][:m_rows] + if metadata.get("columnwise_scale_inv") is not None: + m_tiles = round_up_to_nearest_multiple( + math.ceil(M / NVFP4_BLOCK_SCALING_SIZE), 4 + ) + metadata["columnwise_scale_inv"] = ( + metadata["columnwise_scale_inv"][:, :m_tiles].contiguous() + ) + else: + # MXFP8 scale_inv shapes (see MXFP8Quantizer.get_scale_shape): + # rowwise_scale_inv: [round_up(M, 128), round_up(K//32, 4)] + # columnwise_scale_inv: [round_up(M//32, 4), round_up(K, 128)] + # GTP shards M (dim 0 of the weight), so strip to the unpadded sizes. + if metadata.get("rowwise_scale_inv") is not None: + m_rows = round_up_to_nearest_multiple(M, 128) + metadata["rowwise_scale_inv"] = metadata["rowwise_scale_inv"][:m_rows] + if metadata.get("columnwise_scale_inv") is not None: + m_tiles = round_up_to_nearest_multiple( + M // MXFP8_BLOCK_SCALING_SIZE, 4 + ) + metadata["columnwise_scale_inv"] = ( + metadata["columnwise_scale_inv"][:m_tiles] + ) + + return type(tensor)(**metadata, shape=self._unsharded_shape, dtype=torch.bfloat16) + else: + return tensor[:-self.pad_length] + + def _all_gather_weight(self, async_op, skip_weight_cast, cast_noop_flag, fwd, nvtx_label=None): + """Quantize (if needed) and all-gather weight. Returns (weight_total, handle).""" + if nvtx_label is None: + nvtx_label = ( + self._debug_name + + (".fwd" if fwd else ".bwd") + + (".async" if async_op else ".sync") + ) + nvtx_range_push(f"{nvtx_label}.all_gather_weight") + + weights = self._weights + + # 1. Transition state for async gathers. + if GTP_CONFIG.check_param_states: + new_state = GTPWeightState.ASYNC_WAIT if async_op else GTPWeightState.DATA_READY_SYNC + for w in weights: + w._set_state(new_state) + + # 2. Prepare: quantize, set usage direction. + # Static eligibility (quantizer class, flags, amax group) is fixed + # after model construction — compute once and cache on self so the + # hot path only pays the cheap per-call skip_weight_cast check. + if GTP_CONFIG.coalesce_amax_allreduce: + static_ok = getattr(self, "_coalesced_amax_static", None) + if static_ok is None: + static_ok = _coalesced_amax_static_eligible(weights) + self._coalesced_amax_static = static_ok + # Per-call: match the skip_weight_cast gate in _quantize_if_needed + # (fire when either skip_weight_cast is False or cast_noop_flag + # was provided by the FP8/NVFP4 recipe). + use_coalesced = static_ok and not ( + skip_weight_cast is True and cast_noop_flag is None + ) + else: + use_coalesced = False + + # Quantize step: coalesced batch / fp8_param_gather cache hit (skip) / + # legacy per-weight. set_usage runs uniformly after, gated by did_cast. + fp8_pg_hit = GTP_CONFIG.fp8_param_gather and self.did_cast_to_low_precision + + if use_coalesced: + _quantize_with_coalesced_amax(weights, skip_weight_cast, cast_noop_flag) + elif not fp8_pg_hit: + for w in weights: + w._quantize_if_needed(skip_weight_cast, cast_noop_flag) + + for w in weights: + if w.did_cast_to_low_precision: + w._quantizer.set_usage(rowwise=fwd, columnwise=not fwd) + + # 3. Build gather inputs. + # quantizers / dtypes / gtp_group are stable after model construction — + # cache on the anchor (self == weights[0]) to avoid rebuilding lists + # every call. w.quantized is NOT cached because it can rebind. + quantizers = self._cached_quantizers + if quantizers is None: + quantizers = [w._quantizer for w in weights] + self._cached_quantizers = quantizers + if weights[0].did_cast_to_low_precision: + gather_weights = [w.quantized for w in weights] + else: + gather_weights = list(w.get_padded_shard() for w in weights) + + # 4. Cache checkout — use pooled buffers for both async and sync gathers + # to avoid allocating fresh memory each iteration. + dtypes = self._cached_dtypes + if dtypes is None: + dtypes = [q.dtype if q is not None else w.dtype for q, w in zip(quantizers, weights)] + self._cached_dtypes = dtypes + out_buffers = [] + cache = get_global_GTP_cache() + for p, dt in zip(weights, dtypes): + if fwd: + if p._ag_ticket_fwd is None: + p._ag_ticket_fwd = cache.reserve(p, dt, fwd=True) + cache.get(p._ag_ticket_fwd) + cache.release(p._ag_ticket_fwd) + out_buffers.append(cache.get(p._ag_ticket_fwd)) + else: + if p._ag_ticket_bwd is None: + p._ag_ticket_bwd = cache.reserve(p, dt, fwd=False) + out_buffers.append(cache.get(p._ag_ticket_bwd)) + + # 5. Communicate. + gtp_group = self._cached_gtp_group + if gtp_group is None: + gtp_group = weights[0].group + self._cached_gtp_group = gtp_group + if GTP_CONFIG.check_param_states and len(gather_weights) > 1: + # Debug invariant: batched AG needs distinct output buffers per expert. + assert len(set(id(b) for b in out_buffers)) == len(out_buffers), \ + "Duplicate output buffers in batched all-gather — experts need distinct cache keys" + + # ASYNC AG: wrap issue on ag_stream — ag_stream's tail then reflects + # the collective's full lifecycle (what external wait_stream(ag_stream) + # drains depend on). The explicit outer→ag_stream sync event preserves + # the upstream quantize writer edge that the bare stream context would + # drop; held on self so PyTorch's event pool can't recycle the handle + # between capture and replay. + # SYNC AG: stay on caller — output ready on return. + if async_op: + outer_stream = torch.cuda.current_stream() + ag_stream = get_ag_stream(self.chain_id, gtp_group) + if getattr(self, '_ag_outer_sync_event', None) is None: + self._ag_outer_sync_event = torch.cuda.Event() + outer_sync_event = self._ag_outer_sync_event + outer_sync_event.record(outer_stream) + ag_stream.wait_event(outer_sync_event) + ag_ctx = torch.cuda.stream(ag_stream) + else: + ag_ctx = nullcontext() + + with ag_ctx: + if len(gather_weights) > 1: + nvtx_range_push(f"{nvtx_label}.batched_gtp_ag") + results, handle = grouped_gather_along_first_dim( + gather_weights, gtp_group, + async_op=async_op, + quantizers=quantizers, + output_tensors=out_buffers, + ) + nvtx_range_pop(f"{nvtx_label}.batched_gtp_ag") + else: + nvtx_range_push(f"{nvtx_label}.gtp_ag") + weight_total, handle = gather_along_first_dim( + gather_weights[0], gtp_group, + quantizer=quantizers[0], + async_op=async_op, + output_tensor=out_buffers[0] if out_buffers is not None else None, + ) + nvtx_range_pop(f"{nvtx_label}.gtp_ag") + results = [weight_total] + + result = results if self.is_routed_expert else results[0] + + # 6. Wrap handle. + if async_op: + handle = GTPShardHandle(handle, weights) + else: + handle = None + + nvtx_range_pop(f"{nvtx_label}.all_gather_weight") + return result, handle + + def _wait_param_gather(self): + # Enter ag_stream context so handle.wait() + ag_event.record() both + # land on ag_stream. That makes ag_event mark ag_stream's tail, which + # is what external drains via wait_stream(ag_stream) actually block on. + ag_stream = self._cached_ag_stream + if ag_stream is None: + ag_stream = get_ag_stream(self.chain_id, self.group) + self._cached_ag_stream = ag_stream + with torch.cuda.stream(ag_stream): + if self._prefetch_handle is not None: + self._prefetch_handle.wait() + self._prefetch_handle = None + self.ag_event.record() + + def _all_gather_weight_on_demand(self, fwd, skip_weight_cast=False, cast_noop_flag=None): + result, _ = self._all_gather_weight( + async_op=False, + skip_weight_cast=skip_weight_cast, + cast_noop_flag=cast_noop_flag, + fwd=fwd, + ) + result = result if self.is_routed_expert else [result] + result = [self._strip_padding(r) for r in result] + result = [r.detach().requires_grad_(w.requires_grad) for r, w in zip(result,self._weights)] + return result if self.is_routed_expert else result[0] + + def _get_prefetched_weight(self, fwd, skip_weight_cast=False, cast_noop_flag=None): + # Stale-read guard: state must reflect an AG issued for this cycle; + # otherwise cache.get() would return the prior iter's AG buffer. + if GTP_CONFIG.check_param_states: + for w in self._weights: + assert w.state in ( + GTPWeightState.ASYNC_WAIT, + GTPWeightState.DATA_READY, + GTPWeightState.DATA_READY_SYNC, + ), ( + f"[GTP] _get_prefetched_weight({'fwd' if fwd else 'bwd'}) on " + f"{self._debug_name} with state={w.state!r} — no AG issued; " + f"cache.get() would return stale data. Check the chain's " + f"_need_weight_prefetch flag and issuer's prefetch logic." + ) + _was_drained = getattr(self, '_already_ag_drained', False) + if _was_drained: + # Producer already drained via wait_async_comms; skip the captured + # cross-graph wait (CUDA no-op anyway). Correctness is provided by + # the eager main_stream sync chain in the surrounding training loop. + self._already_ag_drained = False + else: + # Intra-graph or eager consume: drain inline. + self._wait_param_gather() + self.ag_event.wait() + + # Retrieve prefetched results from cache + result = [] + cache = get_global_GTP_cache() + for w in self._weights: + ticket = w._ag_ticket_fwd if fwd else w._ag_ticket_bwd + result.append(cache.get(ticket)) + + result = [self._strip_padding(r) for r in result] + + result = [r.detach().requires_grad_(w.requires_grad) for r, w in zip(result, self._weights)] + return result if self.is_routed_expert else result[0] + + def all_gather_and_prefetch_bwd(self, nvtx_label=None): + """ + Backward variant: get current weight (from cache if prefetched, else + sync gather) and async-prefetch prev_w. + + Safe thanks to the coat-check cache: get() returns the current buffer + to the pool, and the prefetch's checkout() will allocate a separate + buffer if the pool is empty (i.e. the current buffer is still live + via the caller's tensor reference). + + Returns: + weight_total + """ + + if GTP_CONFIG.weight_prefetch and self.next_w is not None: + result = self._get_prefetched_weight(False, skip_weight_cast=True) + else: + result = self._all_gather_weight_on_demand(False, skip_weight_cast=True) + + if ( + GTP_CONFIG.weight_prefetch + and self.prev_w is not None + and self.prev_w._need_weight_prefetch + and self.prev_w._need_weight_prefetch_bwd + ): + # Pre-AG work (quantize, ticket lookup) runs on caller's stream; + # the NCCL collective itself is wrapped on ag_stream inside + # _all_gather_weight (see the async/sync gate there for rationale). + _, handle = self.prev_w._all_gather_weight( + async_op=True, skip_weight_cast=True, cast_noop_flag=None, + fwd=False, nvtx_label=nvtx_label, + ) + self.prev_w._prefetch_handle = handle + + # The unsharded tensor has been returned, no pending work so reset state to NONE + if GTP_CONFIG.check_param_states: + for w in self._weights: + w._set_state(GTPWeightState.NONE) + + if GTP_CONFIG.weight_prefetch and self.next_w is not None: + cache = get_global_GTP_cache() + for w in self._weights: + cache.release(w._ag_ticket_bwd) + + return result + + def batched_all_gather_and_prefetch_bwd(self, nvtx_label=None): + """Batched backward all-gather + prefetch. Wrapper around all_gather_and_prefetch_bwd.""" + assert self.is_routed_expert and self.weight_list is not None + return self.all_gather_and_prefetch_bwd(nvtx_label=nvtx_label) + + def all_gather_and_prefetch( + self, + fwd: bool = True, + skip_weight_cast: bool = False, + cast_noop_flag: torch.Tensor = None, + nvtx_label: str = None, + ): + """ + All-gather current weight and async-prefetch the next weight. + + Returns: + weight_total + """ + if GTP_CONFIG.weight_prefetch and self.prev_w is not None: + result = self._get_prefetched_weight(True, skip_weight_cast, cast_noop_flag) + else: + result = self._all_gather_weight_on_demand(True, skip_weight_cast, cast_noop_flag) + + # Prefetch next weight + if ( + GTP_CONFIG.weight_prefetch + and self.next_w is not None + and self.next_w._need_weight_prefetch + ): + # Pre-AG work on caller; NCCL wrap lives at the collective site + # inside _all_gather_weight. See all_gather_and_prefetch_bwd. + _, handle = self.next_w._all_gather_weight( + async_op=True, + skip_weight_cast=skip_weight_cast, + cast_noop_flag=cast_noop_flag, + fwd=fwd, nvtx_label=nvtx_label, + ) + self.next_w._prefetch_handle = handle + + # The unsharded tensor has been returned, no pending work so reset state to NONE + if GTP_CONFIG.check_param_states: + for w in self._weights: + w._set_state(GTPWeightState.NONE) + + # Lazy population of linked list: link previous weight to current weight + # Uses per-chain state so dense and expert chains never cross-link. + cls = type(self) + chain = cls._get_chain_state(self.chain_id) + if not self.prefetch_initialized: + last_w = chain['last_weight'] + if last_w is not None and last_w.next_w is None: + cls._buffer_link_table_row(last_w, self, chain) + last_w.next_w = self + self.prev_w = last_w + + cache = get_global_GTP_cache() + + # Set the fwd ag buffer + quantizers = [w._quantizer for w in self._weights] + dtypes = [q.dtype if q is not None else w.dtype for q, w in zip(quantizers, self._weights)] + for w, dt in zip(self._weights, dtypes): + w._ag_ticket_fwd = cache.reserve(w, dt, fwd=True) + cache.get(w._ag_ticket_fwd) + cache.release(w._ag_ticket_fwd) + + self.prefetch_initialized = True + chain['last_weight'] = self + elif not chain['link_table_flushed'] and chain['link_table_buffer']: + # Second forward pass: flush the complete table atomically to avoid interleaving + chain['link_table_flushed'] = True + print_rank_0("\n".join(chain['link_table_buffer']) + "\n") + + return result + + def batched_all_gather_and_prefetch(self, **kwargs): + """Batched all-gather + prefetch for expert weights. Wrapper around all_gather_and_prefetch.""" + assert self.is_routed_expert and self.weight_list is not None + return self.all_gather_and_prefetch(**kwargs) + + def get_wgrad_tensor(self): + return _wgrad_pool_get(self._unsharded_shape, self.main_grad.dtype, self.device) + + def register_grad_accum_hook(self, grad_accum_node, hook): + """Register a DDP backward hook to be called from _finalize_wgrad. + + For GTP params, autograd may receive None (async RS) so the normal grad + accumulator hook never fires. Instead, _finalize_wgrad calls the hook + explicitly after RS wait + gradient accumulation, ensuring DDP's + register_grad_ready fires at exactly the right time. + + ``grad_accum_node`` is accepted for caller-API compatibility but the + node itself is not retained — only the hook callable. + """ + del grad_accum_node + self._grad_accum_hook = hook + + @staticmethod + def _handle_megatron_grad_accum(param): + """Handle megatron DDP and gradient accumulation fusion. + + Do NOT set param.grad before calling the hook — the hook checks + param.grad and would accumulate it into main_grad if zero_out_wgrad + is True, corrupting the gradient with a non-zero dummy. + """ + if hasattr(param, "grad_added_to_main_grad"): + param.grad_added_to_main_grad = True + dummy_grad = get_dummy_wgrad(list(param.main_grad.shape), param.dtype) + if getattr(param, '_grad_accum_hook', None) is not None: + param._grad_accum_hook() + + param._set_rs_state(GTPWeightState.NONE) + return dummy_grad + + + def _wait_reduce_scatter(self, finalize_grad=False): + # Enter rs_stream context so handle.wait() + rs_event.record() land + # on rs_stream — mirrors _wait_param_gather for the RS path. + # When finalize_grad=True, main_grad.add_ also runs on rs_stream + # (right after NCCL RS), so it starts during AG drain rather than + # after it — avoids SM-saturation blocking cross-graph overlap. + rs_stream = self._cached_rs_stream + if rs_stream is None: + rs_stream = get_rs_stream(self.chain_id, self.group) + self._cached_rs_stream = rs_stream + with torch.cuda.stream(rs_stream): + if self._wgrad_rs_handle is not None: + self._wgrad_rs_handle.wait() + self._wgrad_rs_handle = None + self.rs_event.record() + if finalize_grad: + cache = get_global_GTP_cache() + for w in self._weights: + w._set_rs_state(GTPWeightState.NONE) + wgrad_rs = cache.get(w._rs_ticket) + w.main_grad.add_(wgrad_rs) + cache.release(w._rs_ticket) + if hasattr(w, "grad_added_to_main_grad"): + w.grad_added_to_main_grad = True + self._already_finalized = True + # Release stashed wgrad inputs: UNGRAPHED buffers go back to the pool; + # GRAPHED just drops Python refs (addresses must stay stable for CG). + if getattr(self, '_wgrad_input_bufs', None) is not None: + if self.chain_id == GTPChain.UNGRAPHED.value: + for buf in self._wgrad_input_bufs: + _wgrad_pool_put(buf) + self._wgrad_input_bufs = None + + def _reduce_scatter(self, wgrads, async_op, nvtx_label=None): + """Reduce-scatter one or more wgrads. Returns (outputs, handle). + + Single tensor: plain reduce-scatter (no coalescing). + Multiple tensors: coalesced reduce-scatter. + """ + if nvtx_label is None: + nvtx_label = ( + self._debug_name + + ".bwd" + + (".async" if async_op else ".sync") + ) + + if GTP_CONFIG.check_param_states: + new_rs_state = ( + GTPWeightState.ASYNC_WAIT if async_op else GTPWeightState.DATA_READY_SYNC + ) + for w in self._weights: + w._set_rs_state(new_rs_state) + + if self.pad_length > 0: + wgrads = [torch.nn.functional.pad(w, (0, 0, 0, self.pad_length)) for w in wgrads] + + if async_op: + dtypes = [w.dtype for w in wgrads] + out_buffers = [] + cache = get_global_GTP_cache() + for p, dt in zip(self._weights, dtypes): + if p._rs_ticket is None: + p._rs_ticket = cache.reserve(p, dt, fwd=False, reduce_scatter=True) + out_buffers.append(cache.get(p._rs_ticket)) + else: + out_buffers = [None] * len(wgrads) + + # ASYNC RS: wrap issue on rs_stream — rs_stream's tail then reflects + # the collective's full lifecycle (what external wait_stream(rs_stream) + # drains depend on). The explicit outer→rs_stream sync event preserves + # the wgrad-GEMM writer edge that the bare stream context would drop; + # held on self so PyTorch's event pool can't recycle the handle + # between capture and replay. Mirrors AG path. + # SYNC RS: stay on caller — output ready on return. + if async_op: + outer_stream = torch.cuda.current_stream() + rs_stream = get_rs_stream(self.chain_id, self.group) + if getattr(self, '_rs_outer_sync_event', None) is None: + self._rs_outer_sync_event = torch.cuda.Event() + outer_sync_event = self._rs_outer_sync_event + outer_sync_event.record(outer_stream) + rs_stream.wait_event(outer_sync_event) + rs_ctx = torch.cuda.stream(rs_stream) + else: + rs_ctx = nullcontext() + + with rs_ctx: + if len(wgrads) == 1: + nvtx_range_push(f"{nvtx_label}.gtp_rs") + out, handle = reduce_scatter_along_first_dim( + wgrads[0], self.group, async_op=async_op, output=out_buffers[0] + ) + nvtx_range_pop(f"{nvtx_label}.gtp_rs") + return [out], handle + else: + outputs = [] + nvtx_range_push(f"{nvtx_label}.batched_gtp_rs") + with torch.distributed._coalescing_manager( + group=self.group, + device=wgrads[0].device, + async_ops=async_op, + ) as cm: + for out_buffer, tensor in zip(out_buffers, wgrads): + out, _ = reduce_scatter_along_first_dim(tensor, self.group, output=out_buffer) + outputs.append(out) + nvtx_range_pop(f"{nvtx_label}.batched_gtp_rs") + + return outputs, cm if async_op else None + + def wgrad_reduce_scatter(self, wgrad, nvtx_label=None): + """Reduce-scatter wgrad(s). Sync for last weight, async+deferred for others. + + Accepts a single tensor (non-routed) or list of tensors (routed experts). + + Returns: + Single tensor or list for sync (last weight) — backward should return this. + None or tuple of Nones for async — backward should return this. + """ + batched = isinstance(wgrad, (list, tuple)) + wgrads = list(wgrad) if batched else [wgrad] + weights = self._weights + + # UNGRAPHED-chain wgrads are recycled via the standalone pool (_wgrad_pool_put). + # GRAPHED-chain wgrads cannot pool-recycle because CUDA graphs require + # stable buffer addresses across replay. + poolable = self.chain_id == GTPChain.UNGRAPHED.value + + if GTP_CONFIG.async_reduction and self.prev_w is not None: + # Async reduce-scatter (not last weight — deferred finish). Pre-RS + # work on caller; NCCL wrap lives at the collective site inside + # _reduce_scatter (mirrors the AG prefetch sites). + _, rs_handle = self._reduce_scatter(wgrads, async_op=True, nvtx_label=nvtx_label) + self._wgrad_rs_handle = GTPShardHandle(rs_handle, weights, reduce_scatter=True) + # Stash wgrad input buffers — cannot recycle yet because the async RS + # kernel is still reading them on rs_stream. + self._wgrad_input_bufs = wgrads + ret = tuple([None] * len(wgrads)) if batched else None + else: + # Sync reduce-scatter — reached as the natural chain-head case, recycle immediately + wgrads, _ = self._reduce_scatter(wgrads, async_op=False, nvtx_label=nvtx_label) + torch._foreach_add_([p.main_grad for p in weights], wgrads) + result = [self._handle_megatron_grad_accum(p) for p in weights] + + if poolable: + for buf in wgrads: + _wgrad_pool_put(buf) + ret = result if batched else result[0] + + # Wait for last reduce scatter if it was async + # Currently only support reduce scattering in reverse order + if GTP_CONFIG.async_reduction and self.next_w is not None: + self.next_w._wait_reduce_scatter() + + if getattr(self.next_w, '_already_finalized', False): + self.next_w._already_finalized = False + else: + self.next_w.rs_event.wait() + cache = get_global_GTP_cache() + next_weights = self.next_w._weights + wgrads = [cache.get(w._rs_ticket) for w in next_weights] + torch._foreach_add_([w.main_grad for w in next_weights], wgrads) + for w in next_weights: + self._handle_megatron_grad_accum(w) + cache.release(w._rs_ticket) + + return ret + + def batched_wgrad_reduce_scatter(self, wgrad_list, nvtx_label=None): + """Batched version of wgrad_reduce_scatter.""" + assert self.is_routed_expert and self.weight_list is not None + return self.wgrad_reduce_scatter(wgrad_list, nvtx_label=nvtx_label) + + def __torch_function__(self, func, types, args=(), kwargs=None): + if kwargs is None: + kwargs = {} + + if func is torch.Tensor.detach: + with torch._C.DisableTorchFunctionSubclass(): + # Perform the raw detach + result = func(*args, **kwargs) + # Re-wrap it in your subclass so PyTorch is happy + return result.as_subclass(type(self)) + + # 2. For everything else (add, mul, etc.), be transparent/decay. + with torch._C.DisableTorchFunctionSubclass(): + return func(*args, **kwargs) + + +def print_rank_0(message, rank=None): + """If distributed is initialized or rank is specified, print only on rank 0.""" + if rank is not None: + if rank == 0: + print(message, flush=True) + elif torch.distributed.is_initialized(): + if torch.distributed.get_rank() == 0: + print(message, flush=True) + else: + print(message, flush=True) + +@dataclass +class _TicketSlot: + """Internal slot backing a persistent ticket in the GTP buffer cache.""" + key: tuple # cache key (shape, dtype, ...) + param: 'GTPShardedParam' # for lazy allocation metadata + dtype: object # torch.dtype or tex.DType + reduce_scatter: bool + fwd: bool + chain_id: str = GTPChain.GRAPHED.value # chain this slot belongs to + buf: Optional[torch.Tensor] = field(default=None) # None when released or after clear() + + +class GTPWeightCache: + """ + Ticket-based buffer pool for GTP all-gather / reduce-scatter buffers. + + - ``reserve(param, dtype, fwd)`` → ``ticket`` + Assigns a persistent ticket (no buffer allocated yet). + - ``get(ticket)`` → ``buffer`` + Returns the buffer, lazily allocating from pool or fresh if needed. + - ``release(ticket)`` + Returns the buffer to the pool. Ticket remains valid; next ``get()`` + will re-allocate from the pool. + - ``clear()`` + Drops all buffers and pools. Tickets remain valid; next ``get()`` + lazily allocates fresh buffers. + """ + + # Bytes per element for known dtypes (used for logging). + _BYTES_PER_ELEMENT = { + torch.bfloat16: 2, + torch.float16: 2, + torch.float32: 4, + tex.DType.kFloat4E2M1: 0.5, + tex.DType.kFloat8E4M3: 1, + } + + def __init__(self): + self._pool: Dict[tuple, List[torch.Tensor]] = defaultdict(list) + self._slots: Dict[int, _TicketSlot] = {} + self._next_ticket: int = 0 + self._total_bytes: int = 0 # running total of allocated bytes + self.key_to_allocate_func = {} + + @staticmethod + def _buf_bytes(shape, dtype) -> int: + """Estimate buffer size in bytes.""" + numel = 1 + for d in shape: + numel *= d + bpe = GTPWeightCache._BYTES_PER_ELEMENT.get(dtype, None) + return numel * bpe + + def _allocate_buffer(self, param: 'GTPShardedParam', dtype, reduce_scatter, fwd) -> torch.Tensor: + if reduce_scatter: + out_shape = param._sharded_padded_shape + else: + out_shape = param._unsharded_shape_padded + + if not isinstance(dtype, torch.dtype): + quantizer = param._quantizer + assert quantizer is not None + param._quantizer.set_usage(rowwise=fwd, columnwise=not fwd) + + buf = param._quantizer.make_empty( + out_shape, + dtype=torch.bfloat16, + device=torch.cuda.current_device(), + ) + else: + buf = torch.empty( + out_shape, dtype=dtype, device=param.device, memory_format=torch.contiguous_format + ) + + buf_bytes = self._buf_bytes(out_shape, dtype) + self._total_bytes += buf_bytes + print_rank_0( + f"[GTP Cache] +{buf_bytes / 1024**2:.1f} MB (shape={out_shape}, dtype={dtype}) " + f"total={self._total_bytes / 1024**2:.1f} MB id: {id(buf)} fwd: {fwd}" + ) + return buf + + def reserve(self, param: 'GTPShardedParam', dtype, fwd: bool, reduce_scatter=False) -> int: + """Assign a persistent ticket. No buffer is allocated until ``get()``.""" + key = param._get_cache_key(dtype, fwd, reduce_scatter) + ticket = self._next_ticket + self._next_ticket += 1 + + self._slots[ticket] = _TicketSlot( + key=key, param=param, dtype=dtype, reduce_scatter=reduce_scatter, fwd=fwd, + chain_id=getattr(param, 'chain_id', GTPChain.UNGRAPHED.value), + ) + return ticket + + def get(self, ticket: int) -> torch.Tensor: + """Return the buffer for *ticket*, lazily allocating if needed.""" + slot = self._slots[ticket] + if slot.buf is None: + pool = self._pool[slot.key] + slot.buf = pool.pop() if pool else self._allocate_buffer( + slot.param, slot.dtype, slot.reduce_scatter, fwd=slot.fwd + ) + self.key_to_allocate_func[slot.key] = (slot.param, slot.dtype, slot.reduce_scatter, slot.fwd) + + return slot.buf + + def release(self, ticket: int): + """Return the buffer to the pool. Ticket remains valid. + + slot.buf is intentionally NOT cleared: get() must stay idempotent so that + CUDA-graph-captured buffers keep their fixed address across replays, and + reallocate_to_mempool() can find every dense-chain buffer. + """ + slot = self._slots[ticket] + if slot.buf is None: + return + # Use identity check — tensor == tensor returns a multi-element bool tensor + # which crashes in a boolean context ("Boolean value of Tensor is ambiguous"). + if not any(b is slot.buf for b in self._pool.get(slot.key, [])): + self._pool[slot.key].append(slot.buf) + + def clear(self): + """Drop all buffers; tickets remain valid and lazily re-allocate on next get().""" + for slot in self._slots.values(): + slot.buf = None + self._pool.clear() + self._total_bytes = 0 + + def reallocate_to_mempool(self, device, mempool): + """Re-allocate GRAPHED-chain ticket buffers into a CUDA graph memory pool. + + Call BEFORE graph capture so every GRAPHED-chain buffer lives in the capture + pool and no allocations are recorded inside the graph. UNGRAPHED-chain + buffers are left in regular memory (they are never referenced by any + captured graph). + """ + + # Identify keys that belong to the GRAPHED chain + graphed_keys = set() + for slot in self._slots.values(): + if slot.chain_id == GTPChain.GRAPHED.value: + graphed_keys.add(slot.key) + + # Clone only GRAPHED-chain pool buffers into the passed in mempool + self._total_bytes = 0 + new_pool = defaultdict(list) + torch._C._cuda_beginAllocateCurrentThreadToPool(device, mempool) + for key, buffers in self._pool.items(): + if key not in graphed_keys: + continue + new_buffers = [] + for _ in range(len(buffers)): + buf = self._allocate_buffer(*self.key_to_allocate_func[key]) + new_buffers.append(buf) + new_pool[key] = new_buffers + torch._C._cuda_endAllocateToPool(device, mempool) + + # Map each buffer in the old pool to its corresponding new one (GRAPHED only) + old_to_new_buff = {} + for key, old_pool in self._pool.items(): + if key not in graphed_keys: + continue + new = new_pool[key] + for old_buf, new_buf in zip(old_pool, new): + old_to_new_buff[old_buf] = new_buf + + # Replace each GRAPHED slot's reference; keep UNGRAPHED slots unchanged + for slot in self._slots.values(): + if slot.chain_id == GTPChain.GRAPHED.value and slot.buf is not None and slot.buf in old_to_new_buff: + slot.buf = old_to_new_buff[slot.buf] + + # Merge: GRAPHED keys get new buffers, UNGRAPHED keys keep old ones + for key, buffers in self._pool.items(): + if key not in graphed_keys: + new_pool[key] = buffers + self._pool = new_pool + + # Remap quantized params into the CG mempool — but only for params on + # the GRAPHED chain. UNGRAPHED-chain params (embedding, output_layer, + # and MoE paths whose scope is not captured) run eagerly and don't + # need their quantized storage in the CG mempool. + torch._C._cuda_beginAllocateCurrentThreadToPool(device, mempool) + for w in _GTP_PARAMS: + if getattr(w, "chain_id", GTPChain.GRAPHED.value) != GTPChain.GRAPHED.value: + continue + if w.quantized is None: + continue + if isinstance(w.quantized, NVFP4TensorStorage): + w.quantized._rowwise_data = torch.clone(w.quantized._rowwise_data) + w.quantized._columnwise_data = torch.clone(w.quantized._columnwise_data) + w.quantized._rowwise_scale_inv = torch.clone(w.quantized._rowwise_scale_inv) + w.quantized._columnwise_scale_inv = torch.clone(w.quantized._columnwise_scale_inv) + w.quantized._amax_columnwise = torch.clone(w.quantized._amax_columnwise) + w.quantized._amax_rowwise = torch.clone(w.quantized._amax_rowwise) + elif isinstance(w.quantized, MXFP8TensorStorage): + w.quantized._rowwise_data = torch.clone(w.quantized._rowwise_data) + w.quantized._columnwise_data = torch.clone(w.quantized._columnwise_data) + w.quantized._rowwise_scale_inv = torch.clone(w.quantized._rowwise_scale_inv) + w.quantized._columnwise_scale_inv = torch.clone(w.quantized._columnwise_scale_inv) + else: + assert False + torch._C._cuda_endAllocateToPool(device, mempool) + + return + +def get_global_GTP_cache() -> GTPWeightCache: + """Get or lazily create the global cache instance.""" + global _GTP_CACHE + if _GTP_CACHE is None: + _GTP_CACHE = GTPWeightCache() + return _GTP_CACHE + + +def reallocate_gtp_cache_to_mempool(device, mempool): + """Re-allocate all GTP cache buffers into a CUDA graph memory pool.""" + if _GTP_CACHE is not None: + _GTP_CACHE.reallocate_to_mempool(device, mempool) + + +def wait_async_comms(chain_id: str = None, skip_rs: bool = False, finalize_after_drain: bool = False): + """Drain in-flight GTP async AG / RS handles. + + When called inside CUDA graph capture, the drains are captured into that + graph. This is the producer-side hook for cross-graph AG/RS overlap: + captured cudaStreamWaitEvent on an event recorded in a different capture + session is a CUDA no-op, so consumer graphs can't safely wait on + cross-graph events. Instead, the producer drains here and flags the + param; the consumer reads the flag and skips its captured wait. + + Args: + chain_id: If specified, only drain params on this chain. + skip_rs: Drain AG only; leave RS in flight. + finalize_after_drain: After RS drain, also accumulate wgrad into + main_grad. Runs main_grad.add_ on rs_stream (right after + NCCL RS) so it starts during AG drain rather than after, + avoiding SM-saturation that blocks cross-graph overlap. + Falls back to caller-stream _finalize_wgrad if no RS handle. + + Per-param side effects: + * _already_ag_drained = True (if an AG handle was drained) + * _already_finalized = True (if finalize_after_drain=True) + """ + for param in list(_inflight_comm_params): + if chain_id is not None and getattr(param, 'chain_id', GTPChain.UNGRAPHED.value) != chain_id: + continue + had_ag = param._prefetch_handle is not None + param._wait_param_gather() + if had_ag: + param._already_ag_drained = True + if not skip_rs: + param._wait_reduce_scatter(finalize_grad=finalize_after_drain) + if finalize_after_drain and not getattr(param, '_already_finalized', False): + cache = get_global_GTP_cache() + param.rs_event.wait() + for w in param._weights: + GTPShardedParam._finalize_wgrad(w, cache.get(w._rs_ticket)) + cache.release(w._rs_ticket) + param._already_finalized = True + + +@dataclass +class BatchedNVFP4AllGatherAsyncHandle: + """Handle for batched asynchronous NVFP4 all-gathers.""" + output_handles: List[_NVFP4AllGatherAsyncHandle] + outer_async_handle: torch.distributed.Work + _synchronized: bool = False + + def wait(self) -> None: + """Wait for the async operation to complete and post-process the tensor.""" + if self._synchronized: + return + self.outer_async_handle.wait() + # Fixes interleaved data for transposed tensor/scale inv and pads scale inv if needed. + for output_handle in self.output_handles: + if output_handle is not None: + assert output_handle.async_handle is None + output_handle.post_process_nvfp4_gather() + # release any tensor references just in case + output_handle.output = None + output_handle.columnwise_data_interleaved = None + output_handle.columnwise_scale_inv_interleaved = None + + self._synchronized = True + + +def grouped_gather_along_first_dim( + weights: list, + process_group, + async_op: bool = False, + quantizers: list = None, + output_tensors: list = None, +): + """ + All-gather multiple weights in a single coalesced operation. + + Handles NVFP4 post-processing for both sync and async paths. + """ + # Determine device from first weight. + inp = weights[0] + if isinstance(inp, NVFP4TensorStorage): + device = ( + inp._rowwise_data.device if inp._rowwise_data is not None + else inp._columnwise_data.device + ) + else: + device = inp.device + + weights_all = [] + weight_handles = [] + with torch.distributed._coalescing_manager( + group=process_group, device=device, async_ops=async_op, + ) as gather_coalescing_manager: + for i, weight in enumerate(weights): + weight_all, weight_handle = gather_along_first_dim( + weight, process_group, + quantizer=quantizers[i], + output_tensor=output_tensors[i] if output_tensors is not None else None, + grouped=True, + ) + weights_all.append(weight_all) + weight_handles.append(weight_handle) + + if async_op: + handle = gather_coalescing_manager + has_nvfp4_handles = any( + isinstance(wh, _NVFP4AllGatherAsyncHandle) for wh in weight_handles + ) + if has_nvfp4_handles: + handle = BatchedNVFP4AllGatherAsyncHandle(weight_handles, handle) + else: + for wh in weight_handles: + if isinstance(wh, _NVFP4AllGatherAsyncHandle): + wh.post_process_nvfp4_gather() + handle = None + + return weights_all, handle diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index 627144345c..fa343fc61c 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -5,6 +5,7 @@ """GroupedLinear API""" from typing import Union, Optional, Callable, Tuple, List from itertools import chain +import traceback import warnings import weakref @@ -24,6 +25,7 @@ _2X_ACC_WGRAD, ) from ._common import WeightGradStore +from .generalized_tensor_parallelism import wrap_module_params_gtp from ..quantization import FP8GlobalStateManager, QuantizerRole from ..utils import ( divide, @@ -100,6 +102,7 @@ def forward( skip_fp8_weight_update, save_original_input, debug, + gtp_size, ) = non_tensor_args if fp8: backward_override = FP8GlobalStateManager.get_fp8_recipe().backward_override @@ -114,6 +117,14 @@ def forward( device = inp.device weight_requires_grad = weights[0].requires_grad + if gtp_size > 1: + weights_gtp_sharded = weights + weights = weights[0].batched_all_gather_and_prefetch( + fwd=True, + skip_weight_cast=is_first_microbatch is False, + cast_noop_flag=skip_fp8_weight_update, + ) + # Configure quantizers if save_original_input and isinstance(input_quantizers[0], Float8Quantizer): if FP8GlobalStateManager.get_fp8_recipe().custom(): @@ -276,12 +287,19 @@ def forward( else: inputmats = [None] * num_gemms - tensors_to_save, tensor_objects = prepare_for_saving( - *inputmats, - *weights_fp8, - *weights, - *biases, - ) + if gtp_size == 1: + tensors_to_save, tensor_objects = prepare_for_saving( + *inputmats, + *weights_fp8, + *weights, + *biases, + ) + else: + tensors_to_save, tensor_objects = prepare_for_saving( + *inputmats, + *weights_gtp_sharded, + *biases, + ) ctx.save_for_backward(*tensors_to_save) ctx.tensor_objects = tensor_objects @@ -303,6 +321,8 @@ def forward( if hasattr(weights[0], "__fsdp_param__"): # MCore FSDP creates main_grad lazily before backward ctx.main_grad_funcs = [weights[i].get_main_grad for i in range(num_gemms)] + elif gtp_size > 1: + ctx.main_grad_funcs = [weights_gtp_sharded[i].get_wgrad_tensor for i in range(num_gemms)] else: ctx.main_grad_funcs = [ lambda j=i: weights[j].main_grad for i in range(num_gemms) @@ -332,6 +352,7 @@ def forward( ctx.debug = debug ctx.save_original_input = save_original_input ctx.input_quantizers = input_quantizers + ctx.gtp_size = gtp_size # backward overrides if backward_override is not None: @@ -357,27 +378,49 @@ def backward( with get_nvtx_range_context("_GroupedLinear_backward"): saved_tensors = restore_from_func_ctx(ctx) N = ctx.num_gemms - inputmats = saved_tensors[:N] - weights = saved_tensors[N : 2 * N] - saved_weights = saved_tensors[2 * N : 3 * N] - biases = saved_tensors[3 * N : 4 * N] + if ctx.gtp_size == 1: + inputmats = saved_tensors[:N] + weights = saved_tensors[N : 2 * N] + saved_weights = saved_tensors[2 * N : 3 * N] + biases = saved_tensors[3 * N : 4 * N] + gtp_origin_weights = None + else: + inputmats = saved_tensors[:N] + gtp_origin_weights = saved_tensors[N : 2 * N] + biases = saved_tensors[2 * N : 3 * N] # Restore from weakrefs to get original weight python objects # (preserves attributes like main_grad, grad_added_to_main_grad, etc.) # Only needed when fuse_wgrad_accumulation is enabled. origin_weights = [None] * N main_grads = [None] * N + if ctx.gtp_size > 1: + # GTP: origin_weights are the GTPShardedParam list saved by + # the forward (not weakrefs); take them from the saved + # tensors directly. + origin_weights = gtp_origin_weights if ctx.fuse_wgrad_accumulation and ctx.weights_requires_grad: - origin_weight_refs = ctx.origin_weight_refs - ctx.origin_weight_refs = None - origin_weights = [ref() if ref is not None else None for ref in origin_weight_refs] - assert all( - w is not None for w in origin_weights - ), "weight was removed while fuse_wgrad_accumulation=True" + if ctx.gtp_size == 1: + origin_weight_refs = ctx.origin_weight_refs + ctx.origin_weight_refs = None + origin_weights = [ + ref() if ref is not None else None for ref in origin_weight_refs + ] + assert all( + w is not None for w in origin_weights + ), "weight was removed while fuse_wgrad_accumulation=True" + # Always populate main_grads from main_grad_funcs. For GTP, + # these resolve to get_wgrad_tensor and return unsharded + # scratch buffers (used by grouped_gemm_wgrad as the output + # buffer before batched_wgrad_reduce_scatter). Don't write + # them onto the GTP param's main_grad attribute — that + # attribute is the per-iter sharded accumulator owned by + # mcore DDP, not the GEMM scratch. main_grads = [main_grad_func() for main_grad_func in ctx.main_grad_funcs] - for origin_weight, main_grad in zip(origin_weights, main_grads): - if main_grad is not None: - origin_weight.main_grad = main_grad + if ctx.gtp_size == 1: + for origin_weight, main_grad in zip(origin_weights, main_grads): + if main_grad is not None: + origin_weight.main_grad = main_grad # Preprocess grad output grad_output_view = grad_output.contiguous().view(-1, grad_output.shape[-1]) @@ -428,13 +471,18 @@ def backward( ctx.m_splits, ) - if ctx.is_first_microbatch is not None: + if ctx.gtp_size > 1: + accumulate_wgrad_into_param_main_grad = False + elif ctx.is_first_microbatch is not None: accumulate_wgrad_into_param_main_grad = ( ctx.fuse_wgrad_accumulation and not ctx.is_first_microbatch ) else: accumulate_wgrad_into_param_main_grad = ctx.fuse_wgrad_accumulation + if ctx.gtp_size > 1: + weights = origin_weights[0].batched_all_gather_and_prefetch_bwd() + if ctx.requires_dgrad: dgrad_gemm_use_split_accumulator = _2X_ACC_DGRAD if ctx.fp8 or ctx.debug: @@ -485,6 +533,13 @@ def backward( use_split_accumulator=dgrad_gemm_use_split_accumulator, ) + # Gathered weights are no longer needed after dgrad GEMM. + # For nvfp4, the NVFP4TensorStorage and its sub-tensors (scale_inv etc.) + # would otherwise survive until function return via this local ref. + if ctx.gtp_size > 1: + w_shape = list(weights[0].size()) + del weights + if ctx.weights_requires_grad: wgrad_gemm_use_split_accumulator = _2X_ACC_WGRAD if ctx.fp8: @@ -496,7 +551,7 @@ def backward( if ctx.fuse_wgrad_accumulation: wgrad_list = main_grads else: - weight_shape = list(weights[0].size()) + weight_shape = w_shape if ctx.gtp_size > 1 else list(weights[0].size()) wgrad_list = tex.bulk_allocate( [weight_shape] * ctx.num_gemms, [ctx.activation_dtype] * ctx.num_gemms, @@ -553,7 +608,7 @@ def backward( use_split_accumulator=wgrad_gemm_use_split_accumulator, accumulate=( accumulate_wgrad_into_param_main_grad - if not getattr(ctx, "origin_weights_overwrite_main_grad", False) + if ctx.gtp_size == 1 and not getattr(ctx, "origin_weights_overwrite_main_grad", False) else False ), ) @@ -595,10 +650,19 @@ def handle_custom_ddp_from_mcore(weight, main_grad, wgrad): wgrad = None return wgrad - wgrad_list = [ - handle_custom_ddp_from_mcore(weight, main_grad, wgrad) - for weight, main_grad, wgrad in zip(origin_weights, main_grads, wgrad_list) - ] + if ctx.gtp_size > 1: + wgrad_list = origin_weights[0].batched_wgrad_reduce_scatter(wgrad_list) + # Drop Python refs to wgrad input buffers. The async RS on rs_stream + # still holds C++ refs (via NCCL Work); those are released when + # _wait_reduce_scatter calls handle.wait() + self.handle = None. + # Without this del, main_grads keeps the tensors alive until function + # return, wasting memory during graph capture warmup. + del main_grads + else: + wgrad_list = [ + handle_custom_ddp_from_mcore(weight, main_grad, wgrad) + for weight, main_grad, wgrad in zip(origin_weights, main_grads, wgrad_list) + ] else: wgrad_list = [None] * ctx.num_gemms @@ -716,6 +780,7 @@ def __init__( single_grouped_weight: bool = False, single_grouped_bias: bool = False, name: Optional[str] = None, + gtp_group: Optional[dist_group_type] = None, ) -> None: super().__init__(name) @@ -771,6 +836,12 @@ def __init__( "Because the TP communication is handled outside of this module." ) + if gtp_group is None: + self.gtp_size = 1 + else: + self.gtp_size = get_distributed_world_size(gtp_group) + assert tp_size == 1, f"TODO(shiqingf): GTP+TP is not well supported yet." + self.parallel_mode = parallel_mode if self.parallel_mode not in GemmParallelModes: raise ValueError( @@ -823,8 +894,22 @@ def __init__( self.init_fp8_metadata(num_gemms=self.num_gemms) is_meta = torch.device(device).type == "meta" + if gtp_group is not None: + # Stash gtp_group + weight_names before reset_parameters so its + # slice hook fires per weightI (gated to GEMM weights so biases + # stay full-size); _gtp_is_grouped routes through the routed-expert + # path and triggers the post-loop weight_list stitch. + self.weight_names = [f"weight{idx}" for idx in range(self.num_gemms)] + self._gtp_group = gtp_group + self._gtp_is_grouped = True + self.reset_parameters(defer_init=is_meta) + if gtp_group is not None: + # No-op safety net for non-TE call sites; slice already done by + # the reset_parameters hook (wrap_module_params_gtp short-circuits). + wrap_module_params_gtp(self, self.weight_names, gtp_group, is_grouped=True) + if self.wgrad_store.delay_wgrad_compute(): for name, param in self.named_parameters(): if name in ("weight", "bias"): @@ -1148,6 +1233,11 @@ def forward( weight_tensors = self._get_weight_tensors() bias_tensors = self._get_bias_tensors() + if self.gtp_size > 1: + weight_tensors[0].setup( + weight_quantizer=self._get_weight_quantizers(), + ) + quantizers = self._get_quantizers() if not debug else self._get_debug_quantizers() if debug: @@ -1202,6 +1292,7 @@ def forward( None, # skip_fp8_weight_update self.save_original_input, debug, + self.gtp_size, ) out, new_workspaces = linear_fn( *autograd_ctx, inp, non_tensor_args, *weight_tensors, *bias_tensors diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 8c88f3ee82..ad88c91af6 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -28,6 +28,7 @@ _2X_ACC_DGRAD, _2X_ACC_WGRAD, ) +from .generalized_tensor_parallelism import wrap_module_params_gtp, GTP_CONFIG from ..quantization import FP8GlobalStateManager, QuantizerRole from ..utils import ( assert_dim_for_fp8_exec, @@ -143,6 +144,7 @@ def forward( symmetric_ar_type, debug, is_fsdp2, + gtp_size, ) = non_tensor_args if fp8: backward_override = FP8GlobalStateManager.get_fp8_recipe().backward_override @@ -297,6 +299,15 @@ def forward( # ------------------------------------------------------ # Prepare weight tensor # ------------------------------------------------------ + + if gtp_size > 1: + weight_gtp_sharded = weight + weight = weight.all_gather_and_prefetch( + fwd=True, + skip_weight_cast=is_first_microbatch is False, + cast_noop_flag=skip_fp8_weight_update, + ) + new_weight_workspace = None weightmat = weight is_weight_param_quantized = False @@ -420,7 +431,7 @@ def forward( nvtx_range_pop(f"{nvtx_label}.row_parallel_comm") else: out = gemm_out - out = out.view(-1, *inp_shape[1:-1], out_features) + out = out.view(-1, *inp_shape[1:-1], out.shape[-1]) # ------------------------------------------------------ # Output tensor is ready to return... # ------------------------------------------------------ @@ -484,8 +495,9 @@ def forward( wt_save = None tensors_to_save, tensor_objects = prepare_for_saving( inputmat, - wt_save, - weight, + # For GTP, avoid keeping the gathered weightmat in memory for memory saving. + wt_save if gtp_size == 1 else None, + weight if gtp_size == 1 else weight_gtp_sharded, bias, ln_weight, ln_out_to_save, @@ -512,6 +524,8 @@ def forward( if hasattr(weight, "__fsdp_param__"): # MCore FSDP creates main_grad lazily before backward ctx.main_grad_func = weight.get_main_grad + elif gtp_size > 1: + ctx.main_grad_func = weight_gtp_sharded.get_wgrad_tensor else: ctx.main_grad_func = lambda: weight.main_grad ctx.grad_input_quantizer = grad_input_quantizer @@ -554,6 +568,7 @@ def forward( qstate.is_first_fp8_module = _first_fp8_module ctx.wgrad_store = wgrad_store ctx.debug = debug + ctx.gtp_size = gtp_size # backward overrides if backward_override is not None: @@ -605,6 +620,9 @@ def backward( rsigma, ) = restore_from_func_ctx(ctx) + if ctx.gtp_size > 1: + weight = saved_weight.all_gather_and_prefetch_bwd() + # Restore from weakref to get original weight python object # (preserves attributes like main_grad, grad_added_to_main_grad, etc.) # Only needed when fuse_wgrad_accumulation is enabled. @@ -622,7 +640,12 @@ def backward( ), "weight was removed while fuse_wgrad_accumulation=True" # Since main_grad can be modified inplace, it should not be a part of saved_tensors main_grad = ctx.main_grad_func() if weight is not None else None - if main_grad is not None: + # GTP: main_grad here is the GTPShardedParam's per-iteration + # wgrad scratch buffer (get_wgrad_tensor, unsharded shape), + # not the per-param accumulator. Don't overwrite the param's + # main_grad attribute with it; the wgrad block will RS this + # scratch into the sharded main_grad downstream. + if main_grad is not None and ctx.gtp_size == 1: origin_weight.main_grad = main_grad # Gather intermediate/activation tensors if needed @@ -747,6 +770,20 @@ def backward( # Input tensor is ready for computing grad weight... # -------------------------------------------------- + # When GTPConfig.wgrad_before_dgrad is True (GTP-only opt-in), + # _do_wgrad runs before _do_dgrad and the inline GTP wgrad + # reduce-scatter NCCL overlaps with the dgrad GEMM that + # follows. The prev_w AG prefetch issued by + # all_gather_and_prefetch_bwd above then overlaps with the + # wgrad GEMM. + swap_wgrad_dgrad = ( + ctx.gtp_size > 1 and GTP_CONFIG.wgrad_before_dgrad + ) + + dgrad = None + dgrad_work = None + wgrad = None + # -------------------------------------------------- # Compute grad input tensor # Note: Gradient w.r.t. GEMM input (i.e. norm output). @@ -775,86 +812,94 @@ def backward( ): weight.update_usage(columnwise_usage=True) - # Choose whether to use GEMM kernel with split accumulator - use_split_accumulator = _2X_ACC_DGRAD - if ctx.fp8: - recipe = ctx.fp8_recipe - if hasattr(recipe, "fp8_gemm_dgrad"): - use_split_accumulator = recipe.fp8_gemm_dgrad.use_split_accumulator - - # Update grad input quantizer - if ctx.grad_input_quantizer is not None: - ctx.grad_input_quantizer.set_usage(rowwise=True, columnwise=False) - - # Output buffers for Userbuffers reduce-scatter - gemm_out = None - reduce_scatter_out = None - if ctx.ub_overlap_rs_dgrad: - reduce_scatter_out = torch.empty( - dgrad_shape, dtype=ctx.activation_dtype, device=grad_outputs[0].device - ) - elif ctx.ub_bulk_wgrad: - gemm_out = ub_obj_wgrad.get_buffer(local_chunk=False) + def _do_dgrad(): + nonlocal dgrad, dgrad_work - # dgrad GEMM - # Note: dx = dy * w - nvtx_range_push(f"{nvtx_label}.dgrad_gemm") - weight_for_dgrad = weight - if ctx.backward_override == "dequantized": - if isinstance(weight_for_dgrad, QuantizedTensorStorage): - weight_for_dgrad = weight_for_dgrad.dequantize(dtype=ctx.activation_dtype) - else: - weight_for_dgrad = cast_if_needed(weight_for_dgrad, ctx.activation_dtype) - elif ctx.backward_override == "high_precision": - weight_for_dgrad = saved_weight - if isinstance(weight_for_dgrad, QuantizedTensorStorage): - weight_for_dgrad = weight_for_dgrad.dequantize(dtype=ctx.activation_dtype) - gemm_out, *_, reduce_scatter_out = general_gemm( - weight_for_dgrad, - grad_output, - layout="NN", - grad=True, - quantization_params=ctx.grad_input_quantizer, - out=gemm_out, - out_dtype=ctx.activation_dtype, - use_split_accumulator=use_split_accumulator, - ub=ub_obj_dgrad, - ub_type=ub_type_dgrad, - extra_output=reduce_scatter_out, - bulk_overlap=ctx.ub_bulk_dgrad, - ) - nvtx_range_pop(f"{nvtx_label}.dgrad_gemm") - - # FSDP2 only handles deallocation all-gathered weights that it allocates. - # Columnwise data is derived from rowwise data after allgather for fp8 - # and 2d block-scaled weights in TE managed memory. So we need to clear - # it here. - # (Issues #2681, #2717) - if getattr(ctx, "is_fsdp2", False) and isinstance(weight, QuantizedTensorStorage): - clear_columnwise_cache(weight) - - # Prepare grad input tensor - # Note: Perform tensor-parallel communication - dgrad = None - dgrad_work = None - if ctx.ub_overlap_rs_dgrad: - dgrad = reduce_scatter_out - elif ctx.ub_bulk_wgrad: - dgrad = ub_obj_wgrad.get_buffer(local_chunk=True) - elif ctx.parallel_mode == "column" and ctx.tp_size > 1: - nvtx_range_push(f"{nvtx_label}.column_parallel_comm_dgrad") - dgrad = gemm_out - if ctx.sequence_parallel: - dgrad, dgrad_work = reduce_scatter_along_first_dim( - dgrad, - ctx.tp_group, - async_op=True, + if isinstance(grad_output, QuantizedTensorStorage): + grad_output.update_usage(rowwise_usage=True) + if ctx.weight_quantizer is not None and isinstance(weight, QuantizedTensorStorage): + weight.update_usage(columnwise_usage=True) + + # Choose whether to use GEMM kernel with split accumulator + use_split_accumulator = _2X_ACC_DGRAD + if ctx.fp8: + recipe = ctx.fp8_recipe + if hasattr(recipe, "fp8_gemm_dgrad"): + use_split_accumulator = recipe.fp8_gemm_dgrad.use_split_accumulator + + # Update grad input quantizer + if ctx.grad_input_quantizer is not None: + ctx.grad_input_quantizer.set_usage(rowwise=True, columnwise=False) + + # Output buffers for Userbuffers reduce-scatter + gemm_out = None + reduce_scatter_out = None + if ctx.ub_overlap_rs_dgrad: + reduce_scatter_out = torch.empty( + dgrad_shape, dtype=ctx.activation_dtype, device=grad_outputs[0].device ) + elif ctx.ub_bulk_wgrad: + gemm_out = ub_obj_wgrad.get_buffer(local_chunk=False) + + # dgrad GEMM + # Note: dx = dy * w + nvtx_range_push(f"{nvtx_label}.dgrad_gemm") + weight_for_dgrad = weight + if ctx.backward_override == "dequantized": + if isinstance(weight_for_dgrad, QuantizedTensorStorage): + weight_for_dgrad = weight_for_dgrad.dequantize(dtype=ctx.activation_dtype) + else: + weight_for_dgrad = cast_if_needed(weight_for_dgrad, ctx.activation_dtype) + elif ctx.backward_override == "high_precision": + weight_for_dgrad = saved_weight + if isinstance(weight_for_dgrad, QuantizedTensorStorage): + weight_for_dgrad = weight_for_dgrad.dequantize(dtype=ctx.activation_dtype) + gemm_out, *_, reduce_scatter_out = general_gemm( + weight_for_dgrad, + grad_output, + layout="NN", + grad=True, + quantization_params=ctx.grad_input_quantizer, + out=gemm_out, + out_dtype=ctx.activation_dtype, + use_split_accumulator=use_split_accumulator, + ub=ub_obj_dgrad, + ub_type=ub_type_dgrad, + extra_output=reduce_scatter_out, + bulk_overlap=ctx.ub_bulk_dgrad, + ) + nvtx_range_pop(f"{nvtx_label}.dgrad_gemm") + + # FSDP2 only handles deallocation all-gathered weights that it allocates. + # Columnwise data is derived from rowwise data after allgather for fp8 + # and 2d block-scaled weights in TE managed memory. So we need to clear + # it here. + # (Issues #2681, #2717) + if getattr(ctx, "is_fsdp2", False) and isinstance(weight, QuantizedTensorStorage): + clear_columnwise_cache(weight) + + # Prepare grad input tensor + # Note: Perform tensor-parallel communication + dgrad = None + dgrad_work = None + if ctx.ub_overlap_rs_dgrad: + dgrad = reduce_scatter_out + elif ctx.ub_bulk_wgrad: + dgrad = ub_obj_wgrad.get_buffer(local_chunk=True) + elif ctx.parallel_mode == "column" and ctx.tp_size > 1: + nvtx_range_push(f"{nvtx_label}.column_parallel_comm_dgrad") + dgrad = gemm_out + if ctx.sequence_parallel: + dgrad, dgrad_work = reduce_scatter_along_first_dim( + dgrad, + ctx.tp_group, + async_op=True, + ) + else: + dgrad, dgrad_work = allreduce(dgrad, ctx.tp_group, async_op=True) + nvtx_range_pop(f"{nvtx_label}.column_parallel_comm_dgrad") else: - dgrad, dgrad_work = allreduce(dgrad, ctx.tp_group, async_op=True) - nvtx_range_pop(f"{nvtx_label}.column_parallel_comm_dgrad") - else: - dgrad = gemm_out + dgrad = gemm_out # -------------------------------------------------- # Grad input tensor has been computed... @@ -864,8 +909,11 @@ def backward( # Compute grad weight # -------------------------------------------------- - wgrad = None - if ctx.requires_wgrad: + def _do_wgrad(): + nonlocal wgrad, grad_bias, ln_out_total, ln_out_total_work, grad_output, dgrad + if not ctx.requires_wgrad: + return + # Prepare grad output tensor # Note: Synchronize tensor-parallel communication and # make sure required data is available @@ -929,7 +977,11 @@ def backward( use_split_accumulator = recipe.fp8_gemm_wgrad.use_split_accumulator # Figure out whether to output wgrad GEMM directly into main grad - if ctx.is_first_microbatch is not None: + if ctx.gtp_size > 1: + # When GTP is enabled, GA is always disabled. GTP Wgrad workflow: + # allocte wgrad_out tmp buffer -> RS(wgrad_gemm) -> GradientAccumulation + accumulate_wgrad_into_param_main_grad = False + elif ctx.is_first_microbatch is not None: accumulate_wgrad_into_param_main_grad = ( ctx.fuse_wgrad_accumulation and not ctx.is_first_microbatch ) @@ -1001,6 +1053,9 @@ def wgrad_gemm( # Call wgrad GEMM now wgrad, grad_bias_ = wgrad_gemm(ln_out_total, grad_output) + if ctx.gtp_size > 1: + wgrad = saved_weight.wgrad_reduce_scatter(wgrad) + # Update grad bias if needed if grad_bias is None: grad_bias = grad_bias_ @@ -1016,17 +1071,35 @@ def wgrad_gemm( if ctx.ln_out_needs_gather: # Gathered input is internal clear_tensor_data(ln_out_total) - if ctx.parallel_mode == "row" and ctx.sequence_parallel: + # In swap mode the dgrad GEMM has not run yet and still + # reads grad_output; defer the cleanup to the outer scope + # (grad_output goes out of scope when bwd returns). + if ( + not swap_wgrad_dgrad + and ctx.parallel_mode == "row" + and ctx.sequence_parallel + ): # Gathered grad output tensor is internal clear_tensor_data(grad_output) - # Update grad input if overlapping reduce-scatter with wgrad GEMM - if ctx.ub_bulk_wgrad: + # Update grad input if overlapping reduce-scatter with wgrad GEMM. + # In swap mode dgrad GEMM follows _do_wgrad and writes dgrad + # freshly, so this ub_bulk_wgrad write would be clobbered; skip. + if not swap_wgrad_dgrad and ctx.ub_bulk_wgrad: if ub_obj_wgrad.is_fp8_ubuf(): dgrad = reduce_scatter_out else: dgrad = ub_obj_wgrad.get_buffer(local_chunk=True).clone() + # Dispatch wgrad/dgrad in chosen order. In swap mode, _do_wgrad + # issues the GTP wgrad reduce-scatter inline; its NCCL kernel + # overlaps with the dgrad GEMM in _do_dgrad that follows. + if swap_wgrad_dgrad: + _do_wgrad() + _do_dgrad() + else: + _do_dgrad() + _do_wgrad() # -------------------------------------------------- # Grad weight has been computed... # -------------------------------------------------- @@ -1080,7 +1153,9 @@ def wgrad_gemm( if ctx.requires_wgrad: # Handle custom DDP from mcore. - if ctx.fuse_wgrad_accumulation and hasattr(origin_weight, "grad_added_to_main_grad"): + if ctx.gtp_size > 1: + pass + elif ctx.fuse_wgrad_accumulation and hasattr(origin_weight, "grad_added_to_main_grad"): origin_weight.grad_added_to_main_grad = True if getattr(origin_weight, "zero_out_wgrad", False): wgrad = get_dummy_wgrad( @@ -1247,6 +1322,7 @@ def __init__( delay_wgrad_compute: bool = False, symmetric_ar_type: Optional[str] = None, name: Optional[str] = None, + gtp_group: Optional[dist_group_type] = None, ) -> None: super().__init__(name) @@ -1277,6 +1353,10 @@ def __init__( self.set_tensor_parallel_group(tp_group) self.set_nccl_overlap_warning_if_tp() + if gtp_group is None: + self.gtp_size = 1 + else: + self.gtp_size = get_distributed_world_size(gtp_group) self.parallel_mode = parallel_mode assert ( self.parallel_mode in GemmParallelModes @@ -1471,8 +1551,20 @@ def __init__( if with_fp8_params: self.init_fp8_metadata() + if gtp_group is not None: + # Stash gtp_group before reset_parameters so its slice hook fires + # (slices BF16 weight pre-quantize; skips FP8 wrap for GTP shards). + self._gtp_group = gtp_group + self._gtp_is_grouped = False + self.reset_parameters(defer_init=device == "meta") + if gtp_group is not None: + # No-op safety net for non-TE call sites; slice already done by + # the reset_parameters hook (wrap_module_params_gtp short-circuits). + wrap_module_params_gtp(self, self.weight_names, gtp_group) + del weight_tensor + # For RPL, bias has to be added after TP collectives # So it cannot be fused with the GEMM if self.parallel_mode == "row" and self.apply_bias: @@ -1635,6 +1727,11 @@ def forward( # Get concatenated weight and bias tensors weight_tensor, bias_tensor = self._get_weight_and_bias_tensors() + if self.gtp_size > 1: + weight_tensor.setup( + weight_quantizer=self._get_weight_quantizers(), + ) + quantizers = ( self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled) if not debug @@ -1705,6 +1802,7 @@ def forward( self.symmetric_ar_type, debug, self.is_fsdp2, + self.gtp_size, ) out, ln_out, new_weight_workspace = fwd_fn( *autograd_ctx, diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index dcbb9eaf93..0b6d0535d4 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -28,6 +28,7 @@ _2X_ACC_WGRAD, ) from ._common import noop_cat, WeightGradStore +from .generalized_tensor_parallelism import wrap_module_params_gtp, GTP_CONFIG from ..quantization import FP8GlobalStateManager, QuantizerRole from ..utils import ( cast_if_needed, @@ -153,6 +154,9 @@ class LinearFwdArgs: cpu_offloading: bool is_grad_enabled: bool + # --- Extended tensor parallelism --- + gtp_size: int = 1 + @dataclass(slots=True) class LinearBwdArgs: @@ -222,6 +226,9 @@ class LinearBwdArgs: cpu_offloading: bool = False owns_input: bool = False + # --- Extended tensor parallelism --- + gtp_size: int = 1 + # --- Per-backward scratch state (populated inside _linear_backward) --- ub_obj_gradout: Optional[Any] = None @@ -399,6 +406,19 @@ def _linear_forward_impl( # ------------------------------------------------------ # Prepare weight tensor # ------------------------------------------------------ + # GTP: replace the sharded weight (local rebind only) with its + # all-gathered counterpart and async-prefetch the next weight in the + # chain. args.weight remains the original GTPShardedParam so that the + # saved-tensor path (slot 2 alias "weight" in _linear_setup_ctx) and + # the backward path receive the sharded reference for re-gathering and + # for wgrad_reduce_scatter. + if args.gtp_size > 1: + weight = weight.all_gather_and_prefetch( + fwd=True, + skip_weight_cast=is_first_microbatch is False, + cast_noop_flag=args.skip_fp8_weight_update, + ) + new_weight_workspace = None weightmat = weight if fp8 or debug: @@ -576,6 +596,10 @@ def _linear_forward_impl( wt_save = weightmat if is_fsdp2 and weightmat is not weight: wt_save = None + # GTP: never save the FP8 weight workspace — backward re-gathers + # the sharded weight from saved_weight (fwd_args.weight). + if args.gtp_size > 1: + wt_save = None # Dedup save slots that alias forward inputs; ``_linear_setup_ctx`` # rebuilds the refs from ``inp`` / ``weight`` / ``bias``. @@ -680,11 +704,16 @@ def _linear_setup_ctx( bwd_args.origin_weight_overwrites_main_grad = getattr(weight, "overwrite_main_grad", False) if hasattr(weight, "__fsdp_param__"): bwd_args.main_grad_func = weight.get_main_grad + elif fwd_args.gtp_size > 1: + # GTP: main_grad lives on the GTPShardedParam (= the original + # sharded weight in fwd_args.weight). + bwd_args.main_grad_func = weight.get_wgrad_tensor else: bwd_args.main_grad_func = lambda: weight.main_grad # Misc bwd_args.cpu_offloading = fwd_args.cpu_offloading + bwd_args.gtp_size = fwd_args.gtp_size if backward_override is not None: bwd_args.fp8 = False @@ -751,7 +780,11 @@ def _linear_backward(args: LinearBwdArgs) -> Tuple[Union[torch.Tensor, None], .. origin_weight_python_object is not None ), "weight was removed while fuse_wgrad_accumulation=True" main_grad = bwd_args.main_grad_func() - origin_weight_python_object.main_grad = main_grad + # GTP: main_grad here is the GTPShardedParam's per-iteration + # wgrad buffer (returned by get_wgrad_tensor). Don't overwrite + # the param's main_grad attribute with it. + if bwd_args.gtp_size == 1: + origin_weight_python_object.main_grad = main_grad # Gather intermediate/activation tensors if needed # NOTE: weight_fp8 = weight when bwd_args.fp8 == False and torch.disttributed.FSDP already @@ -921,7 +954,29 @@ def _linear_backward(args: LinearBwdArgs) -> Tuple[Union[torch.Tensor, None], .. dgrad = None dgrad_work = None - if bwd_args.requires_dgrad: + + # GTP: re-gather the sharded weight and async-prefetch the next + # weight in the chain. saved_weight is the original GTPShardedParam + # (saved by _linear_setup_ctx via the "weight" alias). This must + # run regardless of requires_dgrad so prev_w prefetch is issued + # for the next layer's bwd. + if bwd_args.gtp_size > 1: + weight_fp8 = saved_weight.all_gather_and_prefetch_bwd() + + wgrad = None + + # When GTPConfig.wgrad_before_dgrad is True and GTP is active, run + # _do_wgrad before _do_dgrad so the GTP wgrad RS NCCL overlaps with + # the dgrad GEMM that follows. + swap_wgrad_dgrad = ( + bwd_args.gtp_size > 1 and GTP_CONFIG.wgrad_before_dgrad + ) + + def _do_dgrad(): + nonlocal dgrad, dgrad_work, weight_fp8 + if not bwd_args.requires_dgrad: + return + # FSDP2: Re-create workspace from all-gathered weight when # workspace was not saved. (Issue #2681) @@ -1036,8 +1091,11 @@ def _linear_backward(args: LinearBwdArgs) -> Tuple[Union[torch.Tensor, None], .. # Compute grad weight # -------------------------------------------------- - wgrad = None - if bwd_args.requires_wgrad: + def _do_wgrad(): + nonlocal wgrad, grad_bias, inputmat_total, inputmat_total_work, grad_output, dgrad + if not bwd_args.requires_wgrad: + return + # Prepare input tensor # Note: Synchronize tensor-parallel communication and @@ -1102,7 +1160,12 @@ def _linear_backward(args: LinearBwdArgs) -> Tuple[Union[torch.Tensor, None], .. use_split_accumulator = recipe.fp8_gemm_wgrad.use_split_accumulator # Figure out whether to output wgrad GEMM directly into main grad - if bwd_args.is_first_microbatch is not None: + if bwd_args.gtp_size > 1: + # GTP: wgrad GEMM produces a fresh tmp buffer; the + # reduce-scatter + cascade gradient accumulation happens + # downstream in wgrad_reduce_scatter, not via fused GEMM. + accumulate_wgrad_into_param_main_grad = False + elif bwd_args.is_first_microbatch is not None: accumulate_wgrad_into_param_main_grad = ( bwd_args.fuse_wgrad_accumulation and not bwd_args.is_first_microbatch ) @@ -1178,6 +1241,13 @@ def wgrad_gemm( # Call wgrad GEMM now wgrad, grad_bias_ = wgrad_gemm(inputmat_total, grad_output) + # GTP: reduce-scatter the freshly computed wgrad. This + # issues the async RS NCCL immediately so it can overlap + # with the dgrad GEMM that follows in swap mode (and with + # the next layer's bwd in non-swap mode via the cascade). + if bwd_args.gtp_size > 1: + wgrad = saved_weight.wgrad_reduce_scatter(wgrad) + # Update grad bias if needed if grad_bias is None: grad_bias = grad_bias_ @@ -1190,17 +1260,34 @@ def wgrad_gemm( elif bwd_args.backward_input_needs_gather: # Gathered input tensor is internal clear_tensor_data(inputmat_total) - if bwd_args.parallel_mode == "row" and bwd_args.sequence_parallel: + # In swap mode the dgrad GEMM has not run yet and still + # reads grad_output; defer the cleanup to outer scope. + if ( + not swap_wgrad_dgrad + and bwd_args.parallel_mode == "row" + and bwd_args.sequence_parallel + ): # Gathered grad output tensor is internal clear_tensor_data(grad_output) - # Update grad input if overlapping reduce-scatter with wgrad GEMM - if bwd_args.ub_bulk_wgrad: + # Update grad input if overlapping reduce-scatter with wgrad GEMM. + # In swap mode dgrad GEMM follows _do_wgrad and writes dgrad + # freshly, so this ub_bulk_wgrad write would be clobbered; skip. + if not swap_wgrad_dgrad and bwd_args.ub_bulk_wgrad: if ub_obj_wgrad.is_fp8_ubuf(): dgrad = reduce_scatter_out else: dgrad = ub_obj_wgrad.get_buffer(local_chunk=True).clone() + + # Dispatch wgrad/dgrad in chosen order. + if swap_wgrad_dgrad: + _do_wgrad() + _do_dgrad() + else: + _do_dgrad() + _do_wgrad() + # -------------------------------------------------- # Grad weight has been computed... # -------------------------------------------------- @@ -1223,15 +1310,21 @@ def wgrad_gemm( origin_weight_python_object, "grad_added_to_main_grad" ): origin_weight_python_object.grad_added_to_main_grad = True + # Use the param's local shape (the sharded shape under GTP) so the + # dummy gradient autograd receives matches the autograd-saved + # weight shape. main_grad here is bwd_args.main_grad_func() which, + # for GTP, returns the *unsharded* wgrad scratch buffer + # (get_wgrad_tensor) and would produce a shape mismatch. + wgrad_shape = list(origin_weight_python_object.shape) if getattr(origin_weight_python_object, "zero_out_wgrad", False): wgrad = get_dummy_wgrad( - list(main_grad.shape), + wgrad_shape, origin_weight_python_object.dtype, zero=True, ) else: wgrad = get_dummy_wgrad( - list(main_grad.shape), + wgrad_shape, origin_weight_python_object.dtype, ) elif bwd_args.fuse_wgrad_accumulation: @@ -1447,6 +1540,7 @@ def __init__( symmetric_ar_type: Optional[str] = None, save_original_input: bool = False, name: Optional[str] = None, + gtp_group: Optional[dist_group_type] = None, ) -> None: super().__init__(name) @@ -1475,6 +1569,11 @@ def __init__( self.set_tensor_parallel_group(tp_group) self.set_nccl_overlap_warning_if_tp() + if gtp_group is None: + self.gtp_size = 1 + else: + self.gtp_size = get_distributed_world_size(gtp_group) + self.parallel_mode = parallel_mode assert ( self.parallel_mode in GemmParallelModes @@ -1644,8 +1743,20 @@ def __init__( if with_fp8_params: self.init_fp8_metadata() + if gtp_group is not None: + # Stash gtp_group before reset_parameters so its slice hook fires + # (slices BF16 weight pre-quantize; skips FP8 wrap for GTP shards). + self._gtp_group = gtp_group + self._gtp_is_grouped = False + self.reset_parameters(defer_init=device == "meta") + if gtp_group is not None: + # No-op safety net for non-TE call sites; slice already done by + # the reset_parameters hook (wrap_module_params_gtp short-circuits). + wrap_module_params_gtp(self, self.weight_names, gtp_group) + del weight_tensor + # For RPL, bias has to be added after TP collectives # So it cannot be fused with the GEMM if self.parallel_mode == "row" and self.apply_bias: @@ -1776,6 +1887,11 @@ def forward( try: weight_tensor, bias_tensor = self._get_weight_and_bias_tensors() + if self.gtp_size > 1: + weight_tensor.setup( + weight_quantizer=self._get_weight_quantizers(), + ) + quantizers = ( self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled) if not debug @@ -1894,6 +2010,8 @@ def forward( # misc cpu_offloading=is_cpu_offload_enabled(), is_grad_enabled=is_grad_enabled, + # generalized tensor parallelism + gtp_size=self.gtp_size, ) out, new_weight_workspace = linear_fn( *autograd_ctx, From 915bb941020c7540047ef64596b253679b125df7 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 18 May 2026 05:53:47 +0000 Subject: [PATCH 2/8] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/pytorch/distributed/test_gtp.py | 411 ++++++++++++------ tests/pytorch/distributed/test_tp_gtp.py | 164 ++++--- .../include/transformer_engine/recipe.h | 2 +- .../common/recipe/multi_amax.cu | 24 +- transformer_engine/pytorch/csrc/extensions.h | 3 +- .../pytorch/csrc/extensions/cast.cpp | 15 +- .../pytorch/csrc/extensions/pybind.cpp | 3 +- transformer_engine/pytorch/csrc/quantizer.cpp | 9 +- transformer_engine/pytorch/distributed.py | 18 +- transformer_engine/pytorch/module/base.py | 14 +- .../module/generalized_tensor_parallelism.py | 274 +++++++----- .../pytorch/module/grouped_linear.py | 7 +- .../pytorch/module/layernorm_linear.py | 4 +- transformer_engine/pytorch/module/linear.py | 7 +- 14 files changed, 587 insertions(+), 368 deletions(-) diff --git a/tests/pytorch/distributed/test_gtp.py b/tests/pytorch/distributed/test_gtp.py index 972af13762..f29075e06f 100644 --- a/tests/pytorch/distributed/test_gtp.py +++ b/tests/pytorch/distributed/test_gtp.py @@ -60,6 +60,7 @@ # Fixtures # --------------------------------------------------------------------------- + @pytest.fixture(autouse=True) def reset_fp8_state(): yield @@ -79,6 +80,7 @@ def reset_gtp_globals(): # Helpers # --------------------------------------------------------------------------- + def _free_port() -> int: with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: s.bind(("", 0)) @@ -112,6 +114,7 @@ def _requires_nvfp4(): # 1. GTPWeightState – state-machine transition tests # --------------------------------------------------------------------------- + class TestGTPWeightState: @staticmethod @@ -147,13 +150,16 @@ def test_rs_state_full_cycle(self): # 2. GTPWeightCache – coat-check buffer pool tests # --------------------------------------------------------------------------- + class TestGTPWeightCache: class _FakeGroup: def __init__(self, size=2): self._size = size + def size(self): return self._size + def rank(self): return 0 @@ -224,7 +230,8 @@ def test_different_shapes_use_distinct_pool_slots(self): t2 = cache.reserve(p2, torch.bfloat16, fwd=True) buf2 = cache.get(t2) assert buf1.shape != buf2.shape - cache.release(t1); cache.release(t2) + cache.release(t1) + cache.release(t2) def test_fwd_bwd_tickets_are_distinct(self): """fwd=True and fwd=False reserves always receive distinct ticket IDs.""" @@ -239,16 +246,17 @@ def test_fwd_bwd_tickets_are_distinct(self): # 3. GTP weight sharding: shard content and alignment padding # --------------------------------------------------------------------------- + def _worker_sharding_aligned(rank, world_size, port): _dist_init(rank, world_size, port) - K, M = world_size * 32, 16 # K divisible by 16*world_size → no padding + K, M = world_size * 32, 16 # K divisible by 16*world_size → no padding full_weight = torch.arange(K * M, dtype=torch.float32).reshape(K, M).cuda() dist.broadcast(full_weight, src=0) gtp_group = dist.new_group(list(range(world_size))) mod = nn.Module() mod.weight = nn.Parameter(full_weight.clone(), requires_grad=False) - wrap_module_params_gtp(mod, ['weight'], gtp_group) + wrap_module_params_gtp(mod, ["weight"], gtp_group) shard = mod.weight rows_per_rank = K // world_size @@ -262,7 +270,7 @@ def _worker_sharding_aligned(rank, world_size, port): def _worker_sharding_padding(rank, world_size, port): _dist_init(rank, world_size, port) alignment = 16 * world_size - K = alignment - 1 # deliberately unaligned + K = alignment - 1 # deliberately unaligned M = 16 full_weight = torch.ones(K, M, dtype=torch.float32).cuda() dist.broadcast(full_weight, src=0) @@ -270,7 +278,7 @@ def _worker_sharding_padding(rank, world_size, port): gtp_group = dist.new_group(list(range(world_size))) mod = nn.Module() mod.weight = nn.Parameter(full_weight.clone(), requires_grad=False) - wrap_module_params_gtp(mod, ['weight'], gtp_group) + wrap_module_params_gtp(mod, ["weight"], gtp_group) shard = mod.weight padded_K = alignment @@ -280,16 +288,18 @@ def _worker_sharding_padding(rank, world_size, port): assert shard.pad_length > 0 # The shard tensor holds only the real rows; get_padded_shard() appends zero rows. padded = shard.get_padded_shard() - assert padded.shape[0] == rows_per_rank, \ - f"rank {rank}: expected padded shard {rows_per_rank} rows, got {padded.shape[0]}" + assert ( + padded.shape[0] == rows_per_rank + ), f"rank {rank}: expected padded shard {rows_per_rank} rows, got {padded.shape[0]}" n_real = K - rank * rows_per_rank assert torch.all(padded[n_real:] == 0), "Padding rows must be zero" else: # pad_length is set globally on every rank's shard (slicer attaches the # global padding amount), so we don't assert anything about it here — # only the last rank's shard contains the actual padding rows. - assert shard.shape[0] == rows_per_rank, \ - f"rank {rank}: expected {rows_per_rank} rows, got {shard.shape[0]}" + assert ( + shard.shape[0] == rows_per_rank + ), f"rank {rank}: expected {rows_per_rank} rows, got {shard.shape[0]}" dist.destroy_process_group() @@ -308,14 +318,18 @@ def test_unaligned_shard_padding(self): # 4. wrap_module_params_gtp: param replacement and GroupedLinear weight_list # --------------------------------------------------------------------------- + def _worker_linear_param_replaced(rank, world_size, port): _dist_init(rank, world_size, port) in_f, out_f = 64, 128 gtp_group = dist.new_group(list(range(world_size))) layer = te.Linear( - in_features=in_f, out_features=out_f, - bias=False, params_dtype=torch.bfloat16, - device="cuda", gtp_group=gtp_group, + in_features=in_f, + out_features=out_f, + bias=False, + params_dtype=torch.bfloat16, + device="cuda", + gtp_group=gtp_group, ) w = layer.weight assert isinstance(w, GTPShardedParam), "weight must be GTPShardedParam" @@ -329,9 +343,13 @@ def _worker_grouped_weight_list(rank, world_size, port): num_gemms, in_f, out_f = 3, 32, 64 gtp_group = dist.new_group(list(range(world_size))) layer = te.GroupedLinear( - num_gemms=num_gemms, in_features=in_f, out_features=out_f, - bias=False, params_dtype=torch.bfloat16, - device="cuda", gtp_group=gtp_group, + num_gemms=num_gemms, + in_features=in_f, + out_features=out_f, + bias=False, + params_dtype=torch.bfloat16, + device="cuda", + gtp_group=gtp_group, ) w0 = layer.weight0 assert isinstance(w0, GTPShardedParam) @@ -355,18 +373,22 @@ def test_grouped_linear_weight_list(self): # 5. Linear forward/backward numerical correctness # --------------------------------------------------------------------------- + def _worker_linear_correctness(rank, world_size, port): """GTP output == (all-gathered weight) @ input, and dX matches.""" _dist_init(rank, world_size, port) torch.manual_seed(0) - batch, in_f, out_f = 16, 64, 128 # out_f % (16*world_size)==0 → no padding + batch, in_f, out_f = 16, 64, 128 # out_f % (16*world_size)==0 → no padding dtype = torch.bfloat16 gtp_group = dist.new_group(list(range(world_size))) layer = te.Linear( - in_features=in_f, out_features=out_f, - bias=False, params_dtype=dtype, - device="cuda", gtp_group=gtp_group, + in_features=in_f, + out_features=out_f, + bias=False, + params_dtype=dtype, + device="cuda", + gtp_group=gtp_group, ) # Reconstruct full weight from shards (all-gather) @@ -390,9 +412,9 @@ def _worker_linear_correctness(rank, world_size, port): out_ref = out_ref.to(dtype) assert out_gtp.shape == out_ref.shape, f"Shape mismatch {out_gtp.shape} vs {out_ref.shape}" - assert torch.allclose(out_gtp.float(), out_ref.float(), atol=0.1, rtol=0.1), ( - f"Output mismatch max_diff={(out_gtp.float()-out_ref.float()).abs().max():.4f}" - ) + assert torch.allclose( + out_gtp.float(), out_ref.float(), atol=0.1, rtol=0.1 + ), f"Output mismatch max_diff={(out_gtp.float()-out_ref.float()).abs().max():.4f}" # wgrad RS path always accumulates into main_grad; allocate before backward. layer.weight.main_grad = torch.zeros(layer.weight.shape, dtype=dtype, device="cuda") @@ -404,9 +426,9 @@ def _worker_linear_correctness(rank, world_size, port): out_ref.backward(grad_out.float()) assert inp_gtp.grad is not None - assert torch.allclose(inp_gtp.grad.float(), inp_ref.grad.float(), atol=0.1, rtol=0.1), ( - f"dX mismatch max_diff={(inp_gtp.grad.float()-inp_ref.grad.float()).abs().max():.4f}" - ) + assert torch.allclose( + inp_gtp.grad.float(), inp_ref.grad.float(), atol=0.1, rtol=0.1 + ), f"dX mismatch max_diff={(inp_gtp.grad.float()-inp_ref.grad.float()).abs().max():.4f}" dist.destroy_process_group() @@ -420,6 +442,7 @@ def test_forward_backward_correctness(self): # 6. LayerNormLinear forward/backward smoke test # --------------------------------------------------------------------------- + def _worker_layernorm_linear(rank, world_size, port): _dist_init(rank, world_size, port) torch.manual_seed(0) @@ -428,9 +451,12 @@ def _worker_layernorm_linear(rank, world_size, port): gtp_group = dist.new_group(list(range(world_size))) layer = te.LayerNormLinear( - in_features=in_f, out_features=out_f, - bias=False, params_dtype=dtype, - device="cuda", gtp_group=gtp_group, + in_features=in_f, + out_features=out_f, + bias=False, + params_dtype=dtype, + device="cuda", + gtp_group=gtp_group, ) assert isinstance(layer.weight, GTPShardedParam) @@ -456,6 +482,7 @@ def test_forward_backward(self): # 7. GroupedLinear forward/backward smoke test # --------------------------------------------------------------------------- + def _worker_grouped_linear(rank, world_size, port, num_gemms): _dist_init(rank, world_size, port) torch.manual_seed(0) @@ -464,9 +491,13 @@ def _worker_grouped_linear(rank, world_size, port, num_gemms): gtp_group = dist.new_group(list(range(world_size))) layer = te.GroupedLinear( - num_gemms=num_gemms, in_features=in_f, out_features=out_f, - bias=False, params_dtype=dtype, - device="cuda", gtp_group=gtp_group, + num_gemms=num_gemms, + in_features=in_f, + out_features=out_f, + bias=False, + params_dtype=dtype, + device="cuda", + gtp_group=gtp_group, ) assert isinstance(layer.weight0, GTPShardedParam) @@ -498,6 +529,7 @@ def test_forward_backward(self, num_gemms): # 8. Prefetch chain: next_w / prev_w wiring after first forward pass # --------------------------------------------------------------------------- + def _worker_chain_wired(rank, world_size, port): _dist_init(rank, world_size, port) torch.manual_seed(0) @@ -505,10 +537,22 @@ def _worker_chain_wired(rank, world_size, port): dtype = torch.bfloat16 gtp_group = dist.new_group(list(range(world_size))) - l0 = te.Linear(in_features=in_f, out_features=out_f, bias=False, - params_dtype=dtype, device="cuda", gtp_group=gtp_group) - l1 = te.Linear(in_features=in_f, out_features=out_f, bias=False, - params_dtype=dtype, device="cuda", gtp_group=gtp_group) + l0 = te.Linear( + in_features=in_f, + out_features=out_f, + bias=False, + params_dtype=dtype, + device="cuda", + gtp_group=gtp_group, + ) + l1 = te.Linear( + in_features=in_f, + out_features=out_f, + bias=False, + params_dtype=dtype, + device="cuda", + gtp_group=gtp_group, + ) inp = torch.randn(4, in_f, dtype=dtype, device="cuda") dist.broadcast(inp, src=0) @@ -533,10 +577,22 @@ def _worker_chain_async_prefetch(rank, world_size, port): dtype = torch.bfloat16 gtp_group = dist.new_group(list(range(world_size))) - l0 = te.Linear(in_features=in_f, out_features=out_f, bias=False, - params_dtype=dtype, device="cuda", gtp_group=gtp_group) - l1 = te.Linear(in_features=in_f, out_features=out_f, bias=False, - params_dtype=dtype, device="cuda", gtp_group=gtp_group) + l0 = te.Linear( + in_features=in_f, + out_features=out_f, + bias=False, + params_dtype=dtype, + device="cuda", + gtp_group=gtp_group, + ) + l1 = te.Linear( + in_features=in_f, + out_features=out_f, + bias=False, + params_dtype=dtype, + device="cuda", + gtp_group=gtp_group, + ) inp = torch.randn(4, in_f, dtype=dtype, device="cuda") dist.broadcast(inp, src=0) @@ -562,6 +618,7 @@ def test_async_prefetch_second_pass(self): # 9. Wgrad reduce-scatter: shape and deferred async path # --------------------------------------------------------------------------- + def _worker_wgrad_shape(rank, world_size, port): """After backward, weight.grad shape must match the local shard shape.""" _dist_init(rank, world_size, port) @@ -571,9 +628,12 @@ def _worker_wgrad_shape(rank, world_size, port): gtp_group = dist.new_group(list(range(world_size))) layer = te.Linear( - in_features=in_f, out_features=out_f, - bias=False, params_dtype=dtype, - device="cuda", gtp_group=gtp_group, + in_features=in_f, + out_features=out_f, + bias=False, + params_dtype=dtype, + device="cuda", + gtp_group=gtp_group, fuse_wgrad_accumulation=False, ) inp = torch.randn(8, in_f, dtype=dtype, device="cuda", requires_grad=True) @@ -584,8 +644,7 @@ def _worker_wgrad_shape(rank, world_size, port): w = layer.weight if w.grad is not None: - assert w.grad.shape == w.shape, \ - f"wgrad shape {w.grad.shape} != shard shape {w.shape}" + assert w.grad.shape == w.shape, f"wgrad shape {w.grad.shape} != shard shape {w.shape}" dist.destroy_process_group() @@ -597,10 +656,22 @@ def _worker_multilayer_deferred_rs(rank, world_size, port): dtype = torch.bfloat16 gtp_group = dist.new_group(list(range(world_size))) - l0 = te.Linear(in_features=in_f, out_features=out_f, bias=False, - params_dtype=dtype, device="cuda", gtp_group=gtp_group) - l1 = te.Linear(in_features=in_f, out_features=out_f, bias=False, - params_dtype=dtype, device="cuda", gtp_group=gtp_group) + l0 = te.Linear( + in_features=in_f, + out_features=out_f, + bias=False, + params_dtype=dtype, + device="cuda", + gtp_group=gtp_group, + ) + l1 = te.Linear( + in_features=in_f, + out_features=out_f, + bias=False, + params_dtype=dtype, + device="cuda", + gtp_group=gtp_group, + ) inp = torch.randn(8, in_f, dtype=dtype, device="cuda", requires_grad=True) dist.broadcast(inp, src=0) @@ -633,6 +704,7 @@ def test_multilayer_deferred_rs(self): # 10. Multiple microbatches: output must be consistent when weight unchanged # --------------------------------------------------------------------------- + def _worker_microbatches(rank, world_size, port): _dist_init(rank, world_size, port) torch.manual_seed(0) @@ -641,9 +713,12 @@ def _worker_microbatches(rank, world_size, port): gtp_group = dist.new_group(list(range(world_size))) layer = te.Linear( - in_features=in_f, out_features=out_f, - bias=False, params_dtype=dtype, - device="cuda", gtp_group=gtp_group, + in_features=in_f, + out_features=out_f, + bias=False, + params_dtype=dtype, + device="cuda", + gtp_group=gtp_group, ) inp = torch.randn(batch, in_f, dtype=dtype, device="cuda") dist.broadcast(inp, src=0) @@ -654,8 +729,9 @@ def _worker_microbatches(rank, world_size, port): # Second microbatch with same weight (skip_weight_cast=True path) out2 = layer(inp, is_first_microbatch=False).detach() - assert torch.allclose(out1, out2), \ - f"Microbatch outputs differ; max_diff={(out1-out2).abs().max():.6f}" + assert torch.allclose( + out1, out2 + ), f"Microbatch outputs differ; max_diff={(out1-out2).abs().max():.6f}" dist.destroy_process_group() @@ -669,19 +745,23 @@ def test_consistent_across_microbatches(self): # 11. NVFP4 + GTP: Linear forward/backward, quantized shard setup # --------------------------------------------------------------------------- + def _worker_nvfp4_linear(rank, world_size, port): """Verify that GTP Linear correctly quantizes, all-gathers, and computes with NVFP4.""" _dist_init(rank, world_size, port) torch.manual_seed(0) # batch=32: NVFP4 wgrad GEMM (K=batch) requires K divisible by 32 - batch, in_f, out_f = 32, 64, 128 # out_f % (16*world_size)==0 → no padding + batch, in_f, out_f = 32, 64, 128 # out_f % (16*world_size)==0 → no padding dtype = torch.bfloat16 gtp_group = dist.new_group(list(range(world_size))) layer = te.Linear( - in_features=in_f, out_features=out_f, - bias=False, params_dtype=dtype, - device="cuda", gtp_group=gtp_group, + in_features=in_f, + out_features=out_f, + bias=False, + params_dtype=dtype, + device="cuda", + gtp_group=gtp_group, ) inp = torch.randn(batch, in_f, dtype=dtype, device="cuda", requires_grad=True) dist.broadcast(inp, src=0) @@ -694,8 +774,9 @@ def _worker_nvfp4_linear(rank, world_size, port): # After the first forward pass setup() must have created a quantized shard w = layer.weight assert w.quantized is not None, "NVFP4 quantized shard must be set after setup()" - assert isinstance(w.quantized, QuantizedTensor), \ - f"weight.quantized should be QuantizedTensor, got {type(w.quantized)}" + assert isinstance( + w.quantized, QuantizedTensor + ), f"weight.quantized should be QuantizedTensor, got {type(w.quantized)}" assert out.shape == (batch, out_f), f"unexpected output shape {out.shape}" assert torch.isfinite(out).all(), "NVFP4 GTP output has non-finite values" @@ -719,19 +800,22 @@ def _worker_nvfp4_linear_unaligned(rank, world_size, port): """ _dist_init(rank, world_size, port) torch.manual_seed(0) - alignment = 16 * world_size # 64 for world_size=4 + alignment = 16 * world_size # 64 for world_size=4 # Choose out_f divisible by 8 (NVFP4 GEMM constraint) but not by 64 (GTP alignment). # With out_f=56: pad_length=8, shard_size=16, last rank gets 8 rows padded to 16. - out_f = alignment - 8 # 56 for world_size=4 + out_f = alignment - 8 # 56 for world_size=4 in_f = 64 batch = 32 dtype = torch.bfloat16 gtp_group = dist.new_group(list(range(world_size))) layer = te.Linear( - in_features=in_f, out_features=out_f, - bias=False, params_dtype=dtype, - device="cuda", gtp_group=gtp_group, + in_features=in_f, + out_features=out_f, + bias=False, + params_dtype=dtype, + device="cuda", + gtp_group=gtp_group, ) inp = torch.randn(batch, in_f, dtype=dtype, device="cuda", requires_grad=True) dist.broadcast(inp, src=0) @@ -761,6 +845,7 @@ def test_forward_unaligned_padding(self): # 12. NVFP4 + GTP: GroupedLinear forward/backward (coalesced batched all-gather) # --------------------------------------------------------------------------- + def _worker_nvfp4_grouped_linear(rank, world_size, port, num_gemms): """Verify NVFP4 GTP with GroupedLinear (uses grouped_gather_along_first_dim).""" _dist_init(rank, world_size, port) @@ -772,9 +857,13 @@ def _worker_nvfp4_grouped_linear(rank, world_size, port, num_gemms): gtp_group = dist.new_group(list(range(world_size))) layer = te.GroupedLinear( - num_gemms=num_gemms, in_features=in_f, out_features=out_f, - bias=False, params_dtype=dtype, - device="cuda", gtp_group=gtp_group, + num_gemms=num_gemms, + in_features=in_f, + out_features=out_f, + bias=False, + params_dtype=dtype, + device="cuda", + gtp_group=gtp_group, ) assert isinstance(layer.weight0, GTPShardedParam) @@ -796,8 +885,9 @@ def _worker_nvfp4_grouped_linear(rank, world_size, port, num_gemms): w = getattr(layer, name) assert isinstance(w, GTPShardedParam) assert w.quantized is not None, f"{name}.quantized not set after NVFP4 setup()" - assert isinstance(w.quantized, QuantizedTensor), \ - f"{name}.quantized should be QuantizedTensor, got {type(w.quantized)}" + assert isinstance( + w.quantized, QuantizedTensor + ), f"{name}.quantized should be QuantizedTensor, got {type(w.quantized)}" for i in range(num_gemms): w = getattr(layer, f"weight{i}") @@ -819,20 +909,25 @@ def test_forward_backward(self, num_gemms): # 13. MXFP8 + GTP: Linear forward/backward, quantized shard setup # --------------------------------------------------------------------------- + def _worker_mxfp8_linear(rank, world_size, port): """Verify that GTP Linear correctly quantizes, all-gathers, and computes with MXFP8.""" from transformer_engine.common.recipe import MXFP8BlockScaling + _dist_init(rank, world_size, port) torch.manual_seed(0) # batch=32: MXFP8 wgrad GEMM (K=batch) requires K divisible by MXFP8_BLOCK_SCALING_SIZE=32 - batch, in_f, out_f = 32, 64, 128 # out_f % (16*world_size)==0 → no padding + batch, in_f, out_f = 32, 64, 128 # out_f % (16*world_size)==0 → no padding dtype = torch.bfloat16 gtp_group = dist.new_group(list(range(world_size))) layer = te.Linear( - in_features=in_f, out_features=out_f, - bias=False, params_dtype=dtype, - device="cuda", gtp_group=gtp_group, + in_features=in_f, + out_features=out_f, + bias=False, + params_dtype=dtype, + device="cuda", + gtp_group=gtp_group, ) inp = torch.randn(batch, in_f, dtype=dtype, device="cuda", requires_grad=True) dist.broadcast(inp, src=0) @@ -845,8 +940,9 @@ def _worker_mxfp8_linear(rank, world_size, port): # After the first forward pass setup() must have created a quantized shard w = layer.weight assert w.quantized is not None, "MXFP8 quantized shard must be set after setup()" - assert isinstance(w.quantized, QuantizedTensor), \ - f"weight.quantized should be QuantizedTensor, got {type(w.quantized)}" + assert isinstance( + w.quantized, QuantizedTensor + ), f"weight.quantized should be QuantizedTensor, got {type(w.quantized)}" assert out.shape == (batch, out_f), f"unexpected output shape {out.shape}" assert torch.isfinite(out).all(), "MXFP8 GTP output has non-finite values" @@ -875,6 +971,7 @@ def _worker_mxfp8_linear_unaligned(rank, world_size, port): rows before the GEMM, so the output has the original out_f columns. """ from transformer_engine.common.recipe import MXFP8BlockScaling + _dist_init(rank, world_size, port) torch.manual_seed(0) # out_f=120: M_padded=128, shard_size=32, last rank has 24 rows padded to 32. @@ -886,9 +983,12 @@ def _worker_mxfp8_linear_unaligned(rank, world_size, port): gtp_group = dist.new_group(list(range(world_size))) layer = te.Linear( - in_features=in_f, out_features=out_f, - bias=False, params_dtype=dtype, - device="cuda", gtp_group=gtp_group, + in_features=in_f, + out_features=out_f, + bias=False, + params_dtype=dtype, + device="cuda", + gtp_group=gtp_group, ) inp = torch.randn(batch, in_f, dtype=dtype, device="cuda", requires_grad=True) dist.broadcast(inp, src=0) @@ -924,6 +1024,7 @@ def test_forward_unaligned_padding(self): # 14. GTPConfig / update_config # --------------------------------------------------------------------------- + class TestGTPConfig: def test_update_pad_for_alignment(self): @@ -951,14 +1052,19 @@ def test_invalid_key_raises(self): # 15. GTPShardedParam properties – shape computations and padding # --------------------------------------------------------------------------- + class TestGTPShardedParamProperties: class _FakeGroup: def __init__(self, size=4, rank=0): self._size = size self._rank = rank - def size(self): return self._size - def rank(self): return self._rank + + def size(self): + return self._size + + def rank(self): + return self._rank def _make_param(self, shape, pad_length=0, group_size=4, group_rank=0): p = GTPShardedParam(torch.zeros(*shape)) @@ -1045,11 +1151,15 @@ def test_strip_padding_multi_row(self): # 16. _get_cache_key – expert vs non-expert, fwd vs bwd # --------------------------------------------------------------------------- + class TestGTPCacheKey: class _FakeGroup: - def size(self): return 4 - def rank(self): return 0 + def size(self): + return 4 + + def rank(self): + return 0 def _param(self, shape=(16, 32), expert_idx=None): p = GTPShardedParam(torch.zeros(*shape)) @@ -1061,8 +1171,9 @@ def _param(self, shape=(16, 32), expert_idx=None): def test_non_expert_key_same_for_fwd_bwd(self): """Non-routed params produce the same cache key for fwd and bwd.""" p = self._param(expert_idx=None) - assert p._get_cache_key(torch.bfloat16, fwd=True, reduce_scatter=False) == \ - p._get_cache_key(torch.bfloat16, fwd=False, reduce_scatter=False) + assert p._get_cache_key(torch.bfloat16, fwd=True, reduce_scatter=False) == p._get_cache_key( + torch.bfloat16, fwd=False, reduce_scatter=False + ) def test_expert_key_differs_fwd_bwd(self): """For quantized (non-torch.dtype) recipes, expert fwd vs bwd keys differ.""" @@ -1070,45 +1181,54 @@ def test_expert_key_differs_fwd_bwd(self): # _get_cache_key differentiates fwd/bwd only for non-torch.dtype objects # (e.g. quantized recipe dtype descriptors). Use a mock to trigger that path. mock_dtype = "fp8" - assert p._get_cache_key(mock_dtype, fwd=True, reduce_scatter=False) != \ - p._get_cache_key(mock_dtype, fwd=False, reduce_scatter=False) + assert p._get_cache_key(mock_dtype, fwd=True, reduce_scatter=False) != p._get_cache_key( + mock_dtype, fwd=False, reduce_scatter=False + ) def test_different_expert_idx_different_keys(self): """Two experts with same shape but different indices get distinct keys.""" p0 = self._param(expert_idx=0) p1 = self._param(expert_idx=1) - assert p0._get_cache_key(torch.bfloat16, fwd=True, reduce_scatter=False) != \ - p1._get_cache_key(torch.bfloat16, fwd=True, reduce_scatter=False) + assert p0._get_cache_key( + torch.bfloat16, fwd=True, reduce_scatter=False + ) != p1._get_cache_key(torch.bfloat16, fwd=True, reduce_scatter=False) def test_same_expert_idx_same_key(self): """Same-shaped experts with the same idx share a cache key (cross-layer buffer reuse).""" p_l0 = self._param(expert_idx=0) p_l1 = self._param(expert_idx=0) - assert p_l0._get_cache_key(torch.bfloat16, fwd=True, reduce_scatter=False) == \ - p_l1._get_cache_key(torch.bfloat16, fwd=True, reduce_scatter=False) + assert p_l0._get_cache_key( + torch.bfloat16, fwd=True, reduce_scatter=False + ) == p_l1._get_cache_key(torch.bfloat16, fwd=True, reduce_scatter=False) def test_different_dtypes_different_keys(self): p = self._param() - assert p._get_cache_key(torch.bfloat16, fwd=True, reduce_scatter=False) != \ - p._get_cache_key(torch.float32, fwd=True, reduce_scatter=False) + assert p._get_cache_key(torch.bfloat16, fwd=True, reduce_scatter=False) != p._get_cache_key( + torch.float32, fwd=True, reduce_scatter=False + ) def test_rs_key_differs_from_ag_key(self): """reduce_scatter=True key must differ from reduce_scatter=False key.""" p = self._param() - assert p._get_cache_key(torch.bfloat16, fwd=True, reduce_scatter=False) != \ - p._get_cache_key(torch.bfloat16, fwd=True, reduce_scatter=True) + assert p._get_cache_key(torch.bfloat16, fwd=True, reduce_scatter=False) != p._get_cache_key( + torch.bfloat16, fwd=True, reduce_scatter=True + ) # --------------------------------------------------------------------------- # 17. GTPWeightCache.take() deferred vs get() immediate pool return # --------------------------------------------------------------------------- + class TestGTPCacheRelease: """Tests for GTPWeightCache reserve/get/release semantics.""" class _FakeGroup: - def size(self): return 2 - def rank(self): return 0 + def size(self): + return 2 + + def rank(self): + return 0 def _param(self, shape=(8, 4)): p = GTPShardedParam(torch.zeros(*shape)) @@ -1162,61 +1282,73 @@ def test_release_invalid_ticket_raises(self): # 18. tag_gtp_params_with_names – _debug_name population # --------------------------------------------------------------------------- + class TestTagGTPParamsWithNames: def test_debug_name_populated_for_gtp_param(self): """GTPShardedParam._debug_name is set to the dotted parameter path.""" + class _FakeGroup: - def size(self): return 1 - def rank(self): return 0 + def size(self): + return 1 + + def rank(self): + return 0 model = nn.Linear(4, 8, bias=False) w = GTPShardedParam(torch.randn(8, 4)) w.group = _FakeGroup() - model._parameters['weight'] = w + model._parameters["weight"] = w gtp_module.tag_gtp_params_with_names(model) - assert w._debug_name == 'weight', \ - f"Expected 'weight', got '{w._debug_name}'" + assert w._debug_name == "weight", f"Expected 'weight', got '{w._debug_name}'" def test_nested_module_debug_name(self): """Nested module produces a dotted debug name.""" + class _FakeGroup: - def size(self): return 1 - def rank(self): return 0 + def size(self): + return 1 + + def rank(self): + return 0 outer = nn.Sequential(nn.Linear(4, 8, bias=False)) w = GTPShardedParam(torch.randn(8, 4)) w.group = _FakeGroup() - outer._modules['0']._parameters['weight'] = w + outer._modules["0"]._parameters["weight"] = w gtp_module.tag_gtp_params_with_names(outer) - assert w._debug_name == '0.weight', \ - f"Expected '0.weight', got '{w._debug_name}'" + assert w._debug_name == "0.weight", f"Expected '0.weight', got '{w._debug_name}'" def test_non_gtp_params_are_skipped(self): """Plain nn.Parameter instances are silently ignored.""" model = nn.Linear(4, 8) - gtp_module.tag_gtp_params_with_names(model) # must not raise + gtp_module.tag_gtp_params_with_names(model) # must not raise # --------------------------------------------------------------------------- # 19. wrap_module_params_gtp is a no-op when gtp_group.size() == 1 # --------------------------------------------------------------------------- + class TestGTPGroupSizeOne: class _SingletonGroup: - def size(self): return 1 - def rank(self): return 0 + def size(self): + return 1 + + def rank(self): + return 0 def test_no_sharding_when_gtp_size_one(self): """wrap_module_params_gtp must be a no-op for a singleton GTP group.""" mod = nn.Linear(32, 64, bias=False) original_weight = mod.weight - wrap_module_params_gtp(mod, ['weight'], self._SingletonGroup()) - assert mod.weight is original_weight, \ - "gtp_group.size()==1 should leave parameters unchanged" + wrap_module_params_gtp(mod, ["weight"], self._SingletonGroup()) + assert ( + mod.weight is original_weight + ), "gtp_group.size()==1 should leave parameters unchanged" assert not isinstance(mod.weight, GTPShardedParam) @@ -1224,6 +1356,7 @@ def test_no_sharding_when_gtp_size_one(self): # 21. weight_prefetch=False: forward still produces correct output # --------------------------------------------------------------------------- + def _worker_prefetch_disabled(rank, world_size, port): _dist_init(rank, world_size, port) torch.manual_seed(0) @@ -1233,10 +1366,22 @@ def _worker_prefetch_disabled(rank, world_size, port): gtp_module.update_config(weight_prefetch=False) try: - l0 = te.Linear(in_features=in_f, out_features=out_f, bias=False, - params_dtype=dtype, device="cuda", gtp_group=gtp_group) - l1 = te.Linear(in_features=in_f, out_features=out_f, bias=False, - params_dtype=dtype, device="cuda", gtp_group=gtp_group) + l0 = te.Linear( + in_features=in_f, + out_features=out_f, + bias=False, + params_dtype=dtype, + device="cuda", + gtp_group=gtp_group, + ) + l1 = te.Linear( + in_features=in_f, + out_features=out_f, + bias=False, + params_dtype=dtype, + device="cuda", + gtp_group=gtp_group, + ) inp = torch.randn(4, in_f, dtype=dtype, device="cuda") dist.broadcast(inp, src=0) @@ -1262,17 +1407,21 @@ def test_forward_works_without_prefetch(self): # 22. fuse_wgrad_accumulation=True: wgrad is accumulated into main_grad # --------------------------------------------------------------------------- + def _worker_fuse_wgrad(rank, world_size, port): _dist_init(rank, world_size, port) torch.manual_seed(0) - in_f, out_f = 32, 128 # out_f % (16*world_size)==0, no padding + in_f, out_f = 32, 128 # out_f % (16*world_size)==0, no padding dtype = torch.bfloat16 gtp_group = dist.new_group(list(range(world_size))) layer = te.Linear( - in_features=in_f, out_features=out_f, - bias=False, params_dtype=dtype, - device="cuda", gtp_group=gtp_group, + in_features=in_f, + out_features=out_f, + bias=False, + params_dtype=dtype, + device="cuda", + gtp_group=gtp_group, fuse_wgrad_accumulation=True, ) @@ -1286,8 +1435,9 @@ def _worker_fuse_wgrad(rank, world_size, port): layer(inp, is_first_microbatch=True).sum().backward() # With fused accumulation, wgrad was added into main_grad - assert torch.any(w.main_grad != 0), \ - "main_grad should have been updated by fused wgrad accumulation" + assert torch.any( + w.main_grad != 0 + ), "main_grad should have been updated by fused wgrad accumulation" dist.destroy_process_group() @@ -1301,6 +1451,7 @@ def test_wgrad_accumulated_into_main_grad(self): # 23. _grad_accum_hook is called after reduce-scatter # --------------------------------------------------------------------------- + def _worker_main_grad_updated_after_bwd(rank, world_size, port): """After backward, the wgrad RS path must have accumulated wgrad into main_grad.""" _dist_init(rank, world_size, port) @@ -1310,9 +1461,12 @@ def _worker_main_grad_updated_after_bwd(rank, world_size, port): gtp_group = dist.new_group(list(range(world_size))) layer = te.Linear( - in_features=in_f, out_features=out_f, - bias=False, params_dtype=dtype, - device="cuda", gtp_group=gtp_group, + in_features=in_f, + out_features=out_f, + bias=False, + params_dtype=dtype, + device="cuda", + gtp_group=gtp_group, ) # wgrad RS path always accumulates into main_grad; allocate before backward. @@ -1322,8 +1476,9 @@ def _worker_main_grad_updated_after_bwd(rank, world_size, port): dist.broadcast(inp, src=0) layer(inp, is_first_microbatch=True).sum().backward() - assert torch.any(layer.weight.main_grad != 0), \ - "main_grad should have been updated after the reduce-scatter accumulation" + assert torch.any( + layer.weight.main_grad != 0 + ), "main_grad should have been updated after the reduce-scatter accumulation" dist.destroy_process_group() @@ -1331,5 +1486,3 @@ class TestGTPGradAccumHook: def test_main_grad_updated_after_backward(self): _requires_multi_gpu(4) _run_distributed(_worker_main_grad_updated_after_bwd, 4) - - diff --git a/tests/pytorch/distributed/test_tp_gtp.py b/tests/pytorch/distributed/test_tp_gtp.py index 7310f1c450..ce739e43d9 100644 --- a/tests/pytorch/distributed/test_tp_gtp.py +++ b/tests/pytorch/distributed/test_tp_gtp.py @@ -42,6 +42,7 @@ # Fixtures # --------------------------------------------------------------------------- + @pytest.fixture(autouse=True) def reset_fp8_state(): yield @@ -61,6 +62,7 @@ def reset_gtp_globals(): # Helpers # --------------------------------------------------------------------------- + def _free_port() -> int: with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: s.bind(("", 0)) @@ -125,18 +127,21 @@ def _build_groups(rank: int, world_size: int, tp_size: int, gtp_size: int): # 1. TestTPGTPProcessGroups – group sizes and rank membership # --------------------------------------------------------------------------- + def _worker_groups(rank, world_size, port, tp_size, gtp_size): _dist_init(rank, world_size, port) tp_group, gtp_group, tp_rank, gtp_rank = _build_groups(rank, world_size, tp_size, gtp_size) - assert tp_group.size() == tp_size, \ - f"rank {rank}: TP group size {tp_group.size()} != {tp_size}" - assert gtp_group.size() == gtp_size, \ - f"rank {rank}: GTP group size {gtp_group.size()} != {gtp_size}" - assert dist.get_rank(tp_group) == tp_rank, \ - f"rank {rank}: TP rank {dist.get_rank(tp_group)} != expected {tp_rank}" - assert dist.get_rank(gtp_group) == gtp_rank, \ - f"rank {rank}: GTP rank {dist.get_rank(gtp_group)} != expected {gtp_rank}" + assert tp_group.size() == tp_size, f"rank {rank}: TP group size {tp_group.size()} != {tp_size}" + assert ( + gtp_group.size() == gtp_size + ), f"rank {rank}: GTP group size {gtp_group.size()} != {gtp_size}" + assert ( + dist.get_rank(tp_group) == tp_rank + ), f"rank {rank}: TP rank {dist.get_rank(tp_group)} != expected {tp_rank}" + assert ( + dist.get_rank(gtp_group) == gtp_rank + ), f"rank {rank}: GTP rank {dist.get_rank(gtp_group)} != expected {gtp_rank}" dist.destroy_process_group() @@ -153,25 +158,34 @@ def test_group_sizes_and_ranks(self, tp_size, gtp_size): # 2. TestTPGTPColumnParallelLinear # --------------------------------------------------------------------------- + def _worker_column_shape(rank, world_size, port, tp_size, gtp_size): """Column-parallel: weight shape must be [out_f/(tp_size*gtp_size), in_f].""" _dist_init(rank, world_size, port) tp_group, gtp_group, _, _ = _build_groups(rank, world_size, tp_size, gtp_size) in_f = 64 - out_f = tp_size * gtp_size * 32 # per-rank shard = 32 rows + out_f = tp_size * gtp_size * 32 # per-rank shard = 32 rows layer = te.Linear( - in_features=in_f, out_features=out_f, - parallel_mode="column", bias=False, params_dtype=torch.bfloat16, - device="cuda", tp_group=tp_group, gtp_group=gtp_group, + in_features=in_f, + out_features=out_f, + parallel_mode="column", + bias=False, + params_dtype=torch.bfloat16, + device="cuda", + tp_group=tp_group, + gtp_group=gtp_group, ) expected_rows = out_f // (tp_size * gtp_size) - assert isinstance(layer.weight, GTPShardedParam), \ - f"rank {rank}: weight should be GTPShardedParam" - assert layer.weight.shape == (expected_rows, in_f), \ - f"rank {rank}: expected ({expected_rows}, {in_f}), got {layer.weight.shape}" + assert isinstance( + layer.weight, GTPShardedParam + ), f"rank {rank}: weight should be GTPShardedParam" + assert layer.weight.shape == ( + expected_rows, + in_f, + ), f"rank {rank}: expected ({expected_rows}, {in_f}), got {layer.weight.shape}" dist.destroy_process_group() @@ -183,13 +197,18 @@ def _worker_column_correctness(rank, world_size, port, tp_size, gtp_size): tp_group, gtp_group, tp_rank, gtp_rank = _build_groups(rank, world_size, tp_size, gtp_size) batch, in_f = 16, 64 - out_f = tp_size * gtp_size * 32 # per-rank shard = 32 rows + out_f = tp_size * gtp_size * 32 # per-rank shard = 32 rows dtype = torch.bfloat16 layer = te.Linear( - in_features=in_f, out_features=out_f, - parallel_mode="column", bias=False, params_dtype=dtype, - device="cuda", tp_group=tp_group, gtp_group=gtp_group, + in_features=in_f, + out_features=out_f, + parallel_mode="column", + bias=False, + params_dtype=dtype, + device="cuda", + tp_group=tp_group, + gtp_group=gtp_group, ) # All-gather GTP shards → TP-local full weight [out_f/tp_size, in_f] @@ -197,7 +216,7 @@ def _worker_column_correctness(rank, world_size, port, tp_size, gtp_size): all_gtp_shards = [torch.zeros_like(shard) for _ in range(gtp_size)] dist.all_gather(all_gtp_shards, shard, group=gtp_group) tp_local_weight = torch.cat(all_gtp_shards, dim=0).float() # strip padding - tp_local_weight = tp_local_weight[:out_f // tp_size] + tp_local_weight = tp_local_weight[: out_f // tp_size] # Same full input on all ranks (column-parallel: each rank processes full input) inp = torch.randn(batch, in_f, dtype=dtype, device="cuda") @@ -206,16 +225,17 @@ def _worker_column_correctness(rank, world_size, port, tp_size, gtp_size): # TE forward: GTP all-gathers weight internally; no TP comm in column-parallel fwd out = layer(inp_te, is_first_microbatch=True) - assert out.shape == (batch, out_f // tp_size), \ - f"rank {rank}: output shape {out.shape} != ({batch}, {out_f // tp_size})" + assert out.shape == ( + batch, + out_f // tp_size, + ), f"rank {rank}: output shape {out.shape} != ({batch}, {out_f // tp_size})" # Reference: this TP rank's output = inp @ tp_local_weight^T ref = inp.float() @ tp_local_weight.T ref = ref.to(dtype) - assert torch.allclose(out.float(), ref.float(), atol=0.1, rtol=0.1), ( - f"rank {rank}: output mismatch, " - f"max_diff={(out.float() - ref.float()).abs().max():.4f}" - ) + assert torch.allclose( + out.float(), ref.float(), atol=0.1, rtol=0.1 + ), f"rank {rank}: output mismatch, max_diff={(out.float() - ref.float()).abs().max():.4f}" # Backward: dX is all-reduced across TP group internally by TE grad = torch.randn_like(out) @@ -247,25 +267,33 @@ def test_forward_backward_correctness(self, tp_size, gtp_size): # 3. TestTPGTPRowParallelLinear # --------------------------------------------------------------------------- + def _worker_row_shape(rank, world_size, port, tp_size, gtp_size): """Row-parallel: weight shape must be [out_f/gtp_size, in_f/tp_size].""" _dist_init(rank, world_size, port) tp_group, gtp_group, _, _ = _build_groups(rank, world_size, tp_size, gtp_size) - in_f = tp_size * 64 # TE divides by tp_size → local in_f = 64 + in_f = tp_size * 64 # TE divides by tp_size → local in_f = 64 out_f = gtp_size * 64 # GTP divides by gtp_size → local out_f = 64 layer = te.Linear( - in_features=in_f, out_features=out_f, - parallel_mode="row", bias=False, params_dtype=torch.bfloat16, - device="cuda", tp_group=tp_group, gtp_group=gtp_group, + in_features=in_f, + out_features=out_f, + parallel_mode="row", + bias=False, + params_dtype=torch.bfloat16, + device="cuda", + tp_group=tp_group, + gtp_group=gtp_group, ) expected_shape = (out_f // gtp_size, in_f // tp_size) - assert isinstance(layer.weight, GTPShardedParam), \ - f"rank {rank}: weight should be GTPShardedParam" - assert layer.weight.shape == expected_shape, \ - f"rank {rank}: expected {expected_shape}, got {layer.weight.shape}" + assert isinstance( + layer.weight, GTPShardedParam + ), f"rank {rank}: weight should be GTPShardedParam" + assert ( + layer.weight.shape == expected_shape + ), f"rank {rank}: expected {expected_shape}, got {layer.weight.shape}" dist.destroy_process_group() @@ -277,14 +305,19 @@ def _worker_row_forward_backward(rank, world_size, port, tp_size, gtp_size): tp_group, gtp_group, tp_rank, _ = _build_groups(rank, world_size, tp_size, gtp_size) batch = 16 - in_f = tp_size * 64 # full in_features + in_f = tp_size * 64 # full in_features out_f = gtp_size * 64 # full out_features dtype = torch.bfloat16 layer = te.Linear( - in_features=in_f, out_features=out_f, - parallel_mode="row", bias=False, params_dtype=dtype, - device="cuda", tp_group=tp_group, gtp_group=gtp_group, + in_features=in_f, + out_features=out_f, + parallel_mode="row", + bias=False, + params_dtype=dtype, + device="cuda", + tp_group=tp_group, + gtp_group=gtp_group, ) # Row-parallel: each TP rank takes the corresponding slice of in_f @@ -296,8 +329,10 @@ def _worker_row_forward_backward(rank, world_size, port, tp_size, gtp_size): # TE forward: GTP all-gathers weight, row-parallel all-reduces output across TP out = layer(inp, is_first_microbatch=True) - assert out.shape == (batch, out_f), \ - f"rank {rank}: output shape {out.shape} != ({batch}, {out_f})" + assert out.shape == ( + batch, + out_f, + ), f"rank {rank}: output shape {out.shape} != ({batch}, {out_f})" assert torch.isfinite(out).all(), f"rank {rank}: non-finite output" # wgrad RS path always accumulates into main_grad; allocate before backward. @@ -321,20 +356,25 @@ def _worker_row_correctness(rank, world_size, port, tp_size, gtp_size): dtype = torch.bfloat16 layer = te.Linear( - in_features=in_f, out_features=out_f, - parallel_mode="row", bias=False, params_dtype=dtype, - device="cuda", tp_group=tp_group, gtp_group=gtp_group, + in_features=in_f, + out_features=out_f, + parallel_mode="row", + bias=False, + params_dtype=dtype, + device="cuda", + tp_group=tp_group, + gtp_group=gtp_group, ) # Reconstruct full weight: all-gather GTP shards → TP-local, then all-gather TP shards shard = layer.weight.data.clone() all_gtp_shards = [torch.zeros_like(shard) for _ in range(gtp_size)] dist.all_gather(all_gtp_shards, shard, group=gtp_group) - tp_local_weight = torch.cat(all_gtp_shards, dim=0).float() # [out_f, in_f/tp_size] + tp_local_weight = torch.cat(all_gtp_shards, dim=0).float() # [out_f, in_f/tp_size] all_tp_weights = [torch.zeros_like(tp_local_weight) for _ in range(tp_size)] dist.all_gather(all_tp_weights, tp_local_weight, group=tp_group) - full_weight = torch.cat(all_tp_weights, dim=1).float() # [out_f, in_f] + full_weight = torch.cat(all_tp_weights, dim=1).float() # [out_f, in_f] # Full input (same on all ranks; we slice below to simulate row-parallel) full_inp = torch.randn(batch, in_f, dtype=dtype, device="cuda") @@ -348,10 +388,9 @@ def _worker_row_correctness(rank, world_size, port, tp_size, gtp_size): # Reference: full input @ full weight^T — all ranks should see the same output ref = full_inp.float() @ full_weight.T ref = ref.to(dtype) - assert torch.allclose(out.float(), ref.float(), atol=0.1, rtol=0.1), ( - f"rank {rank}: output mismatch, " - f"max_diff={(out.float() - ref.float()).abs().max():.4f}" - ) + assert torch.allclose( + out.float(), ref.float(), atol=0.1, rtol=0.1 + ), f"rank {rank}: output mismatch, max_diff={(out.float() - ref.float()).abs().max():.4f}" dist.destroy_process_group() @@ -380,6 +419,7 @@ def test_forward_correctness(self, tp_size, gtp_size): # 4. TestTPGTPLayerNormLinear – column-parallel smoke test # --------------------------------------------------------------------------- + def _worker_layernorm_linear(rank, world_size, port, tp_size, gtp_size): _dist_init(rank, world_size, port) torch.manual_seed(0) @@ -391,23 +431,29 @@ def _worker_layernorm_linear(rank, world_size, port, tp_size, gtp_size): dtype = torch.bfloat16 layer = te.LayerNormLinear( - in_features=in_f, out_features=out_f, - bias=False, params_dtype=dtype, + in_features=in_f, + out_features=out_f, + bias=False, + params_dtype=dtype, parallel_mode="column", - device="cuda", tp_group=tp_group, gtp_group=gtp_group, + device="cuda", + tp_group=tp_group, + gtp_group=gtp_group, ) - assert isinstance(layer.weight, GTPShardedParam), \ - f"rank {rank}: LayerNormLinear.weight should be GTPShardedParam" + assert isinstance( + layer.weight, GTPShardedParam + ), f"rank {rank}: LayerNormLinear.weight should be GTPShardedParam" expected_rows = out_f // (tp_size * gtp_size) - assert layer.weight.shape == (expected_rows, in_f), \ - f"rank {rank}: unexpected weight shape {layer.weight.shape}" + assert layer.weight.shape == ( + expected_rows, + in_f, + ), f"rank {rank}: unexpected weight shape {layer.weight.shape}" inp = torch.randn(seq, batch, in_f, dtype=dtype, device="cuda", requires_grad=True) dist.broadcast(inp, src=0) out = layer(inp, is_first_microbatch=True) - assert out.shape == (seq, batch, out_f // tp_size), \ - f"rank {rank}: output shape {out.shape}" + assert out.shape == (seq, batch, out_f // tp_size), f"rank {rank}: output shape {out.shape}" assert torch.isfinite(out).all(), f"rank {rank}: non-finite output" # wgrad RS path always accumulates into main_grad; allocate before backward. diff --git a/transformer_engine/common/include/transformer_engine/recipe.h b/transformer_engine/common/include/transformer_engine/recipe.h index 06a37c1800..f565529e53 100644 --- a/transformer_engine/common/include/transformer_engine/recipe.h +++ b/transformer_engine/common/include/transformer_engine/recipe.h @@ -116,7 +116,7 @@ void nvte_compute_amax_with_config(const NVTETensor input, NVTETensor output, * \param[in] config Quantization configuration (for noop_tensor). May be NULL. * \param[in] stream CUDA stream used for the operation. */ -void nvte_multi_compute_amax(const NVTETensor *inputs, NVTETensor *outputs, size_t num_tensors, +void nvte_multi_compute_amax(const NVTETensor* inputs, NVTETensor* outputs, size_t num_tensors, const NVTEQuantizationConfig config, cudaStream_t stream); /*! \brief Update an FP8 tensor's scale based on its amax. diff --git a/transformer_engine/common/recipe/multi_amax.cu b/transformer_engine/common/recipe/multi_amax.cu index 5420dde587..c7ebf017f9 100644 --- a/transformer_engine/common/recipe/multi_amax.cu +++ b/transformer_engine/common/recipe/multi_amax.cu @@ -81,8 +81,7 @@ __launch_bounds__(multi_amax_kernel_threads) __global__ InputType max = InputType{0.f}; const int warp_id = threadIdx.x / THREADS_PER_WARP; - for (size_t tid = blockIdx.x * blockDim.x + threadIdx.x; tid < M; - tid += gridDim.x * blockDim.x) { + for (size_t tid = blockIdx.x * blockDim.x + threadIdx.x; tid < M; tid += gridDim.x * blockDim.x) { loader.load(tid, N); #pragma unroll for (int i = 0; i < nvec; ++i) { @@ -146,18 +145,15 @@ void launch_multi_amax_batch(const MultiAmaxArgs &args, size_t max_numel, Alignm switch (align) { case Alignment::SAME_ALIGNED: - MultiAmaxKernel - <<>>(args, noop_ptr); + MultiAmaxKernel<<>>(args, noop_ptr); break; case Alignment::SAME_UNALIGNED: - MultiAmaxKernel - <<>>(args, noop_ptr); + MultiAmaxKernel<<>>(args, noop_ptr); break; case Alignment::DIFFERENT: // Heterogeneous alignment across tensors — fall back to nvec=1, aligned=true path // which is safe for any pointer alignment. - MultiAmaxKernel<1, true, InputType> - <<>>(args, noop_ptr); + MultiAmaxKernel<1, true, InputType><<>>(args, noop_ptr); break; } NVTE_CHECK_CUDA(cudaGetLastError()); @@ -186,8 +182,8 @@ std::pair build_batch_args(const std::vector &input args.output_rowwise_amax_list[i] = rw_ptr; args.output_columnwise_amax_list[i] = cw_ptr; args.input_numel[i] = N; - args.num_aligned_elements[i] = get_num_aligned_elements(inp.data.dptr, N, nvec, - sizeof(InputType)); + args.num_aligned_elements[i] = + get_num_aligned_elements(inp.data.dptr, N, nvec, sizeof(InputType)); max_numel = std::max(max_numel, N); // Fold this tensor's alignment into the batch decision. CheckAlignment on a @@ -225,11 +221,9 @@ void multi_compute_amax_impl(const NVTETensor *inputs_, NVTETensor *outputs_, si outputs[i] = convertNVTETensorCheck(outputs_[i]); const auto &inp = *inputs[i]; auto &out = *outputs[i]; - NVTE_CHECK(inp.scaling_mode == NVTE_DELAYED_TENSOR_SCALING, - "nvte_multi_compute_amax: input[", i, - "] must be unquantized, got scaling_mode=", to_string(inp.scaling_mode)); - NVTE_CHECK(!is_fp8_dtype(inp.data.dtype), - "nvte_multi_compute_amax: input[", i, + NVTE_CHECK(inp.scaling_mode == NVTE_DELAYED_TENSOR_SCALING, "nvte_multi_compute_amax: input[", + i, "] must be unquantized, got scaling_mode=", to_string(inp.scaling_mode)); + NVTE_CHECK(!is_fp8_dtype(inp.data.dtype), "nvte_multi_compute_amax: input[", i, "] must be unquantized, got dtype=", to_string(inp.data.dtype)); if (i == 0) { input_dtype = inp.data.dtype; diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 6a29c3adb3..758c5424a1 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -333,8 +333,7 @@ py::object quantize(const at::Tensor &tensor, py::handle quantizer, const py::ob py::object compute_amax_nvfp4(const at::Tensor &tensor, py::handle quantizer, const py::object &output); py::object quantize_cast_only_nvfp4(const at::Tensor &tensor, py::handle quantizer, - const py::object &output, - std::optional noop_flag); + const py::object &output, std::optional noop_flag); // NVFP4-only multi-tensor amax: fuses N per-expert (zero_amax + amax + D2D replicate) // chains into a single pair of kernel launches (one multi-zero + one multi-amax) that diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cpp b/transformer_engine/pytorch/csrc/extensions/cast.cpp index 10a507b194..9fb7cf3fde 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cpp +++ b/transformer_engine/pytorch/csrc/extensions/cast.cpp @@ -71,7 +71,7 @@ py::object quantize(const at::Tensor &tensor, py::handle quantizer, const py::ob * then call `quantize_cast_only_nvfp4` to finish the cast. */ py::object compute_amax_nvfp4(const at::Tensor &tensor, py::handle quantizer, - const py::object &output) { + const py::object &output) { NVTE_CHECK(detail::IsNVFP4Quantizers(quantizer.ptr()), "compute_amax_nvfp4 requires an NVFP4Quantizer"); auto quantizer_cpp = convert_quantizer(quantizer); @@ -101,8 +101,7 @@ py::object compute_amax_nvfp4(const at::Tensor &tensor, py::handle quantizer, * already populated `output`'s amax via compute_amax_nvfp4 + coalesced allreduce. */ py::object quantize_cast_only_nvfp4(const at::Tensor &tensor, py::handle quantizer, - const py::object &output, - std::optional noop_flag) { + const py::object &output, std::optional noop_flag) { NVTE_CHECK(detail::IsNVFP4Quantizers(quantizer.ptr()), "quantize_cast_only_nvfp4 requires an NVFP4Quantizer"); auto quantizer_cpp = convert_quantizer(quantizer); @@ -181,13 +180,13 @@ void compute_multi_amax_nvfp4(const std::vector &tensor_list, TensorWrapper out_cpp; py::object out_py; - NVTE_CHECK(!output_list[i].is_none(), - "compute_multi_amax_nvfp4: output_list[", i, "] is None; caller must pre-allocate"); + NVTE_CHECK(!output_list[i].is_none(), "compute_multi_amax_nvfp4: output_list[", i, + "] is None; caller must pre-allocate"); std::tie(out_cpp, out_py) = quantizer_cpp->convert_and_update_tensor(output_list[i]); - NVTE_CHECK(out_cpp.get_amax().data_ptr != nullptr || - out_cpp.get_columnwise_amax().data_ptr != nullptr, - "compute_multi_amax_nvfp4: output[", i, "] has no amax buffer"); + NVTE_CHECK( + out_cpp.get_amax().data_ptr != nullptr || out_cpp.get_columnwise_amax().data_ptr != nullptr, + "compute_multi_amax_nvfp4: output[", i, "] has no amax buffer"); output_wrappers.emplace_back(std::move(out_cpp)); // quantizer_cpp and out_py are released here at end-of-iteration. diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index c15379cae5..696cc51811 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -141,7 +141,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { "NVFP4: compute local amax into output's amax buffers; no cast, no allreduce", py::arg("tensor"), py::arg("quantizer"), py::arg("output") = py::none()); m.def("quantize_cast_only_nvfp4", transformer_engine::pytorch::quantize_cast_only_nvfp4, - "NVFP4: cast using pre-reduced amax in output's amax buffers; skips amax compute and allreduce", + "NVFP4: cast using pre-reduced amax in output's amax buffers; skips amax compute and " + "allreduce", py::arg("tensor"), py::arg("quantizer"), py::arg("output") = py::none(), py::arg("noop") = py::none()); m.def("compute_multi_amax_nvfp4", transformer_engine::pytorch::compute_multi_amax_nvfp4, diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index 7e75c74726..7f7927d495 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -2268,8 +2268,8 @@ void NVFP4Quantizer::quantize_with_rht_unfused_helper( } void NVFP4Quantizer::quantize_impl(const TensorWrapper& input, TensorWrapper& out, - const std::optional& noop_flag, - bool compute_amax, bool skip_amax_reduction) { + const std::optional& noop_flag, bool compute_amax, + bool skip_amax_reduction) { // Nothing to be done if input is empty if (input.numel() == 0) { return; @@ -2508,8 +2508,7 @@ void NVFP4Quantizer::compute_amax_only(const TensorWrapper& input, TensorWrapper // Only the non-RHT path is supported for the split-phase API today. // RHT path's amax depends on the RHT-rotated view, which is produced // alongside the cast; decoupling amax from cast is not meaningful there. - NVTE_CHECK(!this->with_rht, - "NVFP4Quantizer::compute_amax_only does not support with_rht=true"); + NVTE_CHECK(!this->with_rht, "NVFP4Quantizer::compute_amax_only does not support with_rht=true"); auto stream = at::cuda::getCurrentCUDAStream(); @@ -2539,7 +2538,7 @@ void NVFP4Quantizer::compute_amax_only(const TensorWrapper& input, TensorWrapper } void NVFP4Quantizer::quantize_cast_only(const TensorWrapper& input, TensorWrapper& out, - const std::optional& noop_flag) { + const std::optional& noop_flag) { // Amax is expected to already live in out's amax buffers (e.g. from // compute_amax_only + an external coalesced allreduce). Skip both local // amax compute and the internal allreduce. diff --git a/transformer_engine/pytorch/distributed.py b/transformer_engine/pytorch/distributed.py index 8bcbb8d6c1..3e4d005883 100644 --- a/transformer_engine/pytorch/distributed.py +++ b/transformer_engine/pytorch/distributed.py @@ -918,7 +918,10 @@ def fork(self, name: str = "model-parallel-rng"): def reduce_scatter_along_first_dim( - inp: torch.Tensor, tp_group: dist_group_type, async_op: bool = False, output: torch.Tensor = None + inp: torch.Tensor, + tp_group: dist_group_type, + async_op: bool = False, + output: torch.Tensor = None, ) -> Tuple[torch.Tensor, Optional[torch.distributed.Work]]: """Reduce-scatter the input tensor across model parallel group.""" world_size = get_distributed_world_size(tp_group) @@ -1285,12 +1288,12 @@ def _post_process_nvfp4_gather( # # Fix the interleaved transposed data from gathering along first dim. # out._columnwise_scale_inv = _swap_first_dims(columnwise_scale_inv_interleaved, world_size) # out._columnwise_data = _swap_first_dims(columnwise_data_interleaved, world_size) - out._columnwise_scale_inv.copy_(_swap_first_dims(columnwise_scale_inv_interleaved, world_size)) - out._columnwise_data.copy_(_swap_first_dims(columnwise_data_interleaved, world_size)) + out._columnwise_scale_inv.copy_(_swap_first_dims(columnwise_scale_inv_interleaved, world_size)) + out._columnwise_data.copy_(_swap_first_dims(columnwise_data_interleaved, world_size)) # # Optionally pad the scaling inverse if needed. # out._columnwise_scale_inv = pad_columnwise_scale_inv(out._columnwise_scale_inv) - out._columnwise_scale_inv.copy_(pad_columnwise_scale_inv(out._columnwise_scale_inv)) + out._columnwise_scale_inv.copy_(pad_columnwise_scale_inv(out._columnwise_scale_inv)) @dataclass @@ -1328,8 +1331,8 @@ def _all_gather_nvfp4( async_op: bool = False, quantizer: NVFP4Quantizer, out_shape: Optional[list[int]] = None, - output_tensor = None, - grouped = False, + output_tensor=None, + grouped=False, ) -> tuple[NVFP4TensorStorage, Optional[torch.distributed.Work]]: """All-gather NVFP4 tensor along first dimension.""" @@ -1452,7 +1455,7 @@ def _all_gather_nvfp4( ) # Transfer amax to output. - #TODO: jiemingz + # TODO: jiemingz # out._amax_rowwise = inp._amax_rowwise out._amax_rowwise.copy_(inp._amax_rowwise) @@ -1505,7 +1508,6 @@ def _all_gather_nvfp4( # Transfer amax to output. out._amax_columnwise.copy_(inp._amax_columnwise) - handle = coalesced_handle if async_op else None # Fixes interleaved data for transposed tensor/scale inv and pads scale inv if needed. diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 3d75c5761a..94f03eb2cf 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -1640,9 +1640,7 @@ def reset_parameters(self, defer_init: Optional[bool] = False) -> None: gtp_slice_in_reset_parameters, ) - gtp_sharded = gtp_slice_in_reset_parameters( - self, name, param, expert_idx=idx - ) + gtp_sharded = gtp_slice_in_reset_parameters(self, name, param, expert_idx=idx) if gtp_sharded is not None: param = gtp_sharded _gtp_sharded_weight_names.append(name) @@ -1650,11 +1648,7 @@ def reset_parameters(self, defer_init: Optional[bool] = False) -> None: # Wrap parameters in QuantizedTensor if needed fp8_meta_index = self.param_init_meta[name].fp8_meta_index high_precision_init_val = None - if ( - self.primary_weights_in_fp8 - and fp8_meta_index is not None - and gtp_sharded is None - ): + if self.primary_weights_in_fp8 and fp8_meta_index is not None and gtp_sharded is None: # Keep high-precision values on CPU if needed if self.preserve_high_precision_init_val: @@ -1739,9 +1733,7 @@ def clear(self): gtp_finalize_module_in_reset_parameters, ) - gtp_finalize_module_in_reset_parameters( - self, _gtp_sharded_weight_names - ) + gtp_finalize_module_in_reset_parameters(self, _gtp_sharded_weight_names) @abstractmethod def forward(self): diff --git a/transformer_engine/pytorch/module/generalized_tensor_parallelism.py b/transformer_engine/pytorch/module/generalized_tensor_parallelism.py index ecc6785314..d6028a72bb 100644 --- a/transformer_engine/pytorch/module/generalized_tensor_parallelism.py +++ b/transformer_engine/pytorch/module/generalized_tensor_parallelism.py @@ -14,7 +14,7 @@ from ..distributed import ( gather_along_first_dim, reduce_scatter_along_first_dim, - _NVFP4AllGatherAsyncHandle + _NVFP4AllGatherAsyncHandle, ) from ..quantized_tensor import QuantizedTensor from ..tensor import NVFP4TensorStorage, MXFP8TensorStorage @@ -37,6 +37,7 @@ class GTPChain(str, Enum): Chains never cross-link (prev_w/next_w stay within one chain). CG disabled → single UNGRAPHED chain; full-iteration graph → single GRAPHED. """ + GRAPHED = "GTP_graphed" UNGRAPHED = "GTP_ungraphed" @@ -62,7 +63,7 @@ def set_cuda_graph_scope(scope, moe_shared_expert_overlap: bool = False): _MOE_SHARED_EXPERT_OVERLAP = bool(moe_shared_expert_overlap) -def _classify_param_chain(param_name: str) -> 'GTPChain': +def _classify_param_chain(param_name: str) -> "GTPChain": """Classify an GTPShardedParam by name + active cuda_graph_scope. embedding / output_layer are always UNGRAPHED. Other kinds (mamba mixer, @@ -132,13 +133,12 @@ def classify_gtp_chains(model) -> None: class GTPWeightState(Enum): - NONE = "NONE" # Sharded, no pending operation + NONE = "NONE" # Sharded, no pending operation ASYNC_WAIT = "ASYNC_WAIT" # Async all-gather in progress DATA_READY = "DATA_READY" # Async all-gather complete, result in cache DATA_READY_SYNC = "DATA_READY_SYNC" # Sync all-gather complete, result in cache - # Global GTP buffer cache (persists across clear(); never set to None after creation). _GTP_CACHE = None _GTP_PARAMS = [] @@ -171,7 +171,7 @@ def _wgrad_pool_get(shape: tuple, dtype: torch.dtype, device) -> torch.Tensor: def _wgrad_pool_put(buf: torch.Tensor): """Return a pool-owned buffer for reuse (no-op for untagged buffers; see _wgrad_pool_get).""" - if not getattr(buf, '_from_gtp_wgrad_pool', False): + if not getattr(buf, "_from_gtp_wgrad_pool", False): return key = (tuple(buf.shape), buf.dtype) if key not in _wgrad_buf_pool: @@ -226,9 +226,10 @@ def get_rs_streams_for_chain(chain_id: str) -> list: """RS streams for one chain (all groups that chain has touched).""" return [s for k, s in _RS_STREAMS.items() if k[0] == chain_id] + # Cached once per process: whether the TE build exposes the split-phase APIs. -_COALESCED_AMAX_TE_APIS_AVAILABLE = ( - hasattr(tex, "compute_amax_nvfp4") and hasattr(tex, "quantize_cast_only_nvfp4") +_COALESCED_AMAX_TE_APIS_AVAILABLE = hasattr(tex, "compute_amax_nvfp4") and hasattr( + tex, "quantize_cast_only_nvfp4" ) # Tier-2: multi-tensor amax kernel fuses N per-expert (zero_amax + amax + D2D) chains @@ -257,7 +258,7 @@ def _coalesced_amax_static_eligible(weights): def _quantize_with_coalesced_amax(weights, skip_weight_cast, cast_noop_flag): """Replace the per-weight (compute_amax + allreduce + cast) loop with: - compute_amax loop → one coalesced allreduce → cast loop.""" + compute_amax loop → one coalesced allreduce → cast loop.""" group = weights[0]._quantizer.amax_reduction_group # Materialize padded shards once; on padded last-rank get_padded_shard() @@ -321,6 +322,7 @@ def _quantize_with_coalesced_amax(weights, skip_weight_cast, cast_noop_flag): @dataclass class GTPConfig: """Global configuration for Generalized Tensor Parallelism.""" + pad_for_alignment: int = 16 check_param_states: bool = False weight_prefetch: bool = True @@ -350,6 +352,7 @@ class GTPConfig: # path when either binding is missing. coalesce_amax_allreduce: bool = True + GTP_CONFIG = GTPConfig() @@ -396,8 +399,8 @@ def _gtp_slice_one_param(param, gtp_group, *, name=""): assert tensor.shape[0] % gtp_size == 0, ( f"_gtp_slice_one_param: {name}.shape[0]={tensor.shape[0]} is not " f"divisible by gtp_size={gtp_size}. Either enable padding by " - f"setting GTP_CONFIG.pad_for_alignment > 0, or ensure the weight's " - f"dim-0 is a multiple of the GTP group size." + "setting GTP_CONFIG.pad_for_alignment > 0, or ensure the weight's " + "dim-0 is a multiple of the GTP group size." ) pad_length = 0 @@ -450,9 +453,7 @@ def wrap_module_params_gtp(module, weight_names, gtp_group, is_grouped=None): delattr(module, name) gtp_shard = _gtp_slice_one_param(param, gtp_group, name=name) del param - _gtp_attach_attrs( - gtp_shard, gtp_group, is_grouped=bool(is_grouped), expert_idx=idx - ) + _gtp_attach_attrs(gtp_shard, gtp_group, is_grouped=bool(is_grouped), expert_idx=idx) # register the newly sharded param back to the module module._parameters[name] = gtp_shard @@ -478,9 +479,7 @@ def gtp_slice_in_reset_parameters(module, name, param, expert_idx=0): return None is_grouped = bool(getattr(module, "_gtp_is_grouped", False)) gtp_shard = _gtp_slice_one_param(param, gtp_group, name=name) - _gtp_attach_attrs( - gtp_shard, gtp_group, is_grouped=is_grouped, expert_idx=expert_idx - ) + _gtp_attach_attrs(gtp_shard, gtp_group, is_grouped=is_grouped, expert_idx=expert_idx) return gtp_shard @@ -533,15 +532,17 @@ class GTPShardedParam(torch.nn.Parameter): def _get_chain_state(cls, chain_id: str) -> dict: if chain_id not in cls._chain_state: cls._chain_state[chain_id] = { - 'last_weight': None, - 'link_node_count': 0, - 'link_table_buffer': [], - 'link_table_flushed': False, + "last_weight": None, + "link_node_count": 0, + "link_table_buffer": [], + "link_table_flushed": False, } return cls._chain_state[chain_id] @classmethod - def _buffer_link_table_row(cls, prev: "GTPShardedParam", curr: "GTPShardedParam", chain: dict) -> None: + def _buffer_link_table_row( + cls, prev: "GTPShardedParam", curr: "GTPShardedParam", chain: dict + ) -> None: """Buffer one row of the prefetch-link table (flushed atomically on the second forward pass).""" _W = 70 @@ -549,31 +550,31 @@ def _layer_id(name: str) -> str: m = re.search(r"\d+", name) return m.group() if m else "-" - chain['link_node_count'] += 1 - if chain['link_node_count'] == 1: - chain_id = getattr(curr, 'chain_id', GTPChain.UNGRAPHED.value) - chain['link_table_buffer'].append( - f"\n[{chain_id} chain]" - f"\n{'node_id':>7} | {'layer_id':>8} | {'curr_weight_name':<{_W}} | prev_weight_name" - f"\n{'-'*7}-+-{'-'*8}-+-{'-'*_W}-+-{'-'*_W}" + chain["link_node_count"] += 1 + if chain["link_node_count"] == 1: + chain_id = getattr(curr, "chain_id", GTPChain.UNGRAPHED.value) + chain["link_table_buffer"].append( + f"\n[{chain_id} chain]\n{'node_id':>7} | {'layer_id':>8} |" + f" {'curr_weight_name':<{_W}} |" + f" prev_weight_name\n{'-'*7}-+-{'-'*8}-+-{'-'*_W}-+-{'-'*_W}" ) # Seed weight (first GTP param) as row 0 - chain['link_table_buffer'].append( + chain["link_table_buffer"].append( f"{'0':>7} | {_layer_id(prev._debug_name):>8} | {prev._debug_name:<{_W}} | -" ) - chain['link_table_buffer'].append( + chain["link_table_buffer"].append( f"{chain['link_node_count']:>7} | {_layer_id(curr._debug_name):>8} | " f"{curr._debug_name:<{_W}} | {prev._debug_name}" ) @staticmethod def __new__(cls, tensor, *args, **kwargs): - requires_grad = kwargs.get('requires_grad', True) + requires_grad = kwargs.get("requires_grad", True) return super(GTPShardedParam, cls).__new__(cls, tensor, requires_grad=requires_grad) def __init__(self, x, *args, **kwargs): super().__init__() - + # all gather self.state = GTPWeightState.NONE self._ag_ticket_fwd = None @@ -627,9 +628,10 @@ def setup(self, weight_quantizer=None): """Set quantizer and create quantized shard.""" if self._quantizer is None: + def _configure_quantizer(q, group): q = q.copy() - if hasattr(q, 'with_amax_reduction'): + if hasattr(q, "with_amax_reduction"): q.with_amax_reduction = True q.amax_reduction_group = group q.internal = False @@ -640,14 +642,18 @@ def _configure_quantizer(q, group): q.optimize_for_gemm = not isinstance(q, MXFP8Quantizer) return q - weights = self.weight_list if self.is_routed_expert and self.weight_list is not None else [self] + weights = ( + self.weight_list + if self.is_routed_expert and self.weight_list is not None + else [self] + ) for quantizer, weight in zip(weight_quantizer, weights): if quantizer is None: continue weight._quantizer = _configure_quantizer(quantizer, weight.group) weight.quantized = weight._quantizer.quantize(weight.get_padded_shard()) - weight.quantized.is_routed_expert = getattr(weight, 'is_routed_expert', False) + weight.quantized.is_routed_expert = getattr(weight, "is_routed_expert", False) # fp8_param_gather: the init quantize above already produced a # valid FP8 cache from the BF16 shard; flag did_cast so iter-0's # forward _quantize_if_needed short-circuits and the redundant @@ -701,9 +707,16 @@ def _get_cache_key(self, dtype, fwd: bool, reduce_scatter: bool) -> tuple: For expert weights gathered in parallel, self.expert_idx distinguishes them so each gets a distinct buffer, while same-indexed experts across layers share. """ - + if not isinstance(dtype, torch.dtype): - return (self._unsharded_shape_padded, dtype, fwd, not fwd, self.expert_idx, reduce_scatter) + return ( + self._unsharded_shape_padded, + dtype, + fwd, + not fwd, + self.expert_idx, + reduce_scatter, + ) return (self._unsharded_shape_padded, dtype, self.expert_idx, reduce_scatter) def _quantize_if_needed(self, skip_weight_cast=False, cast_noop_flag=None): @@ -734,23 +747,22 @@ def _strip_padding(self, tensor): return tensor if isinstance(tensor, QuantizedTensor): - assert isinstance(tensor, (NVFP4TensorStorage, MXFP8TensorStorage)), \ - f"Unsupported quantized tensor type for GTP padding: {type(tensor)}" + assert isinstance( + tensor, (NVFP4TensorStorage, MXFP8TensorStorage) + ), f"Unsupported quantized tensor type for GTP padding: {type(tensor)}" metadata = tensor.get_metadata() if metadata.get("rowwise_data") is not None: - metadata["rowwise_data"] = metadata["rowwise_data"][:-self.pad_length] + metadata["rowwise_data"] = metadata["rowwise_data"][: -self.pad_length] if metadata.get("columnwise_data") is not None: if isinstance(tensor, NVFP4TensorStorage): # NVFP4 transposes columnwise and packs 2 values per byte metadata["columnwise_data"] = metadata["columnwise_data"][ - ..., :-self.pad_length // 2 + ..., : -self.pad_length // 2 ].contiguous() else: # MXFP8 columnwise is not transposed, strip first dim - metadata["columnwise_data"] = metadata["columnwise_data"][ - :-self.pad_length - ] + metadata["columnwise_data"] = metadata["columnwise_data"][: -self.pad_length] M = self._unsharded_shape[0] if isinstance(tensor, NVFP4TensorStorage): # NVFP4 scale_inv shapes (see NVFP4Quantizer.get_scale_shape): @@ -764,9 +776,9 @@ def _strip_padding(self, tensor): m_tiles = round_up_to_nearest_multiple( math.ceil(M / NVFP4_BLOCK_SCALING_SIZE), 4 ) - metadata["columnwise_scale_inv"] = ( - metadata["columnwise_scale_inv"][:, :m_tiles].contiguous() - ) + metadata["columnwise_scale_inv"] = metadata["columnwise_scale_inv"][ + :, :m_tiles + ].contiguous() else: # MXFP8 scale_inv shapes (see MXFP8Quantizer.get_scale_shape): # rowwise_scale_inv: [round_up(M, 128), round_up(K//32, 4)] @@ -776,24 +788,18 @@ def _strip_padding(self, tensor): m_rows = round_up_to_nearest_multiple(M, 128) metadata["rowwise_scale_inv"] = metadata["rowwise_scale_inv"][:m_rows] if metadata.get("columnwise_scale_inv") is not None: - m_tiles = round_up_to_nearest_multiple( - M // MXFP8_BLOCK_SCALING_SIZE, 4 - ) - metadata["columnwise_scale_inv"] = ( - metadata["columnwise_scale_inv"][:m_tiles] - ) + m_tiles = round_up_to_nearest_multiple(M // MXFP8_BLOCK_SCALING_SIZE, 4) + metadata["columnwise_scale_inv"] = metadata["columnwise_scale_inv"][:m_tiles] return type(tensor)(**metadata, shape=self._unsharded_shape, dtype=torch.bfloat16) else: - return tensor[:-self.pad_length] + return tensor[: -self.pad_length] def _all_gather_weight(self, async_op, skip_weight_cast, cast_noop_flag, fwd, nvtx_label=None): """Quantize (if needed) and all-gather weight. Returns (weight_total, handle).""" if nvtx_label is None: nvtx_label = ( - self._debug_name - + (".fwd" if fwd else ".bwd") - + (".async" if async_op else ".sync") + self._debug_name + (".fwd" if fwd else ".bwd") + (".async" if async_op else ".sync") ) nvtx_range_push(f"{nvtx_label}.all_gather_weight") @@ -817,9 +823,7 @@ def _all_gather_weight(self, async_op, skip_weight_cast, cast_noop_flag, fwd, nv # Per-call: match the skip_weight_cast gate in _quantize_if_needed # (fire when either skip_weight_cast is False or cast_noop_flag # was provided by the FP8/NVFP4 recipe). - use_coalesced = static_ok and not ( - skip_weight_cast is True and cast_noop_flag is None - ) + use_coalesced = static_ok and not (skip_weight_cast is True and cast_noop_flag is None) else: use_coalesced = False @@ -877,8 +881,9 @@ def _all_gather_weight(self, async_op, skip_weight_cast, cast_noop_flag, fwd, nv self._cached_gtp_group = gtp_group if GTP_CONFIG.check_param_states and len(gather_weights) > 1: # Debug invariant: batched AG needs distinct output buffers per expert. - assert len(set(id(b) for b in out_buffers)) == len(out_buffers), \ - "Duplicate output buffers in batched all-gather — experts need distinct cache keys" + assert len(set(id(b) for b in out_buffers)) == len( + out_buffers + ), "Duplicate output buffers in batched all-gather — experts need distinct cache keys" # ASYNC AG: wrap issue on ag_stream — ag_stream's tail then reflects # the collective's full lifecycle (what external wait_stream(ag_stream) @@ -890,7 +895,7 @@ def _all_gather_weight(self, async_op, skip_weight_cast, cast_noop_flag, fwd, nv if async_op: outer_stream = torch.cuda.current_stream() ag_stream = get_ag_stream(self.chain_id, gtp_group) - if getattr(self, '_ag_outer_sync_event', None) is None: + if getattr(self, "_ag_outer_sync_event", None) is None: self._ag_outer_sync_event = torch.cuda.Event() outer_sync_event = self._ag_outer_sync_event outer_sync_event.record(outer_stream) @@ -903,7 +908,8 @@ def _all_gather_weight(self, async_op, skip_weight_cast, cast_noop_flag, fwd, nv if len(gather_weights) > 1: nvtx_range_push(f"{nvtx_label}.batched_gtp_ag") results, handle = grouped_gather_along_first_dim( - gather_weights, gtp_group, + gather_weights, + gtp_group, async_op=async_op, quantizers=quantizers, output_tensors=out_buffers, @@ -912,7 +918,8 @@ def _all_gather_weight(self, async_op, skip_weight_cast, cast_noop_flag, fwd, nv else: nvtx_range_push(f"{nvtx_label}.gtp_ag") weight_total, handle = gather_along_first_dim( - gather_weights[0], gtp_group, + gather_weights[0], + gtp_group, quantizer=quantizers[0], async_op=async_op, output_tensor=out_buffers[0] if out_buffers is not None else None, @@ -954,7 +961,7 @@ def _all_gather_weight_on_demand(self, fwd, skip_weight_cast=False, cast_noop_fl ) result = result if self.is_routed_expert else [result] result = [self._strip_padding(r) for r in result] - result = [r.detach().requires_grad_(w.requires_grad) for r, w in zip(result,self._weights)] + result = [r.detach().requires_grad_(w.requires_grad) for r, w in zip(result, self._weights)] return result if self.is_routed_expert else result[0] def _get_prefetched_weight(self, fwd, skip_weight_cast=False, cast_noop_flag=None): @@ -969,10 +976,10 @@ def _get_prefetched_weight(self, fwd, skip_weight_cast=False, cast_noop_flag=Non ), ( f"[GTP] _get_prefetched_weight({'fwd' if fwd else 'bwd'}) on " f"{self._debug_name} with state={w.state!r} — no AG issued; " - f"cache.get() would return stale data. Check the chain's " - f"_need_weight_prefetch flag and issuer's prefetch logic." + "cache.get() would return stale data. Check the chain's " + "_need_weight_prefetch flag and issuer's prefetch logic." ) - _was_drained = getattr(self, '_already_ag_drained', False) + _was_drained = getattr(self, "_already_ag_drained", False) if _was_drained: # Producer already drained via wait_async_comms; skip the captured # cross-graph wait (CUDA no-op anyway). Correctness is provided by @@ -1024,8 +1031,11 @@ def all_gather_and_prefetch_bwd(self, nvtx_label=None): # the NCCL collective itself is wrapped on ag_stream inside # _all_gather_weight (see the async/sync gate there for rationale). _, handle = self.prev_w._all_gather_weight( - async_op=True, skip_weight_cast=True, cast_noop_flag=None, - fwd=False, nvtx_label=nvtx_label, + async_op=True, + skip_weight_cast=True, + cast_noop_flag=None, + fwd=False, + nvtx_label=nvtx_label, ) self.prev_w._prefetch_handle = handle @@ -1076,7 +1086,8 @@ def all_gather_and_prefetch( async_op=True, skip_weight_cast=skip_weight_cast, cast_noop_flag=cast_noop_flag, - fwd=fwd, nvtx_label=nvtx_label, + fwd=fwd, + nvtx_label=nvtx_label, ) self.next_w._prefetch_handle = handle @@ -1090,7 +1101,7 @@ def all_gather_and_prefetch( cls = type(self) chain = cls._get_chain_state(self.chain_id) if not self.prefetch_initialized: - last_w = chain['last_weight'] + last_w = chain["last_weight"] if last_w is not None and last_w.next_w is None: cls._buffer_link_table_row(last_w, self, chain) last_w.next_w = self @@ -1100,18 +1111,20 @@ def all_gather_and_prefetch( # Set the fwd ag buffer quantizers = [w._quantizer for w in self._weights] - dtypes = [q.dtype if q is not None else w.dtype for q, w in zip(quantizers, self._weights)] + dtypes = [ + q.dtype if q is not None else w.dtype for q, w in zip(quantizers, self._weights) + ] for w, dt in zip(self._weights, dtypes): w._ag_ticket_fwd = cache.reserve(w, dt, fwd=True) cache.get(w._ag_ticket_fwd) cache.release(w._ag_ticket_fwd) self.prefetch_initialized = True - chain['last_weight'] = self - elif not chain['link_table_flushed'] and chain['link_table_buffer']: + chain["last_weight"] = self + elif not chain["link_table_flushed"] and chain["link_table_buffer"]: # Second forward pass: flush the complete table atomically to avoid interleaving - chain['link_table_flushed'] = True - print_rank_0("\n".join(chain['link_table_buffer']) + "\n") + chain["link_table_flushed"] = True + print_rank_0("\n".join(chain["link_table_buffer"]) + "\n") return result @@ -1148,13 +1161,12 @@ def _handle_megatron_grad_accum(param): if hasattr(param, "grad_added_to_main_grad"): param.grad_added_to_main_grad = True dummy_grad = get_dummy_wgrad(list(param.main_grad.shape), param.dtype) - if getattr(param, '_grad_accum_hook', None) is not None: + if getattr(param, "_grad_accum_hook", None) is not None: param._grad_accum_hook() param._set_rs_state(GTPWeightState.NONE) return dummy_grad - def _wait_reduce_scatter(self, finalize_grad=False): # Enter rs_stream context so handle.wait() + rs_event.record() land # on rs_stream — mirrors _wait_param_gather for the RS path. @@ -1182,7 +1194,7 @@ def _wait_reduce_scatter(self, finalize_grad=False): self._already_finalized = True # Release stashed wgrad inputs: UNGRAPHED buffers go back to the pool; # GRAPHED just drops Python refs (addresses must stay stable for CG). - if getattr(self, '_wgrad_input_bufs', None) is not None: + if getattr(self, "_wgrad_input_bufs", None) is not None: if self.chain_id == GTPChain.UNGRAPHED.value: for buf in self._wgrad_input_bufs: _wgrad_pool_put(buf) @@ -1195,16 +1207,10 @@ def _reduce_scatter(self, wgrads, async_op, nvtx_label=None): Multiple tensors: coalesced reduce-scatter. """ if nvtx_label is None: - nvtx_label = ( - self._debug_name - + ".bwd" - + (".async" if async_op else ".sync") - ) + nvtx_label = self._debug_name + ".bwd" + (".async" if async_op else ".sync") if GTP_CONFIG.check_param_states: - new_rs_state = ( - GTPWeightState.ASYNC_WAIT if async_op else GTPWeightState.DATA_READY_SYNC - ) + new_rs_state = GTPWeightState.ASYNC_WAIT if async_op else GTPWeightState.DATA_READY_SYNC for w in self._weights: w._set_rs_state(new_rs_state) @@ -1232,7 +1238,7 @@ def _reduce_scatter(self, wgrads, async_op, nvtx_label=None): if async_op: outer_stream = torch.cuda.current_stream() rs_stream = get_rs_stream(self.chain_id, self.group) - if getattr(self, '_rs_outer_sync_event', None) is None: + if getattr(self, "_rs_outer_sync_event", None) is None: self._rs_outer_sync_event = torch.cuda.Event() outer_sync_event = self._rs_outer_sync_event outer_sync_event.record(outer_stream) @@ -1258,7 +1264,9 @@ def _reduce_scatter(self, wgrads, async_op, nvtx_label=None): async_ops=async_op, ) as cm: for out_buffer, tensor in zip(out_buffers, wgrads): - out, _ = reduce_scatter_along_first_dim(tensor, self.group, output=out_buffer) + out, _ = reduce_scatter_along_first_dim( + tensor, self.group, output=out_buffer + ) outputs.append(out) nvtx_range_pop(f"{nvtx_label}.batched_gtp_rs") @@ -1297,7 +1305,7 @@ def wgrad_reduce_scatter(self, wgrad, nvtx_label=None): wgrads, _ = self._reduce_scatter(wgrads, async_op=False, nvtx_label=nvtx_label) torch._foreach_add_([p.main_grad for p in weights], wgrads) result = [self._handle_megatron_grad_accum(p) for p in weights] - + if poolable: for buf in wgrads: _wgrad_pool_put(buf) @@ -1308,7 +1316,7 @@ def wgrad_reduce_scatter(self, wgrad, nvtx_label=None): if GTP_CONFIG.async_reduction and self.next_w is not None: self.next_w._wait_reduce_scatter() - if getattr(self.next_w, '_already_finalized', False): + if getattr(self.next_w, "_already_finalized", False): self.next_w._already_finalized = False else: self.next_w.rs_event.wait() @@ -1354,15 +1362,17 @@ def print_rank_0(message, rank=None): else: print(message, flush=True) + @dataclass class _TicketSlot: """Internal slot backing a persistent ticket in the GTP buffer cache.""" - key: tuple # cache key (shape, dtype, ...) - param: 'GTPShardedParam' # for lazy allocation metadata - dtype: object # torch.dtype or tex.DType + + key: tuple # cache key (shape, dtype, ...) + param: "GTPShardedParam" # for lazy allocation metadata + dtype: object # torch.dtype or tex.DType reduce_scatter: bool fwd: bool - chain_id: str = GTPChain.GRAPHED.value # chain this slot belongs to + chain_id: str = GTPChain.GRAPHED.value # chain this slot belongs to buf: Optional[torch.Tensor] = field(default=None) # None when released or after clear() @@ -1395,7 +1405,7 @@ def __init__(self): self._pool: Dict[tuple, List[torch.Tensor]] = defaultdict(list) self._slots: Dict[int, _TicketSlot] = {} self._next_ticket: int = 0 - self._total_bytes: int = 0 # running total of allocated bytes + self._total_bytes: int = 0 # running total of allocated bytes self.key_to_allocate_func = {} @staticmethod @@ -1407,7 +1417,9 @@ def _buf_bytes(shape, dtype) -> int: bpe = GTPWeightCache._BYTES_PER_ELEMENT.get(dtype, None) return numel * bpe - def _allocate_buffer(self, param: 'GTPShardedParam', dtype, reduce_scatter, fwd) -> torch.Tensor: + def _allocate_buffer( + self, param: "GTPShardedParam", dtype, reduce_scatter, fwd + ) -> torch.Tensor: if reduce_scatter: out_shape = param._sharded_padded_shape else: @@ -1419,8 +1431,8 @@ def _allocate_buffer(self, param: 'GTPShardedParam', dtype, reduce_scatter, fwd) param._quantizer.set_usage(rowwise=fwd, columnwise=not fwd) buf = param._quantizer.make_empty( - out_shape, - dtype=torch.bfloat16, + out_shape, + dtype=torch.bfloat16, device=torch.cuda.current_device(), ) else: @@ -1436,15 +1448,19 @@ def _allocate_buffer(self, param: 'GTPShardedParam', dtype, reduce_scatter, fwd) ) return buf - def reserve(self, param: 'GTPShardedParam', dtype, fwd: bool, reduce_scatter=False) -> int: + def reserve(self, param: "GTPShardedParam", dtype, fwd: bool, reduce_scatter=False) -> int: """Assign a persistent ticket. No buffer is allocated until ``get()``.""" key = param._get_cache_key(dtype, fwd, reduce_scatter) ticket = self._next_ticket self._next_ticket += 1 self._slots[ticket] = _TicketSlot( - key=key, param=param, dtype=dtype, reduce_scatter=reduce_scatter, fwd=fwd, - chain_id=getattr(param, 'chain_id', GTPChain.UNGRAPHED.value), + key=key, + param=param, + dtype=dtype, + reduce_scatter=reduce_scatter, + fwd=fwd, + chain_id=getattr(param, "chain_id", GTPChain.UNGRAPHED.value), ) return ticket @@ -1453,11 +1469,20 @@ def get(self, ticket: int) -> torch.Tensor: slot = self._slots[ticket] if slot.buf is None: pool = self._pool[slot.key] - slot.buf = pool.pop() if pool else self._allocate_buffer( - slot.param, slot.dtype, slot.reduce_scatter, fwd=slot.fwd + slot.buf = ( + pool.pop() + if pool + else self._allocate_buffer( + slot.param, slot.dtype, slot.reduce_scatter, fwd=slot.fwd + ) + ) + self.key_to_allocate_func[slot.key] = ( + slot.param, + slot.dtype, + slot.reduce_scatter, + slot.fwd, ) - self.key_to_allocate_func[slot.key] = (slot.param, slot.dtype, slot.reduce_scatter, slot.fwd) - + return slot.buf def release(self, ticket: int): @@ -1522,7 +1547,11 @@ def reallocate_to_mempool(self, device, mempool): # Replace each GRAPHED slot's reference; keep UNGRAPHED slots unchanged for slot in self._slots.values(): - if slot.chain_id == GTPChain.GRAPHED.value and slot.buf is not None and slot.buf in old_to_new_buff: + if ( + slot.chain_id == GTPChain.GRAPHED.value + and slot.buf is not None + and slot.buf in old_to_new_buff + ): slot.buf = old_to_new_buff[slot.buf] # Merge: GRAPHED keys get new buffers, UNGRAPHED keys keep old ones @@ -1559,6 +1588,7 @@ def reallocate_to_mempool(self, device, mempool): return + def get_global_GTP_cache() -> GTPWeightCache: """Get or lazily create the global cache instance.""" global _GTP_CACHE @@ -1573,7 +1603,9 @@ def reallocate_gtp_cache_to_mempool(device, mempool): _GTP_CACHE.reallocate_to_mempool(device, mempool) -def wait_async_comms(chain_id: str = None, skip_rs: bool = False, finalize_after_drain: bool = False): +def wait_async_comms( + chain_id: str = None, skip_rs: bool = False, finalize_after_drain: bool = False +): """Drain in-flight GTP async AG / RS handles. When called inside CUDA graph capture, the drains are captured into that @@ -1597,7 +1629,10 @@ def wait_async_comms(chain_id: str = None, skip_rs: bool = False, finalize_after * _already_finalized = True (if finalize_after_drain=True) """ for param in list(_inflight_comm_params): - if chain_id is not None and getattr(param, 'chain_id', GTPChain.UNGRAPHED.value) != chain_id: + if ( + chain_id is not None + and getattr(param, "chain_id", GTPChain.UNGRAPHED.value) != chain_id + ): continue had_ag = param._prefetch_handle is not None param._wait_param_gather() @@ -1605,7 +1640,7 @@ def wait_async_comms(chain_id: str = None, skip_rs: bool = False, finalize_after param._already_ag_drained = True if not skip_rs: param._wait_reduce_scatter(finalize_grad=finalize_after_drain) - if finalize_after_drain and not getattr(param, '_already_finalized', False): + if finalize_after_drain and not getattr(param, "_already_finalized", False): cache = get_global_GTP_cache() param.rs_event.wait() for w in param._weights: @@ -1617,6 +1652,7 @@ def wait_async_comms(chain_id: str = None, skip_rs: bool = False, finalize_after @dataclass class BatchedNVFP4AllGatherAsyncHandle: """Handle for batched asynchronous NVFP4 all-gathers.""" + output_handles: List[_NVFP4AllGatherAsyncHandle] outer_async_handle: torch.distributed.Work _synchronized: bool = False @@ -1655,7 +1691,8 @@ def grouped_gather_along_first_dim( inp = weights[0] if isinstance(inp, NVFP4TensorStorage): device = ( - inp._rowwise_data.device if inp._rowwise_data is not None + inp._rowwise_data.device + if inp._rowwise_data is not None else inp._columnwise_data.device ) else: @@ -1664,11 +1701,14 @@ def grouped_gather_along_first_dim( weights_all = [] weight_handles = [] with torch.distributed._coalescing_manager( - group=process_group, device=device, async_ops=async_op, + group=process_group, + device=device, + async_ops=async_op, ) as gather_coalescing_manager: for i, weight in enumerate(weights): weight_all, weight_handle = gather_along_first_dim( - weight, process_group, + weight, + process_group, quantizer=quantizers[i], output_tensor=output_tensors[i] if output_tensors is not None else None, grouped=True, @@ -1678,9 +1718,7 @@ def grouped_gather_along_first_dim( if async_op: handle = gather_coalescing_manager - has_nvfp4_handles = any( - isinstance(wh, _NVFP4AllGatherAsyncHandle) for wh in weight_handles - ) + has_nvfp4_handles = any(isinstance(wh, _NVFP4AllGatherAsyncHandle) for wh in weight_handles) if has_nvfp4_handles: handle = BatchedNVFP4AllGatherAsyncHandle(weight_handles, handle) else: diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index fa343fc61c..7468b7ddc9 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -322,7 +322,9 @@ def forward( # MCore FSDP creates main_grad lazily before backward ctx.main_grad_funcs = [weights[i].get_main_grad for i in range(num_gemms)] elif gtp_size > 1: - ctx.main_grad_funcs = [weights_gtp_sharded[i].get_wgrad_tensor for i in range(num_gemms)] + ctx.main_grad_funcs = [ + weights_gtp_sharded[i].get_wgrad_tensor for i in range(num_gemms) + ] else: ctx.main_grad_funcs = [ lambda j=i: weights[j].main_grad for i in range(num_gemms) @@ -608,7 +610,8 @@ def backward( use_split_accumulator=wgrad_gemm_use_split_accumulator, accumulate=( accumulate_wgrad_into_param_main_grad - if ctx.gtp_size == 1 and not getattr(ctx, "origin_weights_overwrite_main_grad", False) + if ctx.gtp_size == 1 + and not getattr(ctx, "origin_weights_overwrite_main_grad", False) else False ), ) diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index ad88c91af6..11a07daac0 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -776,9 +776,7 @@ def backward( # follows. The prev_w AG prefetch issued by # all_gather_and_prefetch_bwd above then overlaps with the # wgrad GEMM. - swap_wgrad_dgrad = ( - ctx.gtp_size > 1 and GTP_CONFIG.wgrad_before_dgrad - ) + swap_wgrad_dgrad = ctx.gtp_size > 1 and GTP_CONFIG.wgrad_before_dgrad dgrad = None dgrad_work = None diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 0b6d0535d4..5ea848d6cf 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -968,16 +968,13 @@ def _linear_backward(args: LinearBwdArgs) -> Tuple[Union[torch.Tensor, None], .. # When GTPConfig.wgrad_before_dgrad is True and GTP is active, run # _do_wgrad before _do_dgrad so the GTP wgrad RS NCCL overlaps with # the dgrad GEMM that follows. - swap_wgrad_dgrad = ( - bwd_args.gtp_size > 1 and GTP_CONFIG.wgrad_before_dgrad - ) + swap_wgrad_dgrad = bwd_args.gtp_size > 1 and GTP_CONFIG.wgrad_before_dgrad def _do_dgrad(): nonlocal dgrad, dgrad_work, weight_fp8 if not bwd_args.requires_dgrad: return - # FSDP2: Re-create workspace from all-gathered weight when # workspace was not saved. (Issue #2681) # Use saved_weight (the original weight parameter) since @@ -1096,7 +1093,6 @@ def _do_wgrad(): if not bwd_args.requires_wgrad: return - # Prepare input tensor # Note: Synchronize tensor-parallel communication and # make sure required data is available @@ -1279,7 +1275,6 @@ def wgrad_gemm( else: dgrad = ub_obj_wgrad.get_buffer(local_chunk=True).clone() - # Dispatch wgrad/dgrad in chosen order. if swap_wgrad_dgrad: _do_wgrad() From ea1905568db4f027a60a03b4f54d18606581bf9f Mon Sep 17 00:00:00 2001 From: Shiqing Fan Date: Mon, 18 May 2026 00:30:01 -0700 Subject: [PATCH 3/8] fix lint; fix comments Signed-off-by: Shiqing Fan --- tests/pytorch/distributed/test_gtp.py | 119 ++++++++++++++++- tests/pytorch/distributed/test_tp_gtp.py | 5 +- transformer_engine/pytorch/distributed.py | 19 +-- .../module/generalized_tensor_parallelism.py | 124 ++++++++++++------ .../pytorch/module/grouped_linear.py | 7 +- .../pytorch/module/layernorm_linear.py | 2 +- 6 files changed, 218 insertions(+), 58 deletions(-) diff --git a/tests/pytorch/distributed/test_gtp.py b/tests/pytorch/distributed/test_gtp.py index f29075e06f..84e3968dc0 100644 --- a/tests/pytorch/distributed/test_gtp.py +++ b/tests/pytorch/distributed/test_gtp.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -28,6 +28,7 @@ 20. TestGTPPrefetchDisabled – weight_prefetch=False: single-pass forward still works (multi-GPU) 21. TestFuseWgradAccumulation – fuse_wgrad_accumulation=True: wgrad→main_grad (multi-GPU) 22. TestGTPGradAccumHook – main_grad updated after reduce-scatter backward (multi-GPU) +23. TestWaitAsyncCommsFallback – wait_async_comms(finalize_after_drain=True) inline-accumulation fallback when _wgrad_rs_handle is None (single-process) Multi-GPU tests use torch.multiprocessing.spawn and are skipped when fewer than the required CUDA devices are available. @@ -71,8 +72,6 @@ def reset_fp8_state(): def reset_gtp_globals(): """Reset all GTP mutable class/module-level state between tests.""" yield - GTPShardedParam._first_weight_flag = True - GTPShardedParam._pending_rs_weight = None GTPShardedParam._chain_state = {} @@ -1486,3 +1485,117 @@ class TestGTPGradAccumHook: def test_main_grad_updated_after_backward(self): _requires_multi_gpu(4) _run_distributed(_worker_main_grad_updated_after_bwd, 4) + + +# --------------------------------------------------------------------------- +# 24. wait_async_comms(finalize_after_drain=True) inline-accumulation fallback +# --------------------------------------------------------------------------- + + +class TestWaitAsyncCommsFallback: + """Exercises the inline-accumulation fallback inside + ``wait_async_comms(finalize_after_drain=True)``: when a param is in + ``_inflight_comm_params`` (async AG was issued) but its ``_wgrad_rs_handle`` + is ``None`` (no async RS handle to drain), the inner + ``_wait_reduce_scatter`` call no-ops and the outer loop must inline the + accumulation itself (main_grad.add_ + ticket release + flag set). + + Production flows rarely hit this combination — chain-interior params have + both async AG and async RS, and chain-head sync RS doesn't enter + ``_inflight_comm_params`` via bwd AG. We construct the state by hand to + pin down the fallback's contract. + """ + + class _FakeGroup: + def size(self): return 1 + def rank(self): return 0 + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required") + def test_fallback_accumulates_when_no_rs_handle(self): + dtype = torch.bfloat16 + p = GTPShardedParam(torch.zeros(8, 4, dtype=dtype, device="cuda")) + p.group = self._FakeGroup() + p.expert_idx = None + p.pad_length = 0 + p.chain_id = gtp_module.GTPChain.UNGRAPHED.value + p._quantizer = None + p.is_routed_expert = False # ⇒ self._weights property returns [self] + p.main_grad = torch.zeros(8, 4, dtype=dtype, device="cuda") + p._prefetch_handle = None # _wait_param_gather is no-op + p._wgrad_rs_handle = None # _wait_reduce_scatter is no-op → fallback fires + p._cached_ag_stream = None + p._cached_rs_stream = None + p.ag_event = torch.cuda.Event(external=True) + p.rs_event = torch.cuda.Event(external=True) + p.rs_event.record() # so rs_event.wait() in fallback doesn't block + p._already_finalized = False + p.grad_added_to_main_grad = False + + # Place a known wgrad in the cache for the fallback to read. + cache = gtp_module.get_global_GTP_cache() + p._rs_ticket = cache.reserve(p, dtype, fwd=False, reduce_scatter=True) + cache.get(p._rs_ticket).fill_(2.0) + + # Save + replace _inflight_comm_params so we don't trip over leftover + # params from earlier tests in the loop. + saved = set(gtp_module._inflight_comm_params) + gtp_module._inflight_comm_params.clear() + gtp_module._inflight_comm_params.add(p) + try: + gtp_module.wait_async_comms( + chain_id=p.chain_id, + skip_rs=False, + finalize_after_drain=True, + ) + finally: + gtp_module._inflight_comm_params.clear() + gtp_module._inflight_comm_params.update(saved) + + torch.cuda.synchronize() + assert torch.all(p.main_grad == 2.0), \ + f"main_grad should be 2.0 after fallback accumulation; got {p.main_grad}" + assert p._already_finalized is True, "_already_finalized must be set" + assert p.grad_added_to_main_grad is True, "grad_added_to_main_grad must be set" + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required") + def test_fallback_skipped_when_already_finalized(self): + """When _already_finalized=True, the fallback must NOT re-accumulate.""" + dtype = torch.bfloat16 + p = GTPShardedParam(torch.zeros(8, 4, dtype=dtype, device="cuda")) + p.group = self._FakeGroup() + p.expert_idx = None + p.pad_length = 0 + p.chain_id = gtp_module.GTPChain.UNGRAPHED.value + p._quantizer = None + p.is_routed_expert = False # ⇒ self._weights property returns [self] + # Pre-existing main_grad with a value the fallback must NOT overwrite. + p.main_grad = torch.full((8, 4), 5.0, dtype=dtype, device="cuda") + p._prefetch_handle = None + p._wgrad_rs_handle = None + p._cached_ag_stream = None + p._cached_rs_stream = None + p.ag_event = torch.cuda.Event(external=True) + p.rs_event = torch.cuda.Event(external=True) + p.rs_event.record() + p._already_finalized = True # ← short-circuits the fallback + + # No _rs_ticket: if the fallback ran it would AttributeError on + # cache.get(None). The skip path must not touch the cache at all. + p._rs_ticket = None + + saved = set(gtp_module._inflight_comm_params) + gtp_module._inflight_comm_params.clear() + gtp_module._inflight_comm_params.add(p) + try: + gtp_module.wait_async_comms( + chain_id=p.chain_id, + skip_rs=False, + finalize_after_drain=True, + ) + finally: + gtp_module._inflight_comm_params.clear() + gtp_module._inflight_comm_params.update(saved) + + torch.cuda.synchronize() + assert torch.all(p.main_grad == 5.0), \ + "main_grad must be untouched when _already_finalized=True" \ No newline at end of file diff --git a/tests/pytorch/distributed/test_tp_gtp.py b/tests/pytorch/distributed/test_tp_gtp.py index ce739e43d9..44381e5bff 100644 --- a/tests/pytorch/distributed/test_tp_gtp.py +++ b/tests/pytorch/distributed/test_tp_gtp.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -17,7 +17,6 @@ 2. TestTPGTPColumnParallelLinear – column-parallel Linear: weight shape + fwd/bwd correctness 3. TestTPGTPRowParallelLinear – row-parallel Linear: weight shape + fwd/bwd smoke test 4. TestTPGTPLayerNormLinear – LayerNormLinear column-parallel smoke test -5. TestTPGTPLayerNormMLP – LayerNormMLP (column FC1 + row FC2) smoke test Tests use (tp_size, gtp_size) = (2, 2) → world_size = 4 (runs on 4-GPU machines). @@ -53,8 +52,6 @@ def reset_fp8_state(): def reset_gtp_globals(): """Reset GTP mutable class/module-level state between tests.""" yield - GTPShardedParam._first_weight_flag = True - GTPShardedParam._pending_rs_weight = None GTPShardedParam._chain_state = {} diff --git a/transformer_engine/pytorch/distributed.py b/transformer_engine/pytorch/distributed.py index 3e4d005883..282b53022b 100644 --- a/transformer_engine/pytorch/distributed.py +++ b/transformer_engine/pytorch/distributed.py @@ -1284,15 +1284,13 @@ def _post_process_nvfp4_gather( handle.wait() handle = None - # TODO - # # Fix the interleaved transposed data from gathering along first dim. - # out._columnwise_scale_inv = _swap_first_dims(columnwise_scale_inv_interleaved, world_size) - # out._columnwise_data = _swap_first_dims(columnwise_data_interleaved, world_size) + # Fix the interleaved transposed data from gathering along first dim. + # In-place .copy_() (not `=` rebind) to keep the storage address stable + # for CUDA graph capture — replays see the same pointer they captured. out._columnwise_scale_inv.copy_(_swap_first_dims(columnwise_scale_inv_interleaved, world_size)) out._columnwise_data.copy_(_swap_first_dims(columnwise_data_interleaved, world_size)) - # # Optionally pad the scaling inverse if needed. - # out._columnwise_scale_inv = pad_columnwise_scale_inv(out._columnwise_scale_inv) + # Optionally pad the scaling inverse if needed (same in-place pattern). out._columnwise_scale_inv.copy_(pad_columnwise_scale_inv(out._columnwise_scale_inv)) @@ -1308,6 +1306,10 @@ class _NVFP4AllGatherAsyncHandle: _synchronized: bool = False def post_process_nvfp4_gather(self) -> None: + """Fix interleaved transposed data + pad scale_inv after the async AG completes. + + Idempotent: gated by ``_synchronized`` in :meth:`wait`. + """ _post_process_nvfp4_gather( self.output, self.columnwise_data_interleaved, @@ -1454,9 +1456,8 @@ def _all_gather_nvfp4( group=process_group, ) - # Transfer amax to output. - # TODO: jiemingz - # out._amax_rowwise = inp._amax_rowwise + # Transfer amax to output via in-place .copy_() so the storage + # address stays stable for CUDA graph capture. out._amax_rowwise.copy_(inp._amax_rowwise) # Gather the transposed NVFP4 data along first dimension. Fix format later. diff --git a/transformer_engine/pytorch/module/generalized_tensor_parallelism.py b/transformer_engine/pytorch/module/generalized_tensor_parallelism.py index d6028a72bb..c3ab2414af 100644 --- a/transformer_engine/pytorch/module/generalized_tensor_parallelism.py +++ b/transformer_engine/pytorch/module/generalized_tensor_parallelism.py @@ -1,16 +1,28 @@ -# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. +"""Generalized Tensor Parallelism (GTP). + +Shards weight tensors 1/N across a GTP process group along ``out_features`` +and materializes them on-demand via async all-gather, with a per-weight +prefetch chain + ticket-based buffer cache co-designed for CUDA graph +capture/replay. Quantized AG (FP8 / MXFP8 / NVFP4) composes with the +sharding for compounding bandwidth reduction. +""" + from collections import defaultdict from contextlib import nullcontext -from typing import Dict, List, Optional -from enum import Enum from dataclasses import dataclass, field +from enum import Enum +from typing import Dict, List, Optional import math import re + import torch +import transformer_engine_torch as tex +from ..constants import NVFP4_BLOCK_SCALING_SIZE, MXFP8_BLOCK_SCALING_SIZE from ..distributed import ( gather_along_first_dim, reduce_scatter_along_first_dim, @@ -20,11 +32,8 @@ from ..tensor import NVFP4TensorStorage, MXFP8TensorStorage from ..tensor.mxfp8_tensor import MXFP8Quantizer from ..utils import nvtx_range_pop, nvtx_range_push, round_up_to_nearest_multiple -from ..constants import NVFP4_BLOCK_SCALING_SIZE, MXFP8_BLOCK_SCALING_SIZE from .base import get_dummy_wgrad -import transformer_engine_torch as tex - class GTPChain(str, Enum): """Prefetch chain identifier for an GTPShardedParam. @@ -133,6 +142,7 @@ def classify_gtp_chains(model) -> None: class GTPWeightState(Enum): + """State of a GTPShardedParam's AG / RS lifecycle (debug / stale-read guard).""" NONE = "NONE" # Sharded, no pending operation ASYNC_WAIT = "ASYNC_WAIT" # Async all-gather in progress DATA_READY = "DATA_READY" # Async all-gather complete, result in cache @@ -256,9 +266,13 @@ def _coalesced_amax_static_eligible(weights): return True -def _quantize_with_coalesced_amax(weights, skip_weight_cast, cast_noop_flag): +def _quantize_with_coalesced_amax(weights, cast_noop_flag): """Replace the per-weight (compute_amax + allreduce + cast) loop with: - compute_amax loop → one coalesced allreduce → cast loop.""" + compute_amax loop → one coalesced allreduce → cast loop. + + The caller has already gated on ``skip_weight_cast`` (see + ``_all_gather_weight``); inside this function we always do the work. + """ group = weights[0]._quantizer.amax_reduction_group # Materialize padded shards once; on padded last-rank get_padded_shard() @@ -498,6 +512,12 @@ def gtp_finalize_module_in_reset_parameters(module, weight_names): class GTPShardHandle: + """Wrapper around a ``dist`` async-work handle for a GTP AG / RS. + + Tracks the participating shards so the wait-site can transition their + ``GTPWeightState`` and so the GTP module can prune the param from + ``_inflight_comm_params`` when the collective completes. + """ def __init__(self, handle, gtp_shards, reduce_scatter=False): self.handle = handle @@ -506,6 +526,7 @@ def __init__(self, handle, gtp_shards, reduce_scatter=False): _inflight_comm_params.add(gtp_shards[0]) def wait(self): + """Wait on the underlying NCCL work and update the shards' state.""" if self.handle is not None: self.handle.wait() self.handle = None # Release NCCL Work and its C++ tensor references promptly @@ -520,9 +541,14 @@ def wait(self): class GTPShardedParam(torch.nn.Parameter): + """A weight parameter sharded 1/N across a GTP process group. + + Materialized on-demand via async all-gather and gradient-reduced via + reduce-scatter. Carries its own prefetch-chain wiring (``prev_w`` / + ``next_w``), per-chain state, AG/RS cache tickets, and the metadata the + integrator needs to drive overlap with captured compute. + """ - _pending_rs_weight = None - _first_weight_flag = True # Per-chain state: each chain_id (GTPChain.GRAPHED / GTPChain.UNGRAPHED) has # its own linked list. Chains never cross-link: prev_w/next_w only connect # params with the same chain_id. @@ -568,11 +594,13 @@ def _layer_id(name: str) -> str: ) @staticmethod - def __new__(cls, tensor, *args, **kwargs): + def __new__(cls, tensor, *args, **kwargs): # pylint: disable=unused-argument requires_grad = kwargs.get("requires_grad", True) + # pylint: disable-next=unexpected-keyword-arg return super(GTPShardedParam, cls).__new__(cls, tensor, requires_grad=requires_grad) - def __init__(self, x, *args, **kwargs): + def __init__(self, tensor, *args, **kwargs): + del tensor, args, kwargs super().__init__() # all gather @@ -587,7 +615,8 @@ def __init__(self, x, *args, **kwargs): # classify_gtp_chains() sets this to False for embedding.word_embeddings.weight. self._need_weight_prefetch_bwd = True self.ag_event = torch.cuda.Event(external=True) - # DDP backward hook (set by register_grad_accum_hook); invoked from _finalize_wgrad. + # DDP backward hook (set by register_grad_accum_hook); invoked after + # the wgrad RS accumulation completes (Graphed.backward / chain cascade). self._grad_accum_hook = None # Quantization self._quantizer = None @@ -672,30 +701,36 @@ def _weights(self): @property def _unsharded_shape_padded(self): + """Full unsharded shape *including* the pad rows on the last rank.""" out_shape = list(self.size()) out_shape[0] = out_shape[0] * self.group.size() return tuple(out_shape) @property def _unsharded_shape(self): + """Full unsharded shape with the pad rows stripped (logical shape).""" out_shape = list(self._unsharded_shape_padded) out_shape[0] -= self.pad_length return tuple(out_shape) @property def _sharded_padded_shape(self): + """This rank's local shard shape, padding included.""" return tuple(self.size()) def get_padded_shard(self): + """Return the local shard already containing its share of padding (identity).""" return self def _set_state(self, new_state: GTPWeightState): + """Advance the AG state (only inspected when ``check_param_states`` is on).""" # Only inspected when check_param_states is on; skip writes otherwise. if not GTP_CONFIG.check_param_states: return self.state = new_state def _set_rs_state(self, new_state: GTPWeightState): + """Advance the RS state (only inspected when ``check_param_states`` is on).""" if not GTP_CONFIG.check_param_states: return self.rs_state = new_state @@ -792,8 +827,8 @@ def _strip_padding(self, tensor): metadata["columnwise_scale_inv"] = metadata["columnwise_scale_inv"][:m_tiles] return type(tensor)(**metadata, shape=self._unsharded_shape, dtype=torch.bfloat16) - else: - return tensor[: -self.pad_length] + + return tensor[: -self.pad_length] def _all_gather_weight(self, async_op, skip_weight_cast, cast_noop_flag, fwd, nvtx_label=None): """Quantize (if needed) and all-gather weight. Returns (weight_total, handle).""" @@ -832,7 +867,7 @@ def _all_gather_weight(self, async_op, skip_weight_cast, cast_noop_flag, fwd, nv fp8_pg_hit = GTP_CONFIG.fp8_param_gather and self.did_cast_to_low_precision if use_coalesced: - _quantize_with_coalesced_amax(weights, skip_weight_cast, cast_noop_flag) + _quantize_with_coalesced_amax(weights, cast_noop_flag) elif not fp8_pg_hit: for w in weights: w._quantize_if_needed(skip_weight_cast, cast_noop_flag) @@ -965,6 +1000,9 @@ def _all_gather_weight_on_demand(self, fwd, skip_weight_cast=False, cast_noop_fl return result if self.is_routed_expert else result[0] def _get_prefetched_weight(self, fwd, skip_weight_cast=False, cast_noop_flag=None): + # ``skip_weight_cast`` and ``cast_noop_flag`` are accepted to keep the + # signature symmetric with ``_all_gather_weight_on_demand``. + del skip_weight_cast, cast_noop_flag # Stale-read guard: state must reflect an AG issued for this cycle; # otherwise cache.get() would return the prior iter's AG buffer. if GTP_CONFIG.check_param_states: @@ -1134,13 +1172,15 @@ def batched_all_gather_and_prefetch(self, **kwargs): return self.all_gather_and_prefetch(**kwargs) def get_wgrad_tensor(self): + """Pool-allocate a wgrad scratch tensor of unsharded shape for the bwd GEMM.""" return _wgrad_pool_get(self._unsharded_shape, self.main_grad.dtype, self.device) def register_grad_accum_hook(self, grad_accum_node, hook): - """Register a DDP backward hook to be called from _finalize_wgrad. + """Register a DDP backward hook to be called after the wgrad RS finalize. For GTP params, autograd may receive None (async RS) so the normal grad - accumulator hook never fires. Instead, _finalize_wgrad calls the hook + accumulator hook never fires. Instead, the integrator (Graphed.backward + for captured chains, or the eager chain-tail cascade) calls this hook explicitly after RS wait + gradient accumulation, ensuring DDP's register_grad_ready fires at exactly the right time. @@ -1255,22 +1295,22 @@ def _reduce_scatter(self, wgrads, async_op, nvtx_label=None): ) nvtx_range_pop(f"{nvtx_label}.gtp_rs") return [out], handle - else: - outputs = [] - nvtx_range_push(f"{nvtx_label}.batched_gtp_rs") - with torch.distributed._coalescing_manager( - group=self.group, - device=wgrads[0].device, - async_ops=async_op, - ) as cm: - for out_buffer, tensor in zip(out_buffers, wgrads): - out, _ = reduce_scatter_along_first_dim( - tensor, self.group, output=out_buffer - ) - outputs.append(out) - nvtx_range_pop(f"{nvtx_label}.batched_gtp_rs") - - return outputs, cm if async_op else None + + outputs = [] + nvtx_range_push(f"{nvtx_label}.batched_gtp_rs") + with torch.distributed._coalescing_manager( + group=self.group, + device=wgrads[0].device, + async_ops=async_op, + ) as cm: + for out_buffer, tensor in zip(out_buffers, wgrads): + out, _ = reduce_scatter_along_first_dim( + tensor, self.group, output=out_buffer + ) + outputs.append(out) + nvtx_range_pop(f"{nvtx_label}.batched_gtp_rs") + + return outputs, cm if async_op else None def wgrad_reduce_scatter(self, wgrad, nvtx_label=None): """Reduce-scatter wgrad(s). Sync for last weight, async+deferred for others. @@ -1336,6 +1376,8 @@ def batched_wgrad_reduce_scatter(self, wgrad_list, nvtx_label=None): return self.wgrad_reduce_scatter(wgrad_list, nvtx_label=nvtx_label) def __torch_function__(self, func, types, args=(), kwargs=None): + """Subclass-preserving dispatch for ``detach`` (other ops fall through).""" + del types # required by protocol, unused here if kwargs is None: kwargs = {} @@ -1586,8 +1628,6 @@ def reallocate_to_mempool(self, device, mempool): assert False torch._C._cuda_endAllocateToPool(device, mempool) - return - def get_global_GTP_cache() -> GTPWeightCache: """Get or lazily create the global cache instance.""" @@ -1622,7 +1662,7 @@ def wait_async_comms( main_grad. Runs main_grad.add_ on rs_stream (right after NCCL RS) so it starts during AG drain rather than after, avoiding SM-saturation that blocks cross-graph overlap. - Falls back to caller-stream _finalize_wgrad if no RS handle. + Falls back to caller-stream accumulation if no RS handle. Per-param side effects: * _already_ag_drained = True (if an AG handle was drained) @@ -1641,11 +1681,19 @@ def wait_async_comms( if not skip_rs: param._wait_reduce_scatter(finalize_grad=finalize_after_drain) if finalize_after_drain and not getattr(param, "_already_finalized", False): + # Fallback path: _wait_reduce_scatter ran sync (no handle to + # drain), so it didn't accumulate. Do the inline accumulation + # here, matching the wgrad_reduce_scatter / _wait_reduce_scatter + # finalize-grad pattern (state reset + main_grad.add_ + release). cache = get_global_GTP_cache() param.rs_event.wait() for w in param._weights: - GTPShardedParam._finalize_wgrad(w, cache.get(w._rs_ticket)) + w._set_rs_state(GTPWeightState.NONE) + wgrad_rs = cache.get(w._rs_ticket) + w.main_grad.add_(wgrad_rs) cache.release(w._rs_ticket) + if hasattr(w, "grad_added_to_main_grad"): + w.grad_added_to_main_grad = True param._already_finalized = True diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index 7468b7ddc9..b309c685f0 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -5,7 +5,6 @@ """GroupedLinear API""" from typing import Union, Optional, Callable, Tuple, List from itertools import chain -import traceback import warnings import weakref @@ -117,8 +116,8 @@ def forward( device = inp.device weight_requires_grad = weights[0].requires_grad + weights_gtp_sharded = weights if gtp_size > 1: - weights_gtp_sharded = weights weights = weights[0].batched_all_gather_and_prefetch( fwd=True, skip_weight_cast=is_first_microbatch is False, @@ -390,6 +389,7 @@ def backward( inputmats = saved_tensors[:N] gtp_origin_weights = saved_tensors[N : 2 * N] biases = saved_tensors[2 * N : 3 * N] + weights = None # Restore from weakrefs to get original weight python objects # (preserves attributes like main_grad, grad_added_to_main_grad, etc.) @@ -538,6 +538,7 @@ def backward( # Gathered weights are no longer needed after dgrad GEMM. # For nvfp4, the NVFP4TensorStorage and its sub-tensors (scale_inv etc.) # would otherwise survive until function return via this local ref. + w_shape = None if ctx.gtp_size > 1: w_shape = list(weights[0].size()) del weights @@ -843,7 +844,7 @@ def __init__( self.gtp_size = 1 else: self.gtp_size = get_distributed_world_size(gtp_group) - assert tp_size == 1, f"TODO(shiqingf): GTP+TP is not well supported yet." + assert tp_size == 1, "GroupedLinear with GTP does not support TP > 1 yet." self.parallel_mode = parallel_mode if self.parallel_mode not in GemmParallelModes: diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 11a07daac0..565e0423d9 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -300,8 +300,8 @@ def forward( # Prepare weight tensor # ------------------------------------------------------ + weight_gtp_sharded = weight if gtp_size > 1: - weight_gtp_sharded = weight weight = weight.all_gather_and_prefetch( fwd=True, skip_weight_cast=is_first_microbatch is False, From 5419298bae5ace15e0cf4aacbe94c8465fb1afee Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 18 May 2026 07:32:25 +0000 Subject: [PATCH 4/8] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/pytorch/distributed/test_gtp.py | 29 +++++++++++-------- .../module/generalized_tensor_parallelism.py | 5 ++-- 2 files changed, 19 insertions(+), 15 deletions(-) diff --git a/tests/pytorch/distributed/test_gtp.py b/tests/pytorch/distributed/test_gtp.py index 84e3968dc0..ace6af0448 100644 --- a/tests/pytorch/distributed/test_gtp.py +++ b/tests/pytorch/distributed/test_gtp.py @@ -1507,8 +1507,11 @@ class TestWaitAsyncCommsFallback: """ class _FakeGroup: - def size(self): return 1 - def rank(self): return 0 + def size(self): + return 1 + + def rank(self): + return 0 @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required") def test_fallback_accumulates_when_no_rs_handle(self): @@ -1519,15 +1522,15 @@ def test_fallback_accumulates_when_no_rs_handle(self): p.pad_length = 0 p.chain_id = gtp_module.GTPChain.UNGRAPHED.value p._quantizer = None - p.is_routed_expert = False # ⇒ self._weights property returns [self] + p.is_routed_expert = False # ⇒ self._weights property returns [self] p.main_grad = torch.zeros(8, 4, dtype=dtype, device="cuda") - p._prefetch_handle = None # _wait_param_gather is no-op - p._wgrad_rs_handle = None # _wait_reduce_scatter is no-op → fallback fires + p._prefetch_handle = None # _wait_param_gather is no-op + p._wgrad_rs_handle = None # _wait_reduce_scatter is no-op → fallback fires p._cached_ag_stream = None p._cached_rs_stream = None p.ag_event = torch.cuda.Event(external=True) p.rs_event = torch.cuda.Event(external=True) - p.rs_event.record() # so rs_event.wait() in fallback doesn't block + p.rs_event.record() # so rs_event.wait() in fallback doesn't block p._already_finalized = False p.grad_added_to_main_grad = False @@ -1552,8 +1555,9 @@ def test_fallback_accumulates_when_no_rs_handle(self): gtp_module._inflight_comm_params.update(saved) torch.cuda.synchronize() - assert torch.all(p.main_grad == 2.0), \ - f"main_grad should be 2.0 after fallback accumulation; got {p.main_grad}" + assert torch.all( + p.main_grad == 2.0 + ), f"main_grad should be 2.0 after fallback accumulation; got {p.main_grad}" assert p._already_finalized is True, "_already_finalized must be set" assert p.grad_added_to_main_grad is True, "grad_added_to_main_grad must be set" @@ -1567,7 +1571,7 @@ def test_fallback_skipped_when_already_finalized(self): p.pad_length = 0 p.chain_id = gtp_module.GTPChain.UNGRAPHED.value p._quantizer = None - p.is_routed_expert = False # ⇒ self._weights property returns [self] + p.is_routed_expert = False # ⇒ self._weights property returns [self] # Pre-existing main_grad with a value the fallback must NOT overwrite. p.main_grad = torch.full((8, 4), 5.0, dtype=dtype, device="cuda") p._prefetch_handle = None @@ -1577,7 +1581,7 @@ def test_fallback_skipped_when_already_finalized(self): p.ag_event = torch.cuda.Event(external=True) p.rs_event = torch.cuda.Event(external=True) p.rs_event.record() - p._already_finalized = True # ← short-circuits the fallback + p._already_finalized = True # ← short-circuits the fallback # No _rs_ticket: if the fallback ran it would AttributeError on # cache.get(None). The skip path must not touch the cache at all. @@ -1597,5 +1601,6 @@ def test_fallback_skipped_when_already_finalized(self): gtp_module._inflight_comm_params.update(saved) torch.cuda.synchronize() - assert torch.all(p.main_grad == 5.0), \ - "main_grad must be untouched when _already_finalized=True" \ No newline at end of file + assert torch.all( + p.main_grad == 5.0 + ), "main_grad must be untouched when _already_finalized=True" diff --git a/transformer_engine/pytorch/module/generalized_tensor_parallelism.py b/transformer_engine/pytorch/module/generalized_tensor_parallelism.py index c3ab2414af..5964a41a7b 100644 --- a/transformer_engine/pytorch/module/generalized_tensor_parallelism.py +++ b/transformer_engine/pytorch/module/generalized_tensor_parallelism.py @@ -143,6 +143,7 @@ def classify_gtp_chains(model) -> None: class GTPWeightState(Enum): """State of a GTPShardedParam's AG / RS lifecycle (debug / stale-read guard).""" + NONE = "NONE" # Sharded, no pending operation ASYNC_WAIT = "ASYNC_WAIT" # Async all-gather in progress DATA_READY = "DATA_READY" # Async all-gather complete, result in cache @@ -1304,9 +1305,7 @@ def _reduce_scatter(self, wgrads, async_op, nvtx_label=None): async_ops=async_op, ) as cm: for out_buffer, tensor in zip(out_buffers, wgrads): - out, _ = reduce_scatter_along_first_dim( - tensor, self.group, output=out_buffer - ) + out, _ = reduce_scatter_along_first_dim(tensor, self.group, output=out_buffer) outputs.append(out) nvtx_range_pop(f"{nvtx_label}.batched_gtp_rs") From d5c1b2451f1527f281599a54423279949228d254 Mon Sep 17 00:00:00 2001 From: Shiqing Fan Date: Mon, 18 May 2026 01:19:39 -0700 Subject: [PATCH 5/8] fix lint and comments. Signed-off-by: Shiqing Fan --- tests/pytorch/distributed/test_gtp.py | 52 +++++++++++++++++++ .../pytorch/csrc/extensions/pybind.cpp | 15 ++++-- .../module/generalized_tensor_parallelism.py | 15 ++++-- transformer_engine/pytorch/module/linear.py | 4 +- 4 files changed, 75 insertions(+), 11 deletions(-) diff --git a/tests/pytorch/distributed/test_gtp.py b/tests/pytorch/distributed/test_gtp.py index ace6af0448..fe25809029 100644 --- a/tests/pytorch/distributed/test_gtp.py +++ b/tests/pytorch/distributed/test_gtp.py @@ -1604,3 +1604,55 @@ def test_fallback_skipped_when_already_finalized(self): assert torch.all( p.main_grad == 5.0 ), "main_grad must be untouched when _already_finalized=True" + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required") + def test_fallback_skipped_for_pure_ag_param(self): + """Regression: cross-graph fwd-AG prefetch in flight + finalize_after_drain=True. + + A param can be in _inflight_comm_params because of an outstanding async + all-gather (e.g. a cross-graph forward prefetch reaching the + bwd→optimizer boundary). No reduce-scatter was ever issued for that + param, so _rs_ticket is None on every weight. Previously the fallback + called cache.get(None) and crashed with KeyError; the guard now skips + the inline accumulation entirely when no weight has an RS ticket. + """ + dtype = torch.bfloat16 + p = GTPShardedParam(torch.zeros(8, 4, dtype=dtype, device="cuda")) + p.group = self._FakeGroup() + p.expert_idx = None + p.pad_length = 0 + p.chain_id = gtp_module.GTPChain.UNGRAPHED.value + p._quantizer = None + p.is_routed_expert = False + # Pre-existing main_grad with a sentinel that must survive untouched. + p.main_grad = torch.full((8, 4), 7.0, dtype=dtype, device="cuda") + p._prefetch_handle = None + p._wgrad_rs_handle = None + p._cached_ag_stream = None + p._cached_rs_stream = None + p.ag_event = torch.cuda.Event(external=True) + p.rs_event = torch.cuda.Event(external=True) + p.rs_event.record() + p._already_finalized = False + # Critical: simulates a pure-AG prefetch — no RS ever issued, ticket is None. + p._rs_ticket = None + + saved = set(gtp_module._inflight_comm_params) + gtp_module._inflight_comm_params.clear() + gtp_module._inflight_comm_params.add(p) + try: + # Must NOT raise KeyError(None) from cache.get(None). + gtp_module.wait_async_comms( + chain_id=p.chain_id, + skip_rs=False, + finalize_after_drain=True, + ) + finally: + gtp_module._inflight_comm_params.clear() + gtp_module._inflight_comm_params.update(saved) + + torch.cuda.synchronize() + assert torch.all(p.main_grad == 7.0), \ + "main_grad must be untouched for a pure-AG param (no wgrad to accumulate)" + assert p._already_finalized is False, \ + "_already_finalized must stay False — no finalize happened for a pure-AG param" diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index 696cc51811..7a0722fb51 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -133,10 +133,10 @@ void init_extension() { #include "common/util/pybind_helper.h" -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - NVTE_DECLARE_COMMON_PYBIND11_HANDLES(m) - m.def("quantize", transformer_engine::pytorch::quantize, py::arg("tensor"), py::arg("quantizer"), - py::arg("output") = py::none(), py::arg("noop") = py::none()); +// Bindings for the NVFP4 split-amax fast-path used by GTP (coalesced +// per-expert amax allreduce + cast). Kept separate from PYBIND11_MODULE +// so the latter stays under cpplint's 500-line readability/fn_size limit. +static void RegisterNvfp4AmaxBindings(py::module &m) { m.def("compute_amax_nvfp4", transformer_engine::pytorch::compute_amax_nvfp4, "NVFP4: compute local amax into output's amax buffers; no cast, no allreduce", py::arg("tensor"), py::arg("quantizer"), py::arg("output") = py::none()); @@ -148,6 +148,13 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("compute_multi_amax_nvfp4", transformer_engine::pytorch::compute_multi_amax_nvfp4, "NVFP4: fused multi-tensor amax compute (writes both rowwise+columnwise amax per output)", py::arg("tensor_list"), py::arg("quantizer_list"), py::arg("output_list")); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + NVTE_DECLARE_COMMON_PYBIND11_HANDLES(m) + m.def("quantize", transformer_engine::pytorch::quantize, py::arg("tensor"), py::arg("quantizer"), + py::arg("output") = py::none(), py::arg("noop") = py::none()); + RegisterNvfp4AmaxBindings(m); m.def("dequantize", &transformer_engine::pytorch::dequantize, "Dequantize", py::arg("input"), py::arg("otype")); m.def("create_empty_quantized_tensor", diff --git a/transformer_engine/pytorch/module/generalized_tensor_parallelism.py b/transformer_engine/pytorch/module/generalized_tensor_parallelism.py index 5964a41a7b..8fd5e214a1 100644 --- a/transformer_engine/pytorch/module/generalized_tensor_parallelism.py +++ b/transformer_engine/pytorch/module/generalized_tensor_parallelism.py @@ -1679,11 +1679,16 @@ def wait_async_comms( param._already_ag_drained = True if not skip_rs: param._wait_reduce_scatter(finalize_grad=finalize_after_drain) - if finalize_after_drain and not getattr(param, "_already_finalized", False): - # Fallback path: _wait_reduce_scatter ran sync (no handle to - # drain), so it didn't accumulate. Do the inline accumulation - # here, matching the wgrad_reduce_scatter / _wait_reduce_scatter - # finalize-grad pattern (state reset + main_grad.add_ + release). + # Fallback inline-accumulation: only when finalize is requested, + # _wait_reduce_scatter didn't already finalize, and an RS actually + # ran for this param (rs_ticket set). Skips pure-AG prefetches in + # _inflight_comm_params (no wgrad to accumulate). + need_fallback_accumulation = ( + finalize_after_drain + and not getattr(param, "_already_finalized", False) + and any(w._rs_ticket is not None for w in param._weights) + ) + if need_fallback_accumulation: cache = get_global_GTP_cache() param.rs_event.wait() for w in param._weights: diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 5ea848d6cf..0ebd4adcae 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -154,7 +154,7 @@ class LinearFwdArgs: cpu_offloading: bool is_grad_enabled: bool - # --- Extended tensor parallelism --- + # --- Generalized tensor parallelism --- gtp_size: int = 1 @@ -226,7 +226,7 @@ class LinearBwdArgs: cpu_offloading: bool = False owns_input: bool = False - # --- Extended tensor parallelism --- + # --- Generalized tensor parallelism --- gtp_size: int = 1 # --- Per-backward scratch state (populated inside _linear_backward) --- From c401c68aadf2f40cba13b3506d88faf9928bdd8d Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 18 May 2026 08:22:25 +0000 Subject: [PATCH 6/8] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/pytorch/distributed/test_gtp.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/tests/pytorch/distributed/test_gtp.py b/tests/pytorch/distributed/test_gtp.py index fe25809029..06a9527434 100644 --- a/tests/pytorch/distributed/test_gtp.py +++ b/tests/pytorch/distributed/test_gtp.py @@ -1652,7 +1652,9 @@ def test_fallback_skipped_for_pure_ag_param(self): gtp_module._inflight_comm_params.update(saved) torch.cuda.synchronize() - assert torch.all(p.main_grad == 7.0), \ - "main_grad must be untouched for a pure-AG param (no wgrad to accumulate)" - assert p._already_finalized is False, \ - "_already_finalized must stay False — no finalize happened for a pure-AG param" + assert torch.all( + p.main_grad == 7.0 + ), "main_grad must be untouched for a pure-AG param (no wgrad to accumulate)" + assert ( + p._already_finalized is False + ), "_already_finalized must stay False — no finalize happened for a pure-AG param" From c49bd467c694618d6b0c9a8ec70891a3b932c76c Mon Sep 17 00:00:00 2001 From: Shiqing Fan Date: Mon, 18 May 2026 02:24:17 -0700 Subject: [PATCH 7/8] fix comments. Signed-off-by: Shiqing Fan --- .../module/generalized_tensor_parallelism.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/transformer_engine/pytorch/module/generalized_tensor_parallelism.py b/transformer_engine/pytorch/module/generalized_tensor_parallelism.py index 8fd5e214a1..860ff8a76b 100644 --- a/transformer_engine/pytorch/module/generalized_tensor_parallelism.py +++ b/transformer_engine/pytorch/module/generalized_tensor_parallelism.py @@ -1433,13 +1433,17 @@ class GTPWeightCache: lazily allocates fresh buffers. """ - # Bytes per element for known dtypes (used for logging). + # Bytes per element for known dtypes (used for logging). Add new entries + # here when GTP starts caching buffers of additional quantized dtypes. + # Only DType values guaranteed exposed by the TE pybind bindings — verify + # via ``hasattr(tex.DType, ...)`` before adding speculative entries. _BYTES_PER_ELEMENT = { torch.bfloat16: 2, torch.float16: 2, torch.float32: 4, tex.DType.kFloat4E2M1: 0.5, tex.DType.kFloat8E4M3: 1, + tex.DType.kFloat8E5M2: 1, } def __init__(self): @@ -1455,8 +1459,12 @@ def _buf_bytes(shape, dtype) -> int: numel = 1 for d in shape: numel *= d - bpe = GTPWeightCache._BYTES_PER_ELEMENT.get(dtype, None) - return numel * bpe + if dtype not in GTPWeightCache._BYTES_PER_ELEMENT: + raise KeyError( + f"GTPWeightCache._buf_bytes: unknown dtype {dtype!r}. " + f"Add it to GTPWeightCache._BYTES_PER_ELEMENT with its bytes-per-element." + ) + return int(numel * GTPWeightCache._BYTES_PER_ELEMENT[dtype]) def _allocate_buffer( self, param: "GTPShardedParam", dtype, reduce_scatter, fwd From 783ac5f0bc6d8cd81dfd57ae5e110a7ed749f0eb Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 18 May 2026 09:28:21 +0000 Subject: [PATCH 8/8] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../pytorch/module/generalized_tensor_parallelism.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/module/generalized_tensor_parallelism.py b/transformer_engine/pytorch/module/generalized_tensor_parallelism.py index 860ff8a76b..fbc3efd7f6 100644 --- a/transformer_engine/pytorch/module/generalized_tensor_parallelism.py +++ b/transformer_engine/pytorch/module/generalized_tensor_parallelism.py @@ -1462,7 +1462,7 @@ def _buf_bytes(shape, dtype) -> int: if dtype not in GTPWeightCache._BYTES_PER_ELEMENT: raise KeyError( f"GTPWeightCache._buf_bytes: unknown dtype {dtype!r}. " - f"Add it to GTPWeightCache._BYTES_PER_ELEMENT with its bytes-per-element." + "Add it to GTPWeightCache._BYTES_PER_ELEMENT with its bytes-per-element." ) return int(numel * GTPWeightCache._BYTES_PER_ELEMENT[dtype])