diff --git a/tests/jax/test_fused_router.py b/tests/jax/test_fused_router.py index 89a32f1ce2..ff62bce772 100644 --- a/tests/jax/test_fused_router.py +++ b/tests/jax/test_fused_router.py @@ -559,3 +559,139 @@ def loss_fused_fn(probs_): assert jnp.allclose( grad_ref, grad_fused, atol=1e-5, rtol=1e-5 ), f"Grad mismatch: max diff = {jnp.abs(grad_ref - grad_fused).max()}" + + +# ============================================================================= +# Test: routing_map BITMAP_U8 vs BYTEMAP parity (fwd + bwd) +# ============================================================================= + + +def _bytemap_to_bitmap_u8(bytemap): + """Reference packer: bool[T, E] -> uint8[T, ceil(E/8)] LSB-first.""" + import numpy as np + + flat = np.asarray(bytemap).astype(np.uint8) + return np.packbits(flat, axis=-1, bitorder="little") + + +@pytest_parametrize_wrapper("dtype", DTYPES) +@pytest_parametrize_wrapper( + "num_tokens,num_experts,topk", + TOPK_CASES, +) +@pytest_parametrize_wrapper("score_function", SCORE_FUNCTIONS) +@pytest.mark.triton +def test_topk_bitmap_vs_bytemap(dtype, num_tokens, num_experts, topk, score_function): + """fused_topk_with_score_function should produce the same probs and an + LSB-packed bitmap routing_map when routing_map_format=BITMAP_U8, and + backward gradients should match the bytemap path exactly.""" + from transformer_engine.jax.router import RoutingMapFormat + + logits = make_logits(num_tokens, num_experts, score_function, dtype) + expert_bias = None + + fwd_byte = jax.jit( + partial( + fused_topk_with_score_function, + topk=topk, + score_function=score_function, + expert_bias=expert_bias, + routing_map_format=RoutingMapFormat.BYTEMAP, + ) + ) + fwd_bit = jax.jit( + partial( + fused_topk_with_score_function, + topk=topk, + score_function=score_function, + expert_bias=expert_bias, + routing_map_format=RoutingMapFormat.BITMAP_U8, + ) + ) + probs_byte, routing_map_byte = fwd_byte(logits) + probs_bit, routing_map_bit = fwd_bit(logits) + + assert probs_byte.dtype == probs_bit.dtype + assert jnp.array_equal(probs_byte, probs_bit), "Probs must be identical across formats" + + packed_expected = _bytemap_to_bitmap_u8(routing_map_byte) + assert routing_map_bit.shape == ( + num_tokens, + (num_experts + 7) // 8, + ), f"Bitmap shape {routing_map_bit.shape} != ({num_tokens}, {(num_experts + 7) // 8})" + assert routing_map_bit.dtype == jnp.uint8 + assert jnp.array_equal( + routing_map_bit, packed_expected + ), "Bitmap routing_map disagrees with np.packbits(bytemap, bitorder='little')" + + # Backward parity: grad of probs.sum() must be identical for both formats. + def loss_byte(logits_): + p, _ = fused_topk_with_score_function( + logits_, + topk, + score_function=score_function, + expert_bias=expert_bias, + routing_map_format=RoutingMapFormat.BYTEMAP, + ) + return p.sum() + + def loss_bit(logits_): + p, _ = fused_topk_with_score_function( + logits_, + topk, + score_function=score_function, + expert_bias=expert_bias, + routing_map_format=RoutingMapFormat.BITMAP_U8, + ) + return p.sum() + + grad_byte = jax.jit(jax.grad(loss_byte))(logits) + grad_bit = jax.jit(jax.grad(loss_bit))(logits) + assert jnp.allclose(grad_byte, grad_bit, atol=0.0, rtol=0.0), ( + "Backward grad must be bit-identical across routing_map_format; " + f"max diff = {jnp.abs(grad_byte - grad_bit).max()}" + ) + + +@pytest_parametrize_wrapper("dtype", DTYPES) +@pytest_parametrize_wrapper( + "num_tokens,num_experts,topk", + SCORE_AUX_LOSS_CASES, +) +@pytest_parametrize_wrapper("score_function", SCORE_FUNCTIONS) +@pytest.mark.triton +def test_score_for_aux_loss_bitmap_vs_bytemap(dtype, num_tokens, num_experts, topk, score_function): + """compute_aux_scores=True path: bitmap routing_map must equal LSB-packed + bytemap; scores must be bitwise identical across formats.""" + from transformer_engine.jax.router import RoutingMapFormat + + logits = make_logits(num_tokens, num_experts, score_function, dtype) + + fwd_byte = jax.jit( + partial( + fused_topk_with_score_function, + topk=topk, + score_function=score_function, + compute_aux_scores=True, + routing_map_format=RoutingMapFormat.BYTEMAP, + ) + ) + fwd_bit = jax.jit( + partial( + fused_topk_with_score_function, + topk=topk, + score_function=score_function, + compute_aux_scores=True, + routing_map_format=RoutingMapFormat.BITMAP_U8, + ) + ) + scores_byte, routing_map_byte = fwd_byte(logits) + scores_bit, routing_map_bit = fwd_bit(logits) + + assert jnp.array_equal(scores_byte, scores_bit), "Scores must be identical across formats" + packed_expected = _bytemap_to_bitmap_u8(routing_map_byte) + assert routing_map_bit.shape == (num_tokens, (num_experts + 7) // 8) + assert routing_map_bit.dtype == jnp.uint8 + assert jnp.array_equal( + routing_map_bit, packed_expected + ), "Bitmap routing_map (aux-loss path) disagrees with packed bytemap" diff --git a/tests/pytorch/test_fused_router.py b/tests/pytorch/test_fused_router.py index 274a35b81d..31b9f171da 100644 --- a/tests/pytorch/test_fused_router.py +++ b/tests/pytorch/test_fused_router.py @@ -4,6 +4,7 @@ import torch from typing import Optional from transformer_engine.pytorch.router import ( + RoutingMapFormat, fused_topk_with_score_function, fused_compute_score_for_moe_aux_loss, fused_moe_aux_loss, @@ -458,6 +459,145 @@ def test_fused_moe_aux_loss(dtype, num_tokens, num_experts, topk): torch.testing.assert_close(probs.grad, probs_clone.grad, atol=atol, rtol=rtol) +# ============================================================================= +# Test: routing_map BITMAP_U8 vs BYTEMAP parity (fwd + bwd) +# Mirrors tests/jax/test_fused_router.py::test_topk_bitmap_vs_bytemap. +# ============================================================================= + + +def _bytemap_to_bitmap_u8(bytemap: torch.Tensor) -> torch.Tensor: + """Reference packer: bool[T, E] -> uint8[T, ceil(E/8)] LSB-first. + + Matches numpy.packbits(..., bitorder='little'), which is what the JAX-side + parity test uses. + """ + flat = bytemap.to(torch.uint8).cpu().numpy() + import numpy as np + + return torch.from_numpy(np.packbits(flat, axis=-1, bitorder="little")).to(bytemap.device) + + +@pytest.mark.parametrize("dtype", [torch.float32]) +@pytest.mark.parametrize( + "num_tokens,num_experts,topk", + [(128, 32, 4), (256, 128, 8), (256, 130, 8), (128, 1024, 16)], +) +@pytest.mark.parametrize("score_function", ["softmax", "sigmoid", "sqrtsoftplus"]) +def test_topk_bitmap_vs_bytemap(dtype, num_tokens, num_experts, topk, score_function): + """fused_topk_with_score_function should produce identical probs and an + LSB-packed bitmap routing_map when routing_map_format=BITMAP_U8, and + backward gradients should match the bytemap path exactly.""" + if topk >= num_experts: + pytest.skip(f"topk ({topk}) >= num_experts ({num_experts})") + if score_function in ("sigmoid", "sqrtsoftplus"): + offset = torch.arange(-num_tokens // 2, num_tokens // 2, dtype=dtype, device="cuda") * 1e-4 + logits = ( + torch.arange(-num_experts // 2, num_experts // 2, device="cuda", dtype=dtype) * 1e-2 + ) + logits = logits.unsqueeze(0).repeat(num_tokens, 1) + offset.unsqueeze(1) + else: + logits = ( + torch.arange( + -num_tokens * num_experts // 2, + num_tokens * num_experts // 2, + device="cuda", + dtype=dtype, + ) + * 1e-4 + ) + logits = logits.view(num_tokens, num_experts) + + logits_byte = logits.detach().clone().requires_grad_(True) + logits_bit = logits.detach().clone().requires_grad_(True) + + probs_byte, routing_map_byte = fused_topk_with_score_function( + logits=logits_byte, + topk=topk, + use_pre_softmax=False, + num_groups=None, + group_topk=None, + scaling_factor=None, + score_function=score_function, + expert_bias=None, + routing_map_format=RoutingMapFormat.BYTEMAP, + ) + probs_bit, routing_map_bit = fused_topk_with_score_function( + logits=logits_bit, + topk=topk, + use_pre_softmax=False, + num_groups=None, + group_topk=None, + scaling_factor=None, + score_function=score_function, + expert_bias=None, + routing_map_format=RoutingMapFormat.BITMAP_U8, + ) + + assert probs_byte.dtype == probs_bit.dtype + torch.testing.assert_close(probs_byte, probs_bit, atol=0.0, rtol=0.0) + + expected_shape = (num_tokens, (num_experts + 7) // 8) + assert ( + routing_map_bit.shape == expected_shape + ), f"Bitmap shape {tuple(routing_map_bit.shape)} != {expected_shape}" + assert routing_map_bit.dtype == torch.uint8 + assert routing_map_byte.dtype == torch.bool + + packed_expected = _bytemap_to_bitmap_u8(routing_map_byte) + torch.testing.assert_close(routing_map_bit, packed_expected, atol=0, rtol=0) + + # Backward parity: grad of probs.sum() must be bit-identical across formats. + probs_byte.sum().backward() + probs_bit.sum().backward() + torch.testing.assert_close(logits_byte.grad, logits_bit.grad, atol=0.0, rtol=0.0) + + +@pytest.mark.parametrize("dtype", [torch.float32]) +@pytest.mark.parametrize( + "num_tokens,num_experts,topk", + [(128, 32, 4), (256, 128, 8), (256, 130, 8)], +) +@pytest.mark.parametrize("score_function", ["softmax", "sigmoid", "sqrtsoftplus"]) +def test_score_for_aux_loss_bitmap_vs_bytemap(dtype, num_tokens, num_experts, topk, score_function): + """fused_compute_score_for_moe_aux_loss: bitmap routing_map must equal + LSB-packed bytemap; scores must be bit-identical across formats.""" + if topk >= num_experts: + pytest.skip(f"topk ({topk}) >= num_experts ({num_experts})") + offset = torch.arange(-num_tokens // 2, num_tokens // 2, dtype=dtype, device="cuda") * 1e-4 + logits = torch.arange(-num_experts // 2, num_experts // 2, device="cuda", dtype=dtype) * 1e-2 + logits = logits.unsqueeze(0).repeat(num_tokens, 1) + offset.unsqueeze(1) + + logits_byte = logits.detach().clone().requires_grad_(True) + logits_bit = logits.detach().clone().requires_grad_(True) + + routing_map_byte, scores_byte = fused_compute_score_for_moe_aux_loss( + logits=logits_byte, + topk=topk, + score_function=score_function, + routing_map_format="bytemap", + ) + routing_map_bit, scores_bit = fused_compute_score_for_moe_aux_loss( + logits=logits_bit, + topk=topk, + score_function=score_function, + routing_map_format="bitmap_u8", + ) + + torch.testing.assert_close(scores_byte, scores_bit, atol=0.0, rtol=0.0) + + expected_shape = (num_tokens, (num_experts + 7) // 8) + assert routing_map_bit.shape == expected_shape + assert routing_map_bit.dtype == torch.uint8 + assert routing_map_byte.dtype == torch.bool + packed_expected = _bytemap_to_bitmap_u8(routing_map_byte) + torch.testing.assert_close(routing_map_bit, packed_expected, atol=0, rtol=0) + + # Backward parity through scores. + scores_byte.sum().backward() + scores_bit.sum().backward() + torch.testing.assert_close(logits_byte.grad, logits_bit.grad, atol=0.0, rtol=0.0) + + def profile_topk_softmax( dtype, num_tokens, diff --git a/transformer_engine/common/fused_router/fused_score_for_moe_aux_loss.cu b/transformer_engine/common/fused_router/fused_score_for_moe_aux_loss.cu index 4eb4240d7c..d003b46d2b 100644 --- a/transformer_engine/common/fused_router/fused_score_for_moe_aux_loss.cu +++ b/transformer_engine/common/fused_router/fused_score_for_moe_aux_loss.cu @@ -16,12 +16,14 @@ namespace transformer_engine { namespace fused_router { -template +template __global__ void fused_score_for_moe_aux_loss_forward_kernel(const DataType *logits, int num_tokens, int num_experts, int topk, int score_function, float *scores, - bool *routing_map, + uint8_t *routing_map, CompType *intermediate_output) { + constexpr bool kIsBitmap = (RoutingMapFormat == NVTE_ROUTING_MAP_FORMAT_BITMAP_U8); /*** * Section: Global Variables/Addresses init * - Each warp is responsible for one token, and has own shared memory buffer. @@ -36,10 +38,20 @@ __global__ void fused_score_for_moe_aux_loss_forward_kernel(const DataType *logi CompType *topk_logits_buf = reinterpret_cast(logits_buf + num_experts * num_token_per_block); int *topk_indices_buf = reinterpret_cast(topk_logits_buf + topk * num_token_per_block); + // Per-warp bitmap accumulator (only used in BITMAP_U8 mode). See the matching + // comment in fused_topk_with_score_function.cu for the uint32 vs uint8 layout note. + const int bitmap_words_per_warp = (num_experts + 31) / 32; + const int bitmap_row_bytes = (num_experts + 7) / 8; + uint32_t *bitmap_words_buf = nullptr; + if constexpr (kIsBitmap) { + bitmap_words_buf = reinterpret_cast(topk_indices_buf + topk * num_token_per_block); + } // The address of buffers on the current warp CompType *local_logits = logits_buf + warp_id * num_experts; CompType *topk_logits = topk_logits_buf + warp_id * topk; int *topk_indices = topk_indices_buf + warp_id * topk; + uint32_t *local_bitmap_words = + (bitmap_words_buf != nullptr) ? bitmap_words_buf + warp_id * bitmap_words_per_warp : nullptr; /*** * Section: Main Loop @@ -58,13 +70,23 @@ __global__ void fused_score_for_moe_aux_loss_forward_kernel(const DataType *logi * - Load the logits to shmem */ int pos_offset = token_offset_cur_warp * num_experts; - // Clear the routing_map (num_experts) - for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { - routing_map[pos_offset + i] = false; - if (score_function == 1) { + // Clear the routing_map. In BYTEMAP mode this clears the row in global memory; + // in BITMAP_U8 mode the row is accumulated in shmem and written wholesale at + // the end of the loop, so no global clear is required. + if (score_function == 1) { + for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { intermediate_output[pos_offset + i] = -std::numeric_limits::infinity(); } } + if constexpr (!kIsBitmap) { + for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { + routing_map[pos_offset + i] = 0; + } + } else { + for (int j = lane_id; j < bitmap_words_per_warp; j += kThreadsPerWarp) { + local_bitmap_words[j] = 0u; + } + } // Load the logits to shmem for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { local_logits[i] = static_cast(logits[pos_offset + i]); @@ -127,8 +149,22 @@ __global__ void fused_score_for_moe_aux_loss_forward_kernel(const DataType *logi __syncwarp(); // Write the routing_map to the output tensor - for (int i = lane_id; i < topk; i += kThreadsPerWarp) { - routing_map[pos_offset + topk_indices[i]] = true; + if constexpr (!kIsBitmap) { + for (int i = lane_id; i < topk; i += kThreadsPerWarp) { + routing_map[pos_offset + topk_indices[i]] = 1; + } + } else { + for (int i = lane_id; i < topk; i += kThreadsPerWarp) { + int e = topk_indices[i]; + atomicOr(&local_bitmap_words[e / 32], 1u << (e % 32)); + } + __syncwarp(); + uint8_t *bitmap_row = + routing_map + static_cast(token_offset_cur_warp) * bitmap_row_bytes; + const uint8_t *local_bitmap_bytes = reinterpret_cast(local_bitmap_words); + for (int j = lane_id; j < bitmap_row_bytes; j += kThreadsPerWarp) { + bitmap_row[j] = local_bitmap_bytes[j]; + } } // Write the scores to the output tensor for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { @@ -139,32 +175,39 @@ __global__ void fused_score_for_moe_aux_loss_forward_kernel(const DataType *logi } } -template +template void fused_score_for_moe_aux_loss_forward_kernel_launcher( const DataType *logits, int num_tokens, int num_experts, int topk, int score_function, - float *scores, bool *routing_map, CompType *intermediate_output, cudaStream_t stream) { + float *scores, uint8_t *routing_map, CompType *intermediate_output, cudaStream_t stream) { // Meta data for the kernel size_t num_token_per_block = kThreadsPerBlock / kThreadsPerWarp; size_t grid_size = (num_tokens + num_token_per_block - 1) / num_token_per_block; size_t shared_memory_size = num_experts * num_token_per_block * sizeof(CompType) // logits + topk * num_token_per_block * sizeof(CompType) // topk_logits + topk * num_token_per_block * sizeof(int); // topk_indices + if constexpr (RoutingMapFormat == NVTE_ROUTING_MAP_FORMAT_BITMAP_U8) { + size_t bitmap_words_per_warp = (num_experts + 31) / 32; + shared_memory_size += + bitmap_words_per_warp * num_token_per_block * sizeof(uint32_t); // bitmap accumulator + } check_shared_memory_capacity_num_experts(shared_memory_size, num_experts); // Radix selection is O(E), independent of K, but it needs 4 passes for 32-bit float; // switch at K=16 where naive O(K^2*E) starts to dominate if (topk < 16) { - NVTE_CHECK_CUDA(cudaFuncSetAttribute( - fused_score_for_moe_aux_loss_forward_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, shared_memory_size)); - fused_score_for_moe_aux_loss_forward_kernel + NVTE_CHECK_CUDA( + cudaFuncSetAttribute(fused_score_for_moe_aux_loss_forward_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, shared_memory_size)); + fused_score_for_moe_aux_loss_forward_kernel <<>>( logits, num_tokens, num_experts, topk, score_function, scores, routing_map, intermediate_output); } else { - NVTE_CHECK_CUDA(cudaFuncSetAttribute( - fused_score_for_moe_aux_loss_forward_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, shared_memory_size)); - fused_score_for_moe_aux_loss_forward_kernel + NVTE_CHECK_CUDA( + cudaFuncSetAttribute(fused_score_for_moe_aux_loss_forward_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, shared_memory_size)); + fused_score_for_moe_aux_loss_forward_kernel <<>>( logits, num_tokens, num_experts, topk, score_function, scores, routing_map, intermediate_output); @@ -172,17 +215,55 @@ void fused_score_for_moe_aux_loss_forward_kernel_launcher( NVTE_CHECK_CUDA(cudaGetLastError()); } +// Build the expected routing_map shape for a given NVTERoutingMapFormat. +// BYTEMAP -> [num_tokens, num_experts] +// BITMAP_U8 -> [num_tokens, ceil(num_experts/8)] +static std::vector expected_routing_map_shape(int num_tokens, int num_experts, + NVTERoutingMapFormat format) { + const size_t t = static_cast(num_tokens); + const size_t e = static_cast(num_experts); + if (format == NVTE_ROUTING_MAP_FORMAT_BITMAP_U8) { + return {t, (e + 7) / 8}; + } + return {t, e}; +} + void fused_score_for_moe_aux_loss_forward(const Tensor &logits, int num_tokens, int num_experts, int topk, int score_function, Tensor &scores, - Tensor &routing_map, Tensor &intermediate_output, - cudaStream_t stream) { - TE_ROUTER_PROBS_TYPE_SWITCH_ALL( - logits.data.dtype, DataType, - fused_score_for_moe_aux_loss_forward_kernel_launcher( - reinterpret_cast(logits.data.dptr), num_tokens, num_experts, topk, - score_function, reinterpret_cast(scores.data.dptr), - reinterpret_cast(routing_map.data.dptr), + Tensor &routing_map, + NVTERoutingMapFormat routing_map_format, + Tensor &intermediate_output, cudaStream_t stream) { + NVTE_CHECK(num_tokens > 0 && num_experts > 0, + "num_tokens and num_experts must be positive; got num_tokens=", num_tokens, + ", num_experts=", num_experts); + const std::vector dense_shape{static_cast(num_tokens), + static_cast(num_experts)}; + NVTE_CHECK(logits.data.shape == dense_shape, "logits shape must be [num_tokens, num_experts]=[", + num_tokens, ", ", num_experts, "], got ", logits.data.shape); + NVTE_CHECK(scores.data.shape == dense_shape, "scores shape must be [num_tokens, num_experts]=[", + num_tokens, ", ", num_experts, "], got ", scores.data.shape); + NVTE_CHECK(intermediate_output.data.shape == dense_shape, + "intermediate_output shape must be [num_tokens, num_experts]=[", num_tokens, ", ", + num_experts, "], got ", intermediate_output.data.shape); + const auto routing_map_shape = + expected_routing_map_shape(num_tokens, num_experts, routing_map_format); + NVTE_CHECK(routing_map.data.shape == routing_map_shape, "routing_map shape mismatch for ", + (routing_map_format == NVTE_ROUTING_MAP_FORMAT_BITMAP_U8 ? "BITMAP_U8" : "BYTEMAP"), + "; expected ", routing_map_shape, ", got ", routing_map.data.shape); +#define AUX_LOSS_FORWARD_DISPATCH(RoutingMapFormatVal) \ + TE_ROUTER_PROBS_TYPE_SWITCH_ALL( \ + logits.data.dtype, DataType, \ + fused_score_for_moe_aux_loss_forward_kernel_launcher( \ + reinterpret_cast(logits.data.dptr), num_tokens, num_experts, topk, \ + score_function, reinterpret_cast(scores.data.dptr), \ + reinterpret_cast(routing_map.data.dptr), \ reinterpret_cast(intermediate_output.data.dptr), stream);); + if (routing_map_format == NVTE_ROUTING_MAP_FORMAT_BITMAP_U8) { + AUX_LOSS_FORWARD_DISPATCH(NVTE_ROUTING_MAP_FORMAT_BITMAP_U8) + } else { + AUX_LOSS_FORWARD_DISPATCH(NVTE_ROUTING_MAP_FORMAT_BYTEMAP) + } +#undef AUX_LOSS_FORWARD_DISPATCH } template @@ -347,17 +428,31 @@ void fused_score_for_moe_aux_loss_backward(const Tensor &intermediate_output, } // namespace fused_router } // namespace transformer_engine +void nvte_fused_score_for_moe_aux_loss_forward_v2(const NVTETensor logits, int num_tokens, + int num_experts, int topk, int score_function, + NVTETensor scores, NVTETensor routing_map, + NVTERoutingMapFormat routing_map_format, + const NVTETensor intermediate_output, + cudaStream_t stream) { + NVTE_API_CALL(nvte_fused_score_for_moe_aux_loss_forward_v2); + using namespace transformer_engine; + fused_router::fused_score_for_moe_aux_loss_forward( + *convertNVTETensorCheck(logits), num_tokens, num_experts, topk, score_function, + *convertNVTETensorCheck(scores), *convertNVTETensorCheck(routing_map), routing_map_format, + *convertNVTETensorCheck(intermediate_output), stream); +} + +// Deprecated V1 entry point: forwards to the V2 above with the BYTEMAP layout. +// Kept for ABI compatibility with external C API consumers. void nvte_fused_score_for_moe_aux_loss_forward(const NVTETensor logits, int num_tokens, int num_experts, int topk, int score_function, - NVTETensor scores, const NVTETensor routing_map, + NVTETensor scores, NVTETensor routing_map, const NVTETensor intermediate_output, cudaStream_t stream) { NVTE_API_CALL(nvte_fused_score_for_moe_aux_loss_forward); - using namespace transformer_engine; - fused_router::fused_score_for_moe_aux_loss_forward( - *convertNVTETensorCheck(logits), num_tokens, num_experts, topk, score_function, - *convertNVTETensorCheck(scores), *convertNVTETensorCheck(routing_map), - *convertNVTETensorCheck(intermediate_output), stream); + nvte_fused_score_for_moe_aux_loss_forward_v2( + logits, num_tokens, num_experts, topk, score_function, scores, routing_map, + NVTE_ROUTING_MAP_FORMAT_BYTEMAP, intermediate_output, stream); } void nvte_fused_score_for_moe_aux_loss_backward(const NVTETensor intermediate_output, diff --git a/transformer_engine/common/fused_router/fused_topk_with_score_function.cu b/transformer_engine/common/fused_router/fused_topk_with_score_function.cu index 9f7a830546..4d004583ac 100644 --- a/transformer_engine/common/fused_router/fused_topk_with_score_function.cu +++ b/transformer_engine/common/fused_router/fused_topk_with_score_function.cu @@ -15,12 +15,14 @@ namespace transformer_engine { namespace fused_router { -template +template __global__ void fused_topk_with_score_function_forward_kernel( const DataType *logits, int num_tokens, int num_experts, int topk, bool use_pre_softmax, int num_groups, int group_topk, float scaling_factor, int score_function, - const BiasType *expert_bias, DataType *probs, bool *routing_map, + const BiasType *expert_bias, DataType *probs, uint8_t *routing_map, CompType *intermediate_output) { + constexpr bool kIsBitmap = (RoutingMapFormat == NVTE_ROUTING_MAP_FORMAT_BITMAP_U8); /*** * Section: Global Variables/Addresses init * - Each warp is responsible for one token, and has own shared memory buffer. @@ -42,12 +44,24 @@ __global__ void fused_topk_with_score_function_forward_kernel( } else { topk_indices_buf = reinterpret_cast(topk_scores_buf + topk * num_token_per_block); } + // Per-warp bitmap accumulator (only used in BITMAP_U8 mode). One uint32 per 32 + // experts; final per-token row is copied byte-by-byte to the global uint8 bitmap. + // uint32 packing is bit-for-bit equivalent to uint8 LSB-first packing on + // little-endian devices, which CUDA always is. + const int bitmap_words_per_warp = (num_experts + 31) / 32; + const int bitmap_row_bytes = (num_experts + 7) / 8; + uint32_t *bitmap_words_buf = nullptr; + if constexpr (kIsBitmap) { + bitmap_words_buf = reinterpret_cast(topk_indices_buf + topk * num_token_per_block); + } // The address of buffers on the current warp CompType *scores = scores_buf + warp_id * num_experts; CompType *topk_scores = topk_scores_buf + warp_id * topk; CompType *masked_scores = masked_scores_buf + warp_id * num_experts; CompType *group_scores = group_scores_buf + warp_id * num_groups; int *topk_indices = topk_indices_buf + warp_id * topk; + uint32_t *local_bitmap_words = + (bitmap_words_buf != nullptr) ? bitmap_words_buf + warp_id * bitmap_words_per_warp : nullptr; /*** * Section: Main Loop @@ -66,14 +80,24 @@ __global__ void fused_topk_with_score_function_forward_kernel( * - Load the logits to shmem */ int pos_offset = token_offset_cur_warp * num_experts; - // Clear the probs/routing_map (num_experts) + // Clear the probs (num_experts). In BYTEMAP mode the routing_map row is also + // cleared here; in BITMAP_U8 mode the row is accumulated in shmem and written + // wholesale at the end, so no global clear is required. for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { probs[pos_offset + i] = 0.0; - routing_map[pos_offset + i] = false; if (score_function == 1) { intermediate_output[pos_offset + i] = -std::numeric_limits::infinity(); } } + if constexpr (!kIsBitmap) { + for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { + routing_map[pos_offset + i] = 0; + } + } else { + for (int j = lane_id; j < bitmap_words_per_warp; j += kThreadsPerWarp) { + local_bitmap_words[j] = 0u; + } + } // Load the logits to shmem for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { scores[i] = logits[pos_offset + i]; @@ -229,21 +253,39 @@ __global__ void fused_topk_with_score_function_forward_kernel( } // Write the probs/routing_map to the output tensor - for (int i = lane_id; i < topk; i += kThreadsPerWarp) { - routing_map[pos_offset + topk_indices[i]] = true; - probs[pos_offset + topk_indices[i]] = scaling_factor * topk_scores[i]; + if constexpr (!kIsBitmap) { + for (int i = lane_id; i < topk; i += kThreadsPerWarp) { + routing_map[pos_offset + topk_indices[i]] = 1; + probs[pos_offset + topk_indices[i]] = scaling_factor * topk_scores[i]; + } + } else { + // BITMAP_U8: OR the selected-expert bit into the per-warp uint32 accumulator + // (shmem atomicOr handles same-word collisions across the topk lanes), then + // copy the bytemap-equivalent bytes out to the global uint8 bitmap row. + for (int i = lane_id; i < topk; i += kThreadsPerWarp) { + int e = topk_indices[i]; + atomicOr(&local_bitmap_words[e / 32], 1u << (e % 32)); + probs[pos_offset + e] = scaling_factor * topk_scores[i]; + } + __syncwarp(); + uint8_t *bitmap_row = + routing_map + static_cast(token_offset_cur_warp) * bitmap_row_bytes; + const uint8_t *local_bitmap_bytes = reinterpret_cast(local_bitmap_words); + for (int j = lane_id; j < bitmap_row_bytes; j += kThreadsPerWarp) { + bitmap_row[j] = local_bitmap_bytes[j]; + } } __threadfence_block(); __syncwarp(); } } -template +template void fused_topk_with_score_function_forward_kernel_launcher( const DataType *logits, int num_tokens, int num_experts, int topk, bool use_pre_softmax, int num_groups, int group_topk, float scaling_factor, int score_function, - const BiasType *expert_bias, DataType *probs, bool *routing_map, CompType *intermediate_output, - cudaStream_t stream) { + const BiasType *expert_bias, DataType *probs, uint8_t *routing_map, + CompType *intermediate_output, cudaStream_t stream) { size_t num_token_per_block = kThreadsPerBlock / kThreadsPerWarp; size_t grid_size = (num_tokens + num_token_per_block - 1) / num_token_per_block; size_t shared_memory_size = num_experts * num_token_per_block * sizeof(CompType) // scores @@ -253,22 +295,31 @@ void fused_topk_with_score_function_forward_kernel_launcher( shared_memory_size += num_groups * num_token_per_block * sizeof(CompType); // group_scores shared_memory_size += num_experts * num_token_per_block * sizeof(CompType); // maksed_scores } + if constexpr (RoutingMapFormat == NVTE_ROUTING_MAP_FORMAT_BITMAP_U8) { + size_t bitmap_words_per_warp = (num_experts + 31) / 32; + shared_memory_size += + bitmap_words_per_warp * num_token_per_block * sizeof(uint32_t); // bitmap accumulator + } check_shared_memory_capacity_num_experts(shared_memory_size, num_experts); // Radix selection is O(E), independent of K, but it needs 4 passes for 32-bit float; // switch at K=16 where naive O(K^2*E) starts to dominate if (topk < 16) { NVTE_CHECK_CUDA(cudaFuncSetAttribute( - fused_topk_with_score_function_forward_kernel, + fused_topk_with_score_function_forward_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, shared_memory_size)); - fused_topk_with_score_function_forward_kernel + fused_topk_with_score_function_forward_kernel <<>>( logits, num_tokens, num_experts, topk, use_pre_softmax, num_groups, group_topk, scaling_factor, score_function, expert_bias, probs, routing_map, intermediate_output); } else { NVTE_CHECK_CUDA(cudaFuncSetAttribute( - fused_topk_with_score_function_forward_kernel, + fused_topk_with_score_function_forward_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, shared_memory_size)); - fused_topk_with_score_function_forward_kernel + fused_topk_with_score_function_forward_kernel <<>>( logits, num_tokens, num_experts, topk, use_pre_softmax, num_groups, group_topk, scaling_factor, score_function, expert_bias, probs, routing_map, intermediate_output); @@ -276,34 +327,79 @@ void fused_topk_with_score_function_forward_kernel_launcher( NVTE_CHECK_CUDA(cudaGetLastError()); } +// Build the expected routing_map shape for a given NVTERoutingMapFormat. +// BYTEMAP -> [num_tokens, num_experts] +// BITMAP_U8 -> [num_tokens, ceil(num_experts/8)] +static std::vector expected_routing_map_shape(int num_tokens, int num_experts, + NVTERoutingMapFormat format) { + const size_t t = static_cast(num_tokens); + const size_t e = static_cast(num_experts); + if (format == NVTE_ROUTING_MAP_FORMAT_BITMAP_U8) { + return {t, (e + 7) / 8}; + } + return {t, e}; +} + void fused_topk_with_score_function_forward(const Tensor logits, int num_tokens, int num_experts, int topk, bool use_pre_softmax, int num_groups, int group_topk, float scaling_factor, int score_function, const Tensor expert_bias, Tensor probs, Tensor routing_map, + NVTERoutingMapFormat routing_map_format, Tensor intermediate_output, cudaStream_t stream) { - TE_ROUTER_PROBS_TYPE_SWITCH_ALL( - logits.data.dtype, DataType, - TE_ROUTER_PROBS_TYPE_SWITCH_ALL( - expert_bias.data.dtype, BiasType, - fused_topk_with_score_function_forward_kernel_launcher( - reinterpret_cast(logits.data.dptr), num_tokens, num_experts, topk, - use_pre_softmax, num_groups, group_topk, scaling_factor, score_function, - reinterpret_cast(expert_bias.data.dptr), - reinterpret_cast(probs.data.dptr), - reinterpret_cast(routing_map.data.dptr), + NVTE_CHECK(num_tokens > 0 && num_experts > 0, + "num_tokens and num_experts must be positive; got num_tokens=", num_tokens, + ", num_experts=", num_experts); + const std::vector dense_shape{static_cast(num_tokens), + static_cast(num_experts)}; + NVTE_CHECK(logits.data.shape == dense_shape, "logits shape must be [num_tokens, num_experts]=[", + num_tokens, ", ", num_experts, "], got ", logits.data.shape); + NVTE_CHECK(probs.data.shape == dense_shape, "probs shape must be [num_tokens, num_experts]=[", + num_tokens, ", ", num_experts, "], got ", probs.data.shape); + NVTE_CHECK(intermediate_output.data.shape == dense_shape, + "intermediate_output shape must be [num_tokens, num_experts]=[", num_tokens, ", ", + num_experts, "], got ", intermediate_output.data.shape); + const auto routing_map_shape = + expected_routing_map_shape(num_tokens, num_experts, routing_map_format); + NVTE_CHECK(routing_map.data.shape == routing_map_shape, "routing_map shape mismatch for ", + (routing_map_format == NVTE_ROUTING_MAP_FORMAT_BITMAP_U8 ? "BITMAP_U8" : "BYTEMAP"), + "; expected ", routing_map_shape, ", got ", routing_map.data.shape); + if (expert_bias.data.dptr != nullptr) { + NVTE_CHECK(expert_bias.data.shape == std::vector{static_cast(num_experts)}, + "expert_bias shape must be [num_experts]=[", num_experts, "], got ", + expert_bias.data.shape); + } +#define ROUTER_FORWARD_DISPATCH(RoutingMapFormatVal) \ + TE_ROUTER_PROBS_TYPE_SWITCH_ALL( \ + logits.data.dtype, DataType, \ + TE_ROUTER_PROBS_TYPE_SWITCH_ALL( \ + expert_bias.data.dtype, BiasType, \ + fused_topk_with_score_function_forward_kernel_launcher( \ + reinterpret_cast(logits.data.dptr), num_tokens, num_experts, topk, \ + use_pre_softmax, num_groups, group_topk, scaling_factor, score_function, \ + reinterpret_cast(expert_bias.data.dptr), \ + reinterpret_cast(probs.data.dptr), \ + reinterpret_cast(routing_map.data.dptr), \ reinterpret_cast(intermediate_output.data.dptr), stream););); + if (routing_map_format == NVTE_ROUTING_MAP_FORMAT_BITMAP_U8) { + ROUTER_FORWARD_DISPATCH(NVTE_ROUTING_MAP_FORMAT_BITMAP_U8) + } else { + ROUTER_FORWARD_DISPATCH(NVTE_ROUTING_MAP_FORMAT_BYTEMAP) + } +#undef ROUTER_FORWARD_DISPATCH } -template +template __global__ void fused_topk_with_score_function_backward_kernel( // Inputs tensor - const bool *routing_map, const CompType *intermediate_output, const DataType *grad_probs, + const uint8_t *routing_map, const CompType *intermediate_output, const DataType *grad_probs, // Other parameters int num_tokens, int num_experts, int topk, bool use_pre_softmax, float scaling_factor, int score_function, // Output tensor DataType *grad_logits) { + constexpr bool kIsBitmap = (RoutingMapFormat == NVTE_ROUTING_MAP_FORMAT_BITMAP_U8); /*** * Section: Global Variables/Addresses init * - Each warp is responsible for one token, and has own shared memory buffer. @@ -343,11 +439,19 @@ __global__ void fused_topk_with_score_function_backward_kernel( * - Load the dgrad/output_from_fwd to shmem */ int pos_offset = token_offset_cur_warp * num_experts; - // Load the dgrad/output_from_fwd to shmem + // Load the dgrad/output_from_fwd to shmem. The routing_map source layout + // depends on the RoutingMapFormat template parameter (see NVTERoutingMapFormat). + const int bitmap_row_bytes = (num_experts + 7) / 8; + const uint8_t *bitmap_row = + routing_map + static_cast(token_offset_cur_warp) * bitmap_row_bytes; for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { local_grad[i] = grad_probs[pos_offset + i]; local_act_from_fwd[i] = intermediate_output[pos_offset + i]; - local_routing_map[i] = routing_map[pos_offset + i]; + if constexpr (!kIsBitmap) { + local_routing_map[i] = routing_map[pos_offset + i] != 0; + } else { + local_routing_map[i] = (bitmap_row[i / 8] >> (i % 8)) & 1u; + } } __threadfence_block(); __syncwarp(); @@ -469,9 +573,9 @@ __global__ void fused_topk_with_score_function_backward_kernel( } } -template +template void fused_topk_with_score_function_backward_kernel_launcher( - const bool *routing_map, const CompType *intermediate_output, const DataType *grad_probs, + const uint8_t *routing_map, const CompType *intermediate_output, const DataType *grad_probs, int num_tokens, int num_experts, int topk, bool use_pre_softmax, float scaling_factor, int score_function, DataType *grad_logits, cudaStream_t stream) { // Meta data for the kernel @@ -483,10 +587,10 @@ void fused_topk_with_score_function_backward_kernel_launcher( + num_experts * num_token_per_block * sizeof(CompType) // comp_buf + num_experts * num_token_per_block * sizeof(bool); // routing_map check_shared_memory_capacity_num_experts(shared_memory_size, num_experts); - NVTE_CHECK_CUDA(cudaFuncSetAttribute(fused_topk_with_score_function_backward_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, - shared_memory_size)); - fused_topk_with_score_function_backward_kernel + NVTE_CHECK_CUDA(cudaFuncSetAttribute( + fused_topk_with_score_function_backward_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, shared_memory_size)); + fused_topk_with_score_function_backward_kernel <<>>( routing_map, intermediate_output, grad_probs, num_tokens, num_experts, topk, use_pre_softmax, scaling_factor, score_function, grad_logits); @@ -494,38 +598,98 @@ void fused_topk_with_score_function_backward_kernel_launcher( } void fused_topk_with_score_function_backward(const Tensor &routing_map, + NVTERoutingMapFormat routing_map_format, const Tensor &intermediate_output, const Tensor &grad_probs, int num_tokens, int num_experts, int topk, bool use_pre_softmax, float scaling_factor, int score_function, Tensor &grad_logits, cudaStream_t stream) { - TE_ROUTER_PROBS_TYPE_SWITCH_ALL( - grad_logits.data.dtype, DataType, - fused_topk_with_score_function_backward_kernel_launcher( - reinterpret_cast(routing_map.data.dptr), - reinterpret_cast(intermediate_output.data.dptr), - reinterpret_cast(grad_probs.data.dptr), num_tokens, num_experts, topk, - use_pre_softmax, scaling_factor, score_function, + NVTE_CHECK(num_tokens > 0 && num_experts > 0, + "num_tokens and num_experts must be positive; got num_tokens=", num_tokens, + ", num_experts=", num_experts); + const std::vector dense_shape{static_cast(num_tokens), + static_cast(num_experts)}; + NVTE_CHECK(intermediate_output.data.shape == dense_shape, + "intermediate_output shape must be [num_tokens, num_experts]=[", num_tokens, ", ", + num_experts, "], got ", intermediate_output.data.shape); + NVTE_CHECK(grad_probs.data.shape == dense_shape, + "grad_probs shape must be [num_tokens, num_experts]=[", num_tokens, ", ", num_experts, + "], got ", grad_probs.data.shape); + NVTE_CHECK(grad_logits.data.shape == dense_shape, + "grad_logits shape must be [num_tokens, num_experts]=[", num_tokens, ", ", num_experts, + "], got ", grad_logits.data.shape); + const auto routing_map_shape = + expected_routing_map_shape(num_tokens, num_experts, routing_map_format); + NVTE_CHECK(routing_map.data.shape == routing_map_shape, "routing_map shape mismatch for ", + (routing_map_format == NVTE_ROUTING_MAP_FORMAT_BITMAP_U8 ? "BITMAP_U8" : "BYTEMAP"), + "; expected ", routing_map_shape, ", got ", routing_map.data.shape); +#define ROUTER_BACKWARD_DISPATCH(RoutingMapFormatVal) \ + TE_ROUTER_PROBS_TYPE_SWITCH_ALL( \ + grad_logits.data.dtype, DataType, \ + fused_topk_with_score_function_backward_kernel_launcher( \ + reinterpret_cast(routing_map.data.dptr), \ + reinterpret_cast(intermediate_output.data.dptr), \ + reinterpret_cast(grad_probs.data.dptr), num_tokens, num_experts, topk, \ + use_pre_softmax, scaling_factor, score_function, \ reinterpret_cast(grad_logits.data.dptr), stream);); + if (routing_map_format == NVTE_ROUTING_MAP_FORMAT_BITMAP_U8) { + ROUTER_BACKWARD_DISPATCH(NVTE_ROUTING_MAP_FORMAT_BITMAP_U8) + } else { + ROUTER_BACKWARD_DISPATCH(NVTE_ROUTING_MAP_FORMAT_BYTEMAP) + } +#undef ROUTER_BACKWARD_DISPATCH } } // namespace fused_router } // namespace transformer_engine -void nvte_fused_topk_with_score_function_forward( +void nvte_fused_topk_with_score_function_forward_v2( const NVTETensor logits, int num_tokens, int num_experts, int topk, int use_pre_softmax, int num_groups, int group_topk, float scaling_factor, int score_function, const NVTETensor expert_bias, NVTETensor probs, NVTETensor routing_map, - NVTETensor intermediate_output, cudaStream_t stream) { - NVTE_API_CALL(nvte_fused_topk_with_score_function_forward); + NVTERoutingMapFormat routing_map_format, NVTETensor intermediate_output, cudaStream_t stream) { + NVTE_API_CALL(nvte_fused_topk_with_score_function_forward_v2); using namespace transformer_engine; fused_router::fused_topk_with_score_function_forward( *convertNVTETensorCheck(logits), num_tokens, num_experts, topk, static_cast(use_pre_softmax), num_groups, group_topk, scaling_factor, score_function, *convertNVTETensorCheck(expert_bias), *convertNVTETensorCheck(probs), - *convertNVTETensorCheck(routing_map), *convertNVTETensorCheck(intermediate_output), stream); + *convertNVTETensorCheck(routing_map), routing_map_format, + *convertNVTETensorCheck(intermediate_output), stream); } +// Deprecated V1 entry point: forwards to the V2 above with the BYTEMAP layout. +// Kept for ABI compatibility with external C API consumers. +void nvte_fused_topk_with_score_function_forward( + const NVTETensor logits, int num_tokens, int num_experts, int topk, int use_pre_softmax, + int num_groups, int group_topk, float scaling_factor, int score_function, + const NVTETensor expert_bias, NVTETensor probs, NVTETensor routing_map, + NVTETensor intermediate_output, cudaStream_t stream) { + NVTE_API_CALL(nvte_fused_topk_with_score_function_forward); + nvte_fused_topk_with_score_function_forward_v2( + logits, num_tokens, num_experts, topk, use_pre_softmax, num_groups, group_topk, + scaling_factor, score_function, expert_bias, probs, routing_map, + NVTE_ROUTING_MAP_FORMAT_BYTEMAP, intermediate_output, stream); +} + +void nvte_fused_topk_with_score_function_backward_v2(const NVTETensor routing_map, + NVTERoutingMapFormat routing_map_format, + const NVTETensor intermediate_output, + const NVTETensor grad_probs, int num_tokens, + int num_experts, int topk, int use_pre_softmax, + float scaling_factor, int score_function, + NVTETensor grad_logits, cudaStream_t stream) { + NVTE_API_CALL(nvte_fused_topk_with_score_function_backward_v2); + using namespace transformer_engine; + fused_router::fused_topk_with_score_function_backward( + *convertNVTETensorCheck(routing_map), routing_map_format, + *convertNVTETensorCheck(intermediate_output), *convertNVTETensorCheck(grad_probs), num_tokens, + num_experts, topk, static_cast(use_pre_softmax), scaling_factor, score_function, + *convertNVTETensorCheck(grad_logits), stream); +} + +// Deprecated V1 entry point: forwards to the V2 above with the BYTEMAP layout. +// Kept for ABI compatibility with external C API consumers. void nvte_fused_topk_with_score_function_backward(const NVTETensor routing_map, const NVTETensor intermediate_output, const NVTETensor grad_probs, int num_tokens, @@ -533,10 +697,7 @@ void nvte_fused_topk_with_score_function_backward(const NVTETensor routing_map, float scaling_factor, int score_function, NVTETensor grad_logits, cudaStream_t stream) { NVTE_API_CALL(nvte_fused_topk_with_score_function_backward); - using namespace transformer_engine; - fused_router::fused_topk_with_score_function_backward( - *convertNVTETensorCheck(routing_map), *convertNVTETensorCheck(intermediate_output), - *convertNVTETensorCheck(grad_probs), num_tokens, num_experts, topk, - static_cast(use_pre_softmax), scaling_factor, score_function, - *convertNVTETensorCheck(grad_logits), stream); + nvte_fused_topk_with_score_function_backward_v2( + routing_map, NVTE_ROUTING_MAP_FORMAT_BYTEMAP, intermediate_output, grad_probs, num_tokens, + num_experts, topk, use_pre_softmax, scaling_factor, score_function, grad_logits, stream); } diff --git a/transformer_engine/common/include/transformer_engine/fused_router.h b/transformer_engine/common/include/transformer_engine/fused_router.h index 794880d324..6cee10bd39 100644 --- a/transformer_engine/common/include/transformer_engine/fused_router.h +++ b/transformer_engine/common/include/transformer_engine/fused_router.h @@ -13,7 +13,26 @@ extern "C" { #endif -/*! \brief Apply topk + softmax/sigmoid to the input tensor. Grouped topk is supported. +/*! \brief Output format of the routing_map tensor. + * + * BYTEMAP — bool/uint8 tensor of shape [num_tokens, num_experts]; one byte + * per (token, expert) pair, 0 or 1. + * BITMAP_U8 — uint8 tensor of shape [num_tokens, ceil(num_experts/8)]; bit + * (e % 8) of byte (e / 8) of row t is 1 iff token t routes to + * expert e (little-endian / LSB-first packing along the expert + * axis). + */ +typedef enum { + NVTE_ROUTING_MAP_FORMAT_BYTEMAP = 0, + NVTE_ROUTING_MAP_FORMAT_BITMAP_U8 = 1, +} NVTERoutingMapFormat; + +/*! \brief Apply topk + softmax/sigmoid to the input tensor. Grouped topk is supported (deprecated). + * + * \deprecated This function has been deprecated in favor of + * nvte_fused_topk_with_score_function_forward_v2, which adds support + * for the NVTE_ROUTING_MAP_FORMAT_BITMAP_U8 routing_map layout. This + * entry point assumes NVTE_ROUTING_MAP_FORMAT_BYTEMAP. * * \param[in] logits Logits from the gating GEMM. * \param[in] num_tokens Number of tokens. @@ -26,7 +45,7 @@ extern "C" { * \param[in] score_function Score function, 0: sigmoid, 1: softmax, 2: sqrtsoftplus. * \param[in] expert_bias Expert bias. (Used at the sigmoid/sqrtsoftplus cases) * \param[out] probs Output tensor for probabilities. - * \param[out] routing_map Output tensor for routing map. + * \param[out] routing_map Output tensor for routing map (BYTEMAP layout). * \param[out] intermediate_output Output tensor for intermediate output. (Softmax/sigmoid output) * \param[in] stream CUDA stream used for the operation. */ @@ -36,9 +55,40 @@ void nvte_fused_topk_with_score_function_forward( const NVTETensor expert_bias, NVTETensor probs, NVTETensor routing_map, NVTETensor intermediate_output, cudaStream_t stream); -/*! \brief Backward pass for fused topk + softmax/sigmoid. +/*! \brief Apply topk + softmax/sigmoid to the input tensor. Grouped topk is supported. + * + * \param[in] logits Logits from the gating GEMM. + * \param[in] num_tokens Number of tokens. + * \param[in] num_experts Number of experts. + * \param[in] topk Topk value. + * \param[in] use_pre_softmax Whether to use softmax before topk. + * \param[in] num_groups Number of groups in grouped topk. + * \param[in] group_topk Grouped topk value. + * \param[in] scaling_factor Scaling factor. + * \param[in] score_function Score function, 0: sigmoid, 1: softmax, 2: sqrtsoftplus. + * \param[in] expert_bias Expert bias. (Used at the sigmoid/sqrtsoftplus cases) + * \param[out] probs Output tensor for probabilities. + * \param[out] routing_map Output tensor for routing map. Shape depends on + * routing_map_format (see NVTERoutingMapFormat). + * \param[in] routing_map_format NVTERoutingMapFormat value selecting the routing_map + * output layout. The caller is responsible for + * allocating routing_map with the matching shape. + * \param[out] intermediate_output Output tensor for intermediate output. (Softmax/sigmoid output) + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_fused_topk_with_score_function_forward_v2( + const NVTETensor logits, int num_tokens, int num_experts, int topk, int use_pre_softmax, + int num_groups, int group_topk, float scaling_factor, int score_function, + const NVTETensor expert_bias, NVTETensor probs, NVTETensor routing_map, + NVTERoutingMapFormat routing_map_format, NVTETensor intermediate_output, cudaStream_t stream); + +/*! \brief Backward pass for fused topk + softmax/sigmoid (deprecated). + * + * \deprecated This function has been deprecated in favor of + * nvte_fused_topk_with_score_function_backward_v2. This entry point + * assumes NVTE_ROUTING_MAP_FORMAT_BYTEMAP. * - * \param[in] routing_map Routing map. + * \param[in] routing_map Routing map (BYTEMAP layout). * \param[in] intermediate_output Intermediate output from the forward pass. (Softmax/sigmoid output) * \param[in] grad_probs Gradient of probs. * \param[in] num_tokens Number of tokens. @@ -57,7 +107,34 @@ void nvte_fused_topk_with_score_function_backward(const NVTETensor routing_map, float scaling_factor, int score_function, NVTETensor grad_logits, cudaStream_t stream); -/*! \brief Forward pass for computing scores/routing map for auxiliary loss. +/*! \brief Backward pass for fused topk + softmax/sigmoid. + * + * \param[in] routing_map Routing map (same layout as produced by forward). + * \param[in] routing_map_format NVTERoutingMapFormat value matching the layout of routing_map. + * \param[in] intermediate_output Intermediate output from the forward pass. (Softmax/sigmoid output) + * \param[in] grad_probs Gradient of probs. + * \param[in] num_tokens Number of tokens. + * \param[in] num_experts Number of experts. + * \param[in] topk Topk value. + * \param[in] use_pre_softmax Whether to use softmax before topk. + * \param[in] scaling_factor Scaling factor. + * \param[in] score_function Score function, 0: sigmoid, 1: softmax, 2: sqrtsoftplus. + * \param[out] grad_logits Gradient of logits. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_fused_topk_with_score_function_backward_v2(const NVTETensor routing_map, + NVTERoutingMapFormat routing_map_format, + const NVTETensor intermediate_output, + const NVTETensor grad_probs, int num_tokens, + int num_experts, int topk, int use_pre_softmax, + float scaling_factor, int score_function, + NVTETensor grad_logits, cudaStream_t stream); + +/*! \brief Forward pass for computing scores/routing map for auxiliary loss (deprecated). + * + * \deprecated This function has been deprecated in favor of + * nvte_fused_score_for_moe_aux_loss_forward_v2. This entry point + * assumes NVTE_ROUTING_MAP_FORMAT_BYTEMAP. * * \param[in] logits Logits from the gating GEMM. * \param[in] num_tokens Number of tokens. @@ -65,16 +142,38 @@ void nvte_fused_topk_with_score_function_backward(const NVTETensor routing_map, * \param[in] topk Topk value. * \param[in] score_function Score function, 0: sigmoid, 1: softmax, 2: sqrtsoftplus. * \param[out] scores Output tensor for scores. - * \param[in] routing_map Routing map. + * \param[out] routing_map Output tensor for routing map (BYTEMAP layout). * \param[in] intermediate_output Intermediate output from the forward pass. (Softmax/sigmoid output) * \param[in] stream CUDA stream used for the operation. */ void nvte_fused_score_for_moe_aux_loss_forward(const NVTETensor logits, int num_tokens, int num_experts, int topk, int score_function, - NVTETensor scores, const NVTETensor routing_map, + NVTETensor scores, NVTETensor routing_map, const NVTETensor intermediate_output, cudaStream_t stream); +/*! \brief Forward pass for computing scores/routing map for auxiliary loss. + * + * \param[in] logits Logits from the gating GEMM. + * \param[in] num_tokens Number of tokens. + * \param[in] num_experts Number of experts. + * \param[in] topk Topk value. + * \param[in] score_function Score function, 0: sigmoid, 1: softmax, 2: sqrtsoftplus. + * \param[out] scores Output tensor for scores. + * \param[out] routing_map Output tensor for routing map. Shape depends on + * routing_map_format (see NVTERoutingMapFormat). + * \param[in] routing_map_format NVTERoutingMapFormat value selecting the routing_map + * output layout. + * \param[in] intermediate_output Intermediate output from the forward pass. (Softmax/sigmoid output) + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_fused_score_for_moe_aux_loss_forward_v2(const NVTETensor logits, int num_tokens, + int num_experts, int topk, int score_function, + NVTETensor scores, NVTETensor routing_map, + NVTERoutingMapFormat routing_map_format, + const NVTETensor intermediate_output, + cudaStream_t stream); + /*! \brief Backward pass for computing scores/routing map for auxiliary loss. * * \param[in] intermediate_output Intermediate output from the forward pass. (Softmax/sigmoid output) diff --git a/transformer_engine/jax/cpp_extensions/router.py b/transformer_engine/jax/cpp_extensions/router.py index 0ae267cbf3..3245439689 100644 --- a/transformer_engine/jax/cpp_extensions/router.py +++ b/transformer_engine/jax/cpp_extensions/router.py @@ -7,13 +7,14 @@ import jax.numpy as jnp from jax import dtypes, ffi from jax.sharding import NamedSharding, PartitionSpec -from transformer_engine_jax import JAXX_Score_Function +from transformer_engine_jax import JAXX_Routing_Map_Format, JAXX_Score_Function from .base import BasePrimitive, register_primitive from .misc import get_padded_spec __all__ = [ "ScoreFunction", + "RoutingMapFormat", "fused_topk_with_score_function_fwd", "fused_topk_with_score_function_bwd", "fused_moe_aux_loss_fwd", @@ -28,6 +29,19 @@ class ScoreFunction(IntEnum): SOFTMAX = int(JAXX_Score_Function.SOFTMAX) +class RoutingMapFormat(IntEnum): + """Routing-map output layout, synced with C++ JAXX_Routing_Map_Format / NVTERoutingMapFormat. + + BYTEMAP — bool/uint8 tensor of shape [num_tokens, num_experts]. + BITMAP_U8 — uint8 tensor of shape [num_tokens, ceil(num_experts/8)]; bit + (e % 8) of byte (e / 8) of row t is 1 iff token t routes to + expert e (LSB-first packing along the expert axis). + """ + + BYTEMAP = int(JAXX_Routing_Map_Format.BYTEMAP) + BITMAP_U8 = int(JAXX_Routing_Map_Format.BITMAP_U8) + + # =========================================== ================================== # Fused Top-K with Score Function - Forward # ============================================================================= @@ -50,7 +64,9 @@ class FusedTopkWithScoreFunctionFwdPrimitive(BasePrimitive): 6, 7, 8, - ) # topk, use_pre_softmax, num_groups, group_topk, scaling_factor, score_function, compute_aux_scores + 9, + ) # topk, use_pre_softmax, num_groups, group_topk, scaling_factor, score_function, + # compute_aux_scores, routing_map_format inner_primitive = None outer_primitive = None @@ -65,6 +81,7 @@ def abstract( scaling_factor, score_function, compute_aux_scores, + routing_map_format, ): """Abstract evaluation: describe output shapes and dtypes.""" del expert_bias_aval, topk, use_pre_softmax, num_groups, group_topk @@ -72,7 +89,15 @@ def abstract( i_dtype = dtypes.canonicalize_dtype(logits_aval.dtype) i_shape = logits_aval.shape probs_aval = logits_aval.update(shape=i_shape, dtype=i_dtype) - routing_map_aval = logits_aval.update(shape=i_shape, dtype=jnp.bool_) + # routing_map shape/dtype depends on the format. In BITMAP_U8 mode the + # expert axis is bit-packed LSB-first into uint8 bytes, so the trailing + # dim becomes ceil(num_experts/8). + if int(routing_map_format) == int(RoutingMapFormat.BITMAP_U8): + packed_experts = (i_shape[-1] + 7) // 8 + routing_map_shape = (*i_shape[:-1], packed_experts) + routing_map_aval = logits_aval.update(shape=routing_map_shape, dtype=jnp.uint8) + else: + routing_map_aval = logits_aval.update(shape=i_shape, dtype=jnp.bool_) # The CUDA kernel always uses float32 (CompType) for intermediate # computations (softmax/sigmoid values saved for backward). intermediate_aval = logits_aval.update(shape=i_shape, dtype=jnp.float32) @@ -91,6 +116,7 @@ def lowering( scaling_factor, score_function, compute_aux_scores, + routing_map_format, ): return ffi.ffi_lowering(FusedTopkWithScoreFunctionFwdPrimitive.name)( ctx, @@ -103,6 +129,7 @@ def lowering( scaling_factor=scaling_factor, score_function=score_function, compute_aux_scores=compute_aux_scores, + routing_map_format=routing_map_format, ) @staticmethod @@ -116,6 +143,7 @@ def impl( scaling_factor, score_function, compute_aux_scores, + routing_map_format, ): if FusedTopkWithScoreFunctionFwdPrimitive.inner_primitive is None: raise RuntimeError( @@ -131,6 +159,7 @@ def impl( scaling_factor=scaling_factor, score_function=score_function, compute_aux_scores=compute_aux_scores, + routing_map_format=routing_map_format, ) @staticmethod @@ -145,6 +174,7 @@ def batcher( scaling_factor, score_function, compute_aux_scores, + routing_map_format, ): if FusedTopkWithScoreFunctionFwdPrimitive.outer_primitive is None: raise RuntimeError( @@ -163,6 +193,7 @@ def batcher( scaling_factor=scaling_factor, score_function=score_function, compute_aux_scores=compute_aux_scores, + routing_map_format=routing_map_format, ), (logits_bdim, logits_bdim, logits_bdim), ) @@ -176,6 +207,7 @@ def partition( scaling_factor, score_function, compute_aux_scores, + routing_map_format, mesh, arg_infos, result_infos, @@ -183,7 +215,14 @@ def partition( del result_infos logits_spec = get_padded_spec(arg_infos[0]) out_sharding = NamedSharding(mesh, PartitionSpec(*logits_spec)) - routing_sharding = NamedSharding(mesh, PartitionSpec(*logits_spec)) + # For bitmap mode the trailing dim is ceil(E/8) instead of E. We keep the + # routing_map sharded the same way logits is along all non-trailing dims + # and replicate the (now packed) expert axis to avoid sharding mid-byte. + if int(routing_map_format) == int(RoutingMapFormat.BITMAP_U8): + routing_spec = (*logits_spec[:-1], None) if len(logits_spec) >= 1 else logits_spec + else: + routing_spec = logits_spec + routing_sharding = NamedSharding(mesh, PartitionSpec(*routing_spec)) intermediate_sharding = NamedSharding(mesh, PartitionSpec(*logits_spec)) out_shardings = [out_sharding, routing_sharding, intermediate_sharding] arg_shardings = (arg_infos[0].sharding, arg_infos[1].sharding) @@ -199,13 +238,28 @@ def sharded_impl(logits, expert_bias): scaling_factor, score_function, compute_aux_scores, + routing_map_format, ) return mesh, sharded_impl, out_shardings, arg_shardings @staticmethod - def shardy_sharding_rule(*args): - del args + def shardy_sharding_rule(*args, **kwargs): + # Static args arrive in impl_static_args order: routing_map_format is the + # last (8th) static arg. Be defensive about positional-vs-kwarg passing + # across JAX versions. + routing_map_format = kwargs.get("routing_map_format") + if routing_map_format is None and len(args) >= 8: + routing_map_format = args[7] + # routing_map's expert axis is the same as logits in BYTEMAP mode; in + # BITMAP_U8 mode it's a packed-byte axis distinct from num_experts. + if routing_map_format is not None and int(routing_map_format) == int( + RoutingMapFormat.BITMAP_U8 + ): + return ( + "num_tokens num_experts, bias_dim ->" + " num_tokens num_experts, num_tokens packed_experts, num_tokens num_experts" + ) return ( "num_tokens num_experts, bias_dim -> num_tokens num_experts, num_tokens num_experts," " num_tokens num_experts" @@ -234,7 +288,8 @@ class FusedTopkWithScoreFunctionBwdPrimitive(BasePrimitive): 5, 6, 7, - ) # topk, use_pre_softmax, scaling_factor, score_function, compute_aux_scores + 8, + ) # topk, use_pre_softmax, scaling_factor, score_function, compute_aux_scores, routing_map_format inner_primitive = None outer_primitive = None @@ -248,9 +303,10 @@ def abstract( scaling_factor, score_function, compute_aux_scores, + routing_map_format, ): del topk, use_pre_softmax, scaling_factor, score_function - del compute_aux_scores, routing_map_aval + del compute_aux_scores, routing_map_aval, routing_map_format return intermediate_aval.update( shape=intermediate_aval.shape, dtype=dtypes.canonicalize_dtype(grad_probs_aval.dtype), @@ -268,6 +324,7 @@ def lowering( scaling_factor, score_function, compute_aux_scores, + routing_map_format, ): return ffi.ffi_lowering(FusedTopkWithScoreFunctionBwdPrimitive.name)( ctx, @@ -279,6 +336,7 @@ def lowering( scaling_factor=scaling_factor, score_function=score_function, compute_aux_scores=compute_aux_scores, + routing_map_format=routing_map_format, ) @staticmethod @@ -291,6 +349,7 @@ def impl( scaling_factor, score_function, compute_aux_scores, + routing_map_format, ): if FusedTopkWithScoreFunctionBwdPrimitive.inner_primitive is None: raise RuntimeError( @@ -305,6 +364,7 @@ def impl( scaling_factor=scaling_factor, score_function=score_function, compute_aux_scores=compute_aux_scores, + routing_map_format=routing_map_format, ) @staticmethod @@ -317,6 +377,7 @@ def batcher( scaling_factor, score_function, compute_aux_scores, + routing_map_format, ): if FusedTopkWithScoreFunctionBwdPrimitive.outer_primitive is None: raise RuntimeError( @@ -334,6 +395,7 @@ def batcher( scaling_factor=scaling_factor, score_function=score_function, compute_aux_scores=compute_aux_scores, + routing_map_format=routing_map_format, ), grad_probs_bdim, ) @@ -345,11 +407,12 @@ def partition( scaling_factor, score_function, compute_aux_scores, + routing_map_format, mesh, arg_infos, result_infos, ): - del result_infos + del result_infos, routing_map_format grad_spec = get_padded_spec(arg_infos[2]) out_sharding = NamedSharding(mesh, PartitionSpec(*grad_spec)) arg_shardings = (arg_infos[0].sharding, arg_infos[1].sharding, arg_infos[2].sharding) @@ -364,13 +427,24 @@ def sharded_impl(routing_map, intermediate, grad_probs): scaling_factor, score_function, compute_aux_scores, + routing_map_format, ) return mesh, sharded_impl, out_sharding, arg_shardings @staticmethod - def shardy_sharding_rule(*args): - del args + def shardy_sharding_rule(*args, **kwargs): + # routing_map_format is the 6th static arg (impl_static_args index 5). + routing_map_format = kwargs.get("routing_map_format") + if routing_map_format is None and len(args) >= 6: + routing_map_format = args[5] + if routing_map_format is not None and int(routing_map_format) == int( + RoutingMapFormat.BITMAP_U8 + ): + return ( + "num_tokens packed_experts, num_tokens num_experts, num_tokens num_experts ->" + " num_tokens num_experts" + ) return ( "num_tokens num_experts, num_tokens num_experts, num_tokens num_experts -> num_tokens" " num_experts" @@ -592,6 +666,7 @@ def fused_topk_with_score_function_fwd( score_function, expert_bias: jnp.ndarray, compute_aux_scores: bool = False, + routing_map_format: int = int(RoutingMapFormat.BYTEMAP), ): """ Fused top-k with score function forward pass. @@ -620,6 +695,9 @@ def fused_topk_with_score_function_fwd( Expert bias (only used with sigmoid). Pass empty array if unused. compute_aux_scores : bool If True, compute clean scores for aux loss instead of full top-k. + routing_map_format : int + RoutingMapFormat.BYTEMAP (default, bool[T, E]) or RoutingMapFormat.BITMAP_U8 + (uint8[T, ceil(E/8)], LSB-first along the expert axis). Returns ------- @@ -635,6 +713,7 @@ def fused_topk_with_score_function_fwd( scaling_factor=float(scaling_factor), score_function=int(score_function), compute_aux_scores=int(compute_aux_scores), + routing_map_format=int(routing_map_format), ) @@ -647,12 +726,17 @@ def fused_topk_with_score_function_bwd( scaling_factor: float, score_function, compute_aux_scores: bool = False, + routing_map_format: int = int(RoutingMapFormat.BYTEMAP), ): """ Fused top-k with score function backward pass. When compute_aux_scores=True, routing_map is ignored and the score-for-aux-loss backward kernel is used instead. + + routing_map_format must match the layout produced by the matching forward + call (BYTEMAP or BITMAP_U8). The CUDA kernel branches per-lane on this flag + when loading bits into shmem. """ return FusedTopkWithScoreFunctionBwdPrimitive.outer_primitive.bind( routing_map, @@ -663,6 +747,7 @@ def fused_topk_with_score_function_bwd( scaling_factor=float(scaling_factor), score_function=int(score_function), compute_aux_scores=int(compute_aux_scores), + routing_map_format=int(routing_map_format), ) diff --git a/transformer_engine/jax/csrc/extensions.h b/transformer_engine/jax/csrc/extensions.h index 2ecfedc8a2..d8bf1788a4 100644 --- a/transformer_engine/jax/csrc/extensions.h +++ b/transformer_engine/jax/csrc/extensions.h @@ -251,6 +251,7 @@ XLA_FFI_REGISTER_STRUCT_ATTR_DECODING( // ENUM_ATTR and DICT_ATTR recoding need to be registered in the global namespace XLA_FFI_REGISTER_ENUM_ATTR_DECODING(transformer_engine::jax::JAXX_Scaling_Mode); XLA_FFI_REGISTER_ENUM_ATTR_DECODING(transformer_engine::jax::JAXX_Score_Function); +XLA_FFI_REGISTER_ENUM_ATTR_DECODING(transformer_engine::jax::JAXX_Routing_Map_Format); XLA_FFI_REGISTER_ENUM_ATTR_DECODING(transformer_engine::jax::JAXX_Collective_Op); XLA_FFI_REGISTER_ENUM_ATTR_DECODING(transformer_engine::jax::JAXX_Quantize_Layout); diff --git a/transformer_engine/jax/csrc/extensions/misc.h b/transformer_engine/jax/csrc/extensions/misc.h index c6f6f87cb4..62f95ad529 100644 --- a/transformer_engine/jax/csrc/extensions/misc.h +++ b/transformer_engine/jax/csrc/extensions/misc.h @@ -4,6 +4,7 @@ * See LICENSE for license information. ************************************************************************/ +#include #include #include @@ -127,6 +128,14 @@ enum class JAXX_Score_Function : int64_t { SOFTMAX = 1, }; +// Mirror of NVTERoutingMapFormat for JAX FFI plumbing. Values are taken +// directly from the C enum in +// transformer_engine/common/include/transformer_engine/fused_router.h. +enum class JAXX_Routing_Map_Format : int64_t { + BYTEMAP = NVTE_ROUTING_MAP_FORMAT_BYTEMAP, + BITMAP_U8 = NVTE_ROUTING_MAP_FORMAT_BITMAP_U8, +}; + enum class JAXX_Collective_Op : int64_t { NONE = 0, ALL_GATHER = 1, diff --git a/transformer_engine/jax/csrc/extensions/pybind.cpp b/transformer_engine/jax/csrc/extensions/pybind.cpp index 70d0403b3e..cc79ff3737 100644 --- a/transformer_engine/jax/csrc/extensions/pybind.cpp +++ b/transformer_engine/jax/csrc/extensions/pybind.cpp @@ -217,6 +217,11 @@ PYBIND11_MODULE(transformer_engine_jax, m) { .value("SOFTMAX", JAXX_Score_Function::SOFTMAX) .export_values(); + pybind11::enum_(m, "JAXX_Routing_Map_Format", pybind11::module_local()) + .value("BYTEMAP", JAXX_Routing_Map_Format::BYTEMAP) + .value("BITMAP_U8", JAXX_Routing_Map_Format::BITMAP_U8) + .export_values(); + pybind11::enum_(m, "JAXX_Collective_Op", pybind11::module_local()) .value("NONE", JAXX_Collective_Op::NONE) .value("ALL_GATHER", JAXX_Collective_Op::ALL_GATHER) diff --git a/transformer_engine/jax/csrc/extensions/router.cpp b/transformer_engine/jax/csrc/extensions/router.cpp index 79daec3f07..89cbc27fdb 100644 --- a/transformer_engine/jax/csrc/extensions/router.cpp +++ b/transformer_engine/jax/csrc/extensions/router.cpp @@ -21,10 +21,11 @@ Error_Type FusedTopkWithScoreFunctionForwardFFI( Buffer_Type logits_buf, // [num_tokens, num_experts] Buffer_Type expert_bias_buf, // [num_experts] or empty Result_Type probs_buf, // [num_tokens, num_experts] (or scores when compute_aux_scores) - Result_Type routing_map_buf, // [num_tokens, num_experts] + Result_Type routing_map_buf, // BYTEMAP: [T, E] uint8 / BITMAP_U8: [T, ceil(E/8)] uint8 Result_Type intermediate_buf, // [num_tokens, num_experts] int64_t topk, int64_t use_pre_softmax, int64_t num_groups, int64_t group_topk, - double scaling_factor, JAXX_Score_Function score_function, int64_t compute_aux_scores) { + double scaling_factor, JAXX_Score_Function score_function, int64_t compute_aux_scores, + JAXX_Routing_Map_Format routing_map_format) { auto dtype = convert_ffi_datatype_to_te_dtype(logits_buf.element_type()); auto dims = logits_buf.dimensions(); auto num_tokens = static_cast(product(dims, 0, dims.size() - 1)); @@ -40,7 +41,16 @@ Error_Type FusedTopkWithScoreFunctionForwardFFI( std::vector{static_cast(num_tokens), static_cast(num_experts)}; auto logits_tensor = TensorWrapper(logits, flat_shape, dtype); auto probs_tensor = TensorWrapper(probs, flat_shape, dtype); - auto routing_map_tensor = TensorWrapper(routing_map, flat_shape, DType::kByte); + // Flatten the routing_map shape to match the kernel's 2D indexing. The trailing + // dim depends on the requested format: num_experts for BYTEMAP, ceil(num_experts/8) + // for BITMAP_U8. Keeping this 2D also lets the kernel's shape NVTE_CHECKs fire. + auto routing_map_format_nvte = static_cast(routing_map_format); + size_t routing_map_trailing = (routing_map_format_nvte == NVTE_ROUTING_MAP_FORMAT_BITMAP_U8) + ? static_cast((num_experts + 7) / 8) + : static_cast(num_experts); + auto routing_map_shape = + std::vector{static_cast(num_tokens), routing_map_trailing}; + auto routing_map_tensor = TensorWrapper(routing_map, routing_map_shape, DType::kByte); // intermediate is always float32 (CompType) regardless of logits dtype. auto intermediate_dtype = convert_ffi_datatype_to_te_dtype(intermediate_buf->element_type()); NVTE_CHECK( @@ -51,10 +61,10 @@ Error_Type FusedTopkWithScoreFunctionForwardFFI( auto intermediate_tensor = TensorWrapper(intermediate, flat_shape, DType::kFloat32); if (compute_aux_scores) { - nvte_fused_score_for_moe_aux_loss_forward( + nvte_fused_score_for_moe_aux_loss_forward_v2( logits_tensor.data(), num_tokens, num_experts, static_cast(topk), static_cast(score_function), probs_tensor.data(), routing_map_tensor.data(), - intermediate_tensor.data(), stream); + routing_map_format_nvte, intermediate_tensor.data(), stream); } else { auto bias_dims = expert_bias_buf.dimensions(); auto expert_bias_tensor = @@ -63,12 +73,12 @@ Error_Type FusedTopkWithScoreFunctionForwardFFI( convert_ffi_datatype_to_te_dtype(expert_bias_buf.element_type())) : TensorWrapper(); - nvte_fused_topk_with_score_function_forward( + nvte_fused_topk_with_score_function_forward_v2( logits_tensor.data(), num_tokens, num_experts, static_cast(topk), static_cast(use_pre_softmax), static_cast(num_groups), static_cast(group_topk), static_cast(scaling_factor), static_cast(score_function), expert_bias_tensor.data(), probs_tensor.data(), - routing_map_tensor.data(), intermediate_tensor.data(), stream); + routing_map_tensor.data(), routing_map_format_nvte, intermediate_tensor.data(), stream); } return ffi_with_cuda_error_check(); @@ -89,7 +99,8 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(FusedTopkWithScoreFunctionForwardHandler, .Attr("group_topk") .Attr("scaling_factor") .Attr("score_function") - .Attr("compute_aux_scores"), + .Attr("compute_aux_scores") + .Attr("routing_map_format"), FFI_CudaGraph_Traits); // ============================================================================ @@ -98,12 +109,13 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(FusedTopkWithScoreFunctionForwardHandler, Error_Type FusedTopkWithScoreFunctionBackwardFFI( cudaStream_t stream, - Buffer_Type routing_map_buf, // [num_tokens, num_experts] (unused when compute_aux_scores) + Buffer_Type routing_map_buf, // bytemap [T,E] uint8 or bitmap [T, ceil(E/8)] uint8 Buffer_Type intermediate_buf, // [num_tokens, num_experts] Buffer_Type grad_probs_buf, // [num_tokens, num_experts] (grad_scores when compute_aux_scores) Result_Type grad_logits_buf, // [num_tokens, num_experts] int64_t topk, int64_t use_pre_softmax, double scaling_factor, - JAXX_Score_Function score_function, int64_t compute_aux_scores) { + JAXX_Score_Function score_function, int64_t compute_aux_scores, + JAXX_Routing_Map_Format routing_map_format) { // intermediate is always float32 (CompType) regardless of logits dtype. auto intermediate_dtype = convert_ffi_datatype_to_te_dtype(intermediate_buf.element_type()); NVTE_CHECK( @@ -130,14 +142,20 @@ Error_Type FusedTopkWithScoreFunctionBackwardFFI( static_cast(score_function), grad_logits_tensor.data(), stream); } else { + auto routing_map_format_nvte = static_cast(routing_map_format); + size_t routing_map_trailing = (routing_map_format_nvte == NVTE_ROUTING_MAP_FORMAT_BITMAP_U8) + ? static_cast((num_experts + 7) / 8) + : static_cast(num_experts); + auto routing_map_shape = + std::vector{static_cast(num_tokens), routing_map_trailing}; auto routing_map_tensor = - TensorWrapper(routing_map_buf.untyped_data(), flat_shape, DType::kByte); + TensorWrapper(routing_map_buf.untyped_data(), routing_map_shape, DType::kByte); - nvte_fused_topk_with_score_function_backward( - routing_map_tensor.data(), intermediate_tensor.data(), grad_probs_tensor.data(), num_tokens, - num_experts, static_cast(topk), static_cast(use_pre_softmax), - static_cast(scaling_factor), static_cast(score_function), - grad_logits_tensor.data(), stream); + nvte_fused_topk_with_score_function_backward_v2( + routing_map_tensor.data(), routing_map_format_nvte, intermediate_tensor.data(), + grad_probs_tensor.data(), num_tokens, num_experts, static_cast(topk), + static_cast(use_pre_softmax), static_cast(scaling_factor), + static_cast(score_function), grad_logits_tensor.data(), stream); } return ffi_with_cuda_error_check(); @@ -155,7 +173,8 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(FusedTopkWithScoreFunctionBackwardHandler, .Attr("use_pre_softmax") .Attr("scaling_factor") .Attr("score_function") - .Attr("compute_aux_scores"), + .Attr("compute_aux_scores") + .Attr("routing_map_format"), FFI_CudaGraph_Traits); // ============================================================================ diff --git a/transformer_engine/jax/router.py b/transformer_engine/jax/router.py index 65f2e8a7ff..80ec42a95f 100644 --- a/transformer_engine/jax/router.py +++ b/transformer_engine/jax/router.py @@ -27,6 +27,7 @@ from transformer_engine.jax.cpp_extensions.router import ( ScoreFunction, + RoutingMapFormat, fused_topk_with_score_function_fwd, fused_topk_with_score_function_bwd, fused_moe_aux_loss_fwd, @@ -35,11 +36,35 @@ __all__ = [ "ScoreFunction", + "RoutingMapFormat", "fused_topk_with_score_function", "fused_moe_aux_loss", ] +def _validate_routing_map_format( + routing_map_format: Union[str, RoutingMapFormat, int], +) -> RoutingMapFormat: + """Validate and convert routing_map_format to a RoutingMapFormat enum.""" + if isinstance(routing_map_format, RoutingMapFormat): + return routing_map_format + if isinstance(routing_map_format, int): + try: + return RoutingMapFormat(routing_map_format) + except ValueError: + raise ValueError( + "routing_map_format int must match a RoutingMapFormat value; " + f"got {routing_map_format!r}" + ) from None + try: + return RoutingMapFormat[routing_map_format.upper()] + except (KeyError, AttributeError): + raise ValueError( + "routing_map_format must be 'bytemap', 'bitmap_u8', a RoutingMapFormat enum, " + f"or the matching int; got {routing_map_format!r}" + ) from None + + def _validate_score_function(score_function: Union[str, ScoreFunction]) -> ScoreFunction: """Validate and convert score_function to a ScoreFunction enum.""" if isinstance(score_function, ScoreFunction): @@ -68,6 +93,7 @@ def fused_topk_with_score_function( score_function: Union[str, ScoreFunction] = ScoreFunction.SOFTMAX, expert_bias: Optional[jnp.ndarray] = None, compute_aux_scores: bool = False, + routing_map_format: Union[str, RoutingMapFormat, int] = RoutingMapFormat.BYTEMAP, ) -> Tuple[jnp.ndarray, jnp.ndarray]: """ Fused top-k with score function router. @@ -108,6 +134,11 @@ def fused_topk_with_score_function( compute_aux_scores : bool If True, use the clean score-for-aux-loss kernel. Returns dense scores over all experts instead of sparse probs. + routing_map_format : Union[str, RoutingMapFormat, int] + Output layout for routing_map. "bytemap" / RoutingMapFormat.BYTEMAP (default) + returns a bool[T, E] tensor; "bitmap_u8" / RoutingMapFormat.BITMAP_U8 returns a + uint8[T, ceil(E/8)] tensor with bit (e % 8) of byte (e / 8) set when token t + routes to expert e (LSB-first along the expert axis). Returns ------- @@ -117,8 +148,10 @@ def fused_topk_with_score_function( When compute_aux_scores=True: Dense score tensor, shape [num_tokens, num_experts]. All expert positions contain scores. routing_map : jnp.ndarray - Boolean mask, shape [num_tokens, num_experts]. - True at selected expert positions. + Shape/dtype depend on routing_map_format: + - BYTEMAP: bool[num_tokens, num_experts]; True at selected expert positions. + - BITMAP_U8: uint8[num_tokens, ceil(num_experts/8)]; LSB-first bit-packed + along the expert axis. """ if not isinstance(scaling_factor, (int, float)): raise TypeError( @@ -127,6 +160,7 @@ def fused_topk_with_score_function( ) score_function = _validate_score_function(score_function) + routing_map_format = _validate_routing_map_format(routing_map_format) if compute_aux_scores: expert_bias = jnp.empty((0,), dtype=logits.dtype) @@ -153,12 +187,13 @@ def fused_topk_with_score_function( scaling_factor, score_function, compute_aux_scores, + routing_map_format, ) return probs_or_scores, routing_map -@partial(jax.custom_vjp, nondiff_argnums=(2, 3, 4, 5, 6, 7, 8)) +@partial(jax.custom_vjp, nondiff_argnums=(2, 3, 4, 5, 6, 7, 8, 9)) def _fused_topk_with_score_function( logits: jnp.ndarray, expert_bias: jnp.ndarray, @@ -169,6 +204,7 @@ def _fused_topk_with_score_function( scaling_factor: float, score_function: ScoreFunction, compute_aux_scores: bool, + routing_map_format: RoutingMapFormat, ) -> Tuple[jnp.ndarray, jnp.ndarray]: (probs, routing_map), _ = _fused_topk_with_score_function_fwd( logits, @@ -180,6 +216,7 @@ def _fused_topk_with_score_function( scaling_factor, score_function, compute_aux_scores, + routing_map_format, ) return probs, routing_map @@ -194,6 +231,7 @@ def _fused_topk_with_score_function_fwd( scaling_factor, score_function, compute_aux_scores, + routing_map_format, ): probs, routing_map, saved_scores = fused_topk_with_score_function_fwd( logits, @@ -205,6 +243,7 @@ def _fused_topk_with_score_function_fwd( score_function, expert_bias, compute_aux_scores, + routing_map_format, ) residuals = (routing_map, saved_scores) return (probs, routing_map), residuals @@ -218,6 +257,7 @@ def _fused_topk_with_score_function_bwd( scaling_factor, score_function, compute_aux_scores, + routing_map_format, residuals, g, ): @@ -234,6 +274,7 @@ def _fused_topk_with_score_function_bwd( scaling_factor, score_function, compute_aux_scores, + routing_map_format, ) return grad_logits, None diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 8082ff07ed..f6efcab254 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -29,16 +29,18 @@ namespace transformer_engine::pytorch { std::tuple fused_topk_with_score_function_fwd( at::Tensor logits, int topk, bool use_pre_softmax, std::optional num_groups, std::optional group_topk, std::optional scaling_factor, std::string score_function, - std::optional expert_bias); + std::optional expert_bias, + NVTERoutingMapFormat routing_map_format = NVTE_ROUTING_MAP_FORMAT_BYTEMAP); -void fused_topk_with_score_function_bwd(int num_tokens, int num_experts, at::Tensor routing_map, - at::Tensor intermediate_output, at::Tensor grad_probs, - at::Tensor grad_logits, int topk, bool use_pre_softmax, - std::optional scaling_factor, - std::string score_function); +void fused_topk_with_score_function_bwd( + int num_tokens, int num_experts, at::Tensor routing_map, at::Tensor intermediate_output, + at::Tensor grad_probs, at::Tensor grad_logits, int topk, bool use_pre_softmax, + std::optional scaling_factor, std::string score_function, + NVTERoutingMapFormat routing_map_format = NVTE_ROUTING_MAP_FORMAT_BYTEMAP); std::tuple fused_score_for_moe_aux_loss_fwd( - at::Tensor logits, int topk, std::string score_function); + at::Tensor logits, int topk, std::string score_function, + NVTERoutingMapFormat routing_map_format = NVTE_ROUTING_MAP_FORMAT_BYTEMAP); void fused_score_for_moe_aux_loss_bwd(int num_tokens, int num_experts, at::Tensor intermediate_output, at::Tensor grad_probs, diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index a4571c64e2..4d60b5d209 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -129,6 +129,41 @@ void init_extension() { }); } +// Pybind11 registrations for the fused MoE router kernels. Split out of +// PYBIND11_MODULE() to keep that function under the cpplint readability/fn_size +// limit. +void init_router_bindings(pybind11::module &m) { + pybind11::enum_(m, "NVTERoutingMapFormat", pybind11::module_local()) + .value("BYTEMAP", NVTE_ROUTING_MAP_FORMAT_BYTEMAP) + .value("BITMAP_U8", NVTE_ROUTING_MAP_FORMAT_BITMAP_U8); + m.def("fused_topk_with_score_function_fwd", &fused_topk_with_score_function_fwd, + py::arg("logits"), py::arg("topk"), py::arg("use_pre_softmax"), py::arg("num_groups"), + py::arg("group_topk"), py::arg("scaling_factor"), py::arg("score_function"), + py::arg("expert_bias"), py::arg("routing_map_format") = NVTE_ROUTING_MAP_FORMAT_BYTEMAP, + "Fused topk with score function fwd"); + m.def("fused_topk_with_score_function_bwd", &fused_topk_with_score_function_bwd, + py::arg("num_tokens"), py::arg("num_experts"), py::arg("routing_map"), + py::arg("intermediate_output"), py::arg("grad_probs"), py::arg("grad_logits"), + py::arg("topk"), py::arg("use_pre_softmax"), py::arg("scaling_factor"), + py::arg("score_function"), py::arg("routing_map_format") = NVTE_ROUTING_MAP_FORMAT_BYTEMAP, + "Fused topk with score function bwd"); + m.def("fused_score_for_moe_aux_loss_fwd", &fused_score_for_moe_aux_loss_fwd, py::arg("logits"), + py::arg("topk"), py::arg("score_function"), + py::arg("routing_map_format") = NVTE_ROUTING_MAP_FORMAT_BYTEMAP, + "Fused aux loss with score function fwd"); + m.def("fused_score_for_moe_aux_loss_bwd", &fused_score_for_moe_aux_loss_bwd, + py::arg("num_tokens"), py::arg("num_experts"), py::arg("intermediate_output"), + py::arg("grad_scores"), py::arg("grad_logits"), py::arg("topk"), py::arg("score_function"), + "Fused aux loss with score function bwd"); + m.def("fused_moe_aux_loss_fwd", &fused_moe_aux_loss_fwd, py::arg("probs"), + py::arg("tokens_per_expert"), py::arg("total_num_tokens"), py::arg("num_experts"), + py::arg("num_rows"), py::arg("num_cols"), py::arg("topk"), py::arg("coeff"), + "Fused aux loss fwd"); + m.def("fused_moe_aux_loss_bwd", &fused_moe_aux_loss_bwd, py::arg("Const_buf"), + py::arg("tokens_per_expert"), py::arg("num_rows"), py::arg("num_cols"), + py::arg("grad_aux_loss"), "Fused aux loss bwd"); +} + } // namespace transformer_engine::pytorch #include "common/util/pybind_helper.h" @@ -451,31 +486,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { "Fused Apply QKV RoPE BWD", py::call_guard()); // fused router - m.def("fused_topk_with_score_function_fwd", - &transformer_engine::pytorch::fused_topk_with_score_function_fwd, py::arg("logits"), - py::arg("topk"), py::arg("use_pre_softmax"), py::arg("num_groups"), py::arg("group_topk"), - py::arg("scaling_factor"), py::arg("score_function"), py::arg("expert_bias"), - "Fused topk with score function fwd"); - m.def("fused_topk_with_score_function_bwd", - &transformer_engine::pytorch::fused_topk_with_score_function_bwd, py::arg("num_tokens"), - py::arg("num_experts"), py::arg("routing_map"), py::arg("intermediate_output"), - py::arg("grad_probs"), py::arg("grad_logits"), py::arg("topk"), py::arg("use_pre_softmax"), - py::arg("scaling_factor"), py::arg("score_function"), "Fused topk with score function bwd"); - m.def("fused_score_for_moe_aux_loss_fwd", - &transformer_engine::pytorch::fused_score_for_moe_aux_loss_fwd, py::arg("logits"), - py::arg("topk"), py::arg("score_function"), "Fused aux loss with score function fwd"); - m.def("fused_score_for_moe_aux_loss_bwd", - &transformer_engine::pytorch::fused_score_for_moe_aux_loss_bwd, py::arg("num_tokens"), - py::arg("num_experts"), py::arg("intermediate_output"), py::arg("grad_scores"), - py::arg("grad_logits"), py::arg("topk"), py::arg("score_function"), - "Fused aux loss with score function bwd"); - m.def("fused_moe_aux_loss_fwd", &transformer_engine::pytorch::fused_moe_aux_loss_fwd, - py::arg("probs"), py::arg("tokens_per_expert"), py::arg("total_num_tokens"), - py::arg("num_experts"), py::arg("num_rows"), py::arg("num_cols"), py::arg("topk"), - py::arg("coeff"), "Fused aux loss fwd"); - m.def("fused_moe_aux_loss_bwd", &transformer_engine::pytorch::fused_moe_aux_loss_bwd, - py::arg("Const_buf"), py::arg("tokens_per_expert"), py::arg("num_rows"), - py::arg("num_cols"), py::arg("grad_aux_loss"), "Fused aux loss bwd"); + transformer_engine::pytorch::init_router_bindings(m); // Dropout m.def("dropout_fwd", transformer_engine::pytorch::dropout_fwd, "Dropout forward with 8-bit RNG", diff --git a/transformer_engine/pytorch/csrc/extensions/router.cpp b/transformer_engine/pytorch/csrc/extensions/router.cpp index 4df64d8e26..11490725e2 100644 --- a/transformer_engine/pytorch/csrc/extensions/router.cpp +++ b/transformer_engine/pytorch/csrc/extensions/router.cpp @@ -12,10 +12,23 @@ namespace transformer_engine::pytorch { static std::map score_function_map = { {"sigmoid", 0}, {"softmax", 1}, {"sqrtsoftplus", 2}}; +// Allocate a routing_map output tensor with the layout that matches the +// requested NVTERoutingMapFormat. +// BYTEMAP -> bool[num_tokens, num_experts] +// BITMAP_U8 -> uint8[num_tokens, ceil(num_experts/8)], LSB-first along the +// expert axis. +static at::Tensor allocate_routing_map(int num_tokens, int num_experts, + NVTERoutingMapFormat format) { + if (format == NVTE_ROUTING_MAP_FORMAT_BITMAP_U8) { + return at::empty({num_tokens, (num_experts + 7) / 8}, at::dtype(at::kByte).device(at::kCUDA)); + } + return at::empty({num_tokens, num_experts}, at::dtype(at::kBool).device(at::kCUDA)); +} + std::tuple fused_topk_with_score_function_fwd( at::Tensor logits, int topk, bool use_pre_softmax, std::optional num_groups, std::optional group_topk, std::optional scaling_factor, std::string score_function, - std::optional expert_bias) { + std::optional expert_bias, NVTERoutingMapFormat routing_map_format) { int num_tokens = logits.size(0); int num_experts = logits.size(1); // Check if the input is valid @@ -44,8 +57,7 @@ std::tuple fused_topk_with_score_function_fw // Construct the output tensor at::Tensor probs = at::empty({num_tokens, num_experts}, at::dtype(logits.scalar_type()).device(at::kCUDA)); - at::Tensor routing_map = - at::empty({num_tokens, num_experts}, at::dtype(at::kBool).device(at::kCUDA)); + at::Tensor routing_map = allocate_routing_map(num_tokens, num_experts, routing_map_format); // Intermediate output is used to store the output of the softmax/sigmoid function at::Tensor intermediate_output = at::empty({num_tokens, num_experts}, at::dtype(at::kFloat).device(at::kCUDA)); @@ -59,11 +71,11 @@ std::tuple fused_topk_with_score_function_fw expert_bias_cu = makeTransformerEngineTensor(expert_bias.value()); } - nvte_fused_topk_with_score_function_forward( + nvte_fused_topk_with_score_function_forward_v2( logits_cu.data(), num_tokens, num_experts, topk, use_pre_softmax, num_groups_value, group_topk_value, scaling_factor_value, score_function_map[score_function], - expert_bias_cu.data(), probs_cu.data(), routing_map_cu.data(), intermediate_output_cu.data(), - at::cuda::getCurrentCUDAStream()); + expert_bias_cu.data(), probs_cu.data(), routing_map_cu.data(), routing_map_format, + intermediate_output_cu.data(), at::cuda::getCurrentCUDAStream()); return std::make_tuple(probs, routing_map, intermediate_output); } @@ -72,7 +84,8 @@ void fused_topk_with_score_function_bwd(int num_tokens, int num_experts, at::Ten at::Tensor intermediate_output, at::Tensor grad_probs, at::Tensor grad_logits, int topk, bool use_pre_softmax, std::optional scaling_factor, - std::string score_function) { + std::string score_function, + NVTERoutingMapFormat routing_map_format) { // Get the value of the parameters auto scaling_factor_value = scaling_factor.has_value() ? scaling_factor.value() : 1.0f; auto score_function_value = score_function_map[score_function]; @@ -82,14 +95,15 @@ void fused_topk_with_score_function_bwd(int num_tokens, int num_experts, at::Ten auto grad_probs_cu = makeTransformerEngineTensor(grad_probs); auto grad_logits_cu = makeTransformerEngineTensor(grad_logits); - nvte_fused_topk_with_score_function_backward( - routing_map_cu.data(), intermediate_output_cu.data(), grad_probs_cu.data(), num_tokens, - num_experts, topk, use_pre_softmax, scaling_factor_value, score_function_value, - grad_logits_cu.data(), at::cuda::getCurrentCUDAStream()); + nvte_fused_topk_with_score_function_backward_v2( + routing_map_cu.data(), routing_map_format, intermediate_output_cu.data(), + grad_probs_cu.data(), num_tokens, num_experts, topk, use_pre_softmax, scaling_factor_value, + score_function_value, grad_logits_cu.data(), at::cuda::getCurrentCUDAStream()); } std::tuple fused_score_for_moe_aux_loss_fwd( - at::Tensor logits, int topk, std::string score_function) { + at::Tensor logits, int topk, std::string score_function, + NVTERoutingMapFormat routing_map_format) { int num_tokens = logits.size(0); int num_experts = logits.size(1); // Check if the input is valid @@ -104,8 +118,7 @@ std::tuple fused_score_for_moe_aux_loss_fwd( // Construct the output tensor at::Tensor scores = at::empty({num_tokens, num_experts}, at::dtype(at::kFloat).device(at::kCUDA)); - at::Tensor routing_map = - at::empty({num_tokens, num_experts}, at::dtype(at::kBool).device(at::kCUDA)); + at::Tensor routing_map = allocate_routing_map(num_tokens, num_experts, routing_map_format); at::Tensor intermediate_output = at::empty({num_tokens, num_experts}, at::dtype(at::kFloat).device(at::kCUDA)); @@ -114,9 +127,10 @@ std::tuple fused_score_for_moe_aux_loss_fwd( auto routing_map_cu = makeTransformerEngineTensor(routing_map); auto intermediate_output_cu = makeTransformerEngineTensor(intermediate_output); - nvte_fused_score_for_moe_aux_loss_forward( + nvte_fused_score_for_moe_aux_loss_forward_v2( logits_cu.data(), num_tokens, num_experts, topk, score_function_value, scores_cu.data(), - routing_map_cu.data(), intermediate_output_cu.data(), at::cuda::getCurrentCUDAStream()); + routing_map_cu.data(), routing_map_format, intermediate_output_cu.data(), + at::cuda::getCurrentCUDAStream()); return std::make_tuple(scores, routing_map, intermediate_output); } diff --git a/transformer_engine/pytorch/router.py b/transformer_engine/pytorch/router.py index b56b1cd5eb..4ee4d36f7a 100644 --- a/transformer_engine/pytorch/router.py +++ b/transformer_engine/pytorch/router.py @@ -13,12 +13,61 @@ - Only cast to low-precision when necessary and the casting only happens in writing to global memory. For example, the gradient is required to have the same dtype as the input. """ -from typing import Optional +from typing import Optional, Union import torch import transformer_engine_torch as tex +# Re-export the C++ enum NVTERoutingMapFormat under a friendlier Python name. +# Members: +# RoutingMapFormat.BYTEMAP — bool[num_tokens, num_experts] +# RoutingMapFormat.BITMAP_U8 — uint8[num_tokens, ceil(num_experts/8)], +# LSB-first / little-endian packing along the +# expert axis. +RoutingMapFormat = tex.NVTERoutingMapFormat + + +_ROUTING_MAP_FORMAT_FROM_STRING = { + "bytemap": RoutingMapFormat.BYTEMAP, + "bitmap_u8": RoutingMapFormat.BITMAP_U8, +} + + +def _validate_routing_map_format( + routing_map_format: Union[str, "RoutingMapFormat", int], +) -> "RoutingMapFormat": + """Coerce user-supplied routing_map_format into the NVTERoutingMapFormat enum. + + Accepts an enum value, an int with one of the enum's values, or one of the + case-insensitive strings ``"bytemap"`` / ``"bitmap_u8"``. String parsing is + only supported at this outer API boundary; the rest of the stack uses the + enum directly. + """ + if isinstance(routing_map_format, RoutingMapFormat): + return routing_map_format + if isinstance(routing_map_format, str): + try: + return _ROUTING_MAP_FORMAT_FROM_STRING[routing_map_format.lower()] + except KeyError: + raise ValueError( + "routing_map_format string must be 'bytemap' or 'bitmap_u8'; " + f"got {routing_map_format!r}" + ) from None + if isinstance(routing_map_format, int): + try: + return RoutingMapFormat(routing_map_format) + except ValueError: + raise ValueError( + "routing_map_format int must match a RoutingMapFormat value; " + f"got {routing_map_format!r}" + ) from None + raise TypeError( + "routing_map_format must be a RoutingMapFormat enum, an int matching one " + f"of its values, or 'bytemap' / 'bitmap_u8'; got {routing_map_format!r}" + ) + + class FusedTopkScoreFunction(torch.autograd.Function): """ Fused Topk with Score Function router. @@ -36,6 +85,7 @@ def forward( scaling_factor: Optional[float], score_function: str, expert_bias: Optional[torch.Tensor], + routing_map_format: "RoutingMapFormat", ): # pylint: disable=missing-function-docstring # Save the shape of the logits @@ -53,16 +103,23 @@ def forward( scaling_factor, score_function, expert_bias, + routing_map_format, ) - # Restore the shape - probs = probs.view(tensor_shape) + # Save the flat 2D routing_map for backward (kernel indexes by + # num_tokens x trailing_dim), then restore the leading dims of the + # input on the returned outputs. The trailing dim of routing_map + # depends on the format: num_experts for BYTEMAP, ceil(num_experts/8) + # for BITMAP_U8. ctx.save_for_backward(routing_map, intermediate_output) + probs = probs.view(tensor_shape) + routing_map = routing_map.view(*tensor_shape[:-1], routing_map.shape[-1]) ctx.num_tokens = num_tokens ctx.num_experts = num_experts ctx.use_pre_softmax = use_pre_softmax ctx.topk = topk ctx.scaling_factor = scaling_factor ctx.score_function = score_function + ctx.routing_map_format = routing_map_format ctx.logits_dtype = logits.dtype return probs, routing_map @@ -88,10 +145,11 @@ def backward(ctx, grad_probs, _): ctx.use_pre_softmax, ctx.scaling_factor, ctx.score_function, + ctx.routing_map_format, ) # Restore the shape grad_logits = grad_logits.view(tensor_shape) - return grad_logits, None, None, None, None, None, None, None + return grad_logits, None, None, None, None, None, None, None, None def fused_topk_with_score_function( @@ -103,6 +161,7 @@ def fused_topk_with_score_function( scaling_factor: Optional[float], score_function: str, expert_bias: Optional[torch.Tensor], + routing_map_format: Union[str, RoutingMapFormat, int] = RoutingMapFormat.BYTEMAP, ): """ Fused topk with score function router. @@ -121,14 +180,27 @@ def fused_topk_with_score_function( currently support "softmax", "sigmoid" and "sqrtsoftplus". expert_bias : torch.Tensor, optional could be used with the sigmoid/sqrtsoftplus score functions. + routing_map_format : Union[str, RoutingMapFormat, int], optional + Output layout for routing_map. ``"bytemap"`` / ``RoutingMapFormat.BYTEMAP`` + (default) returns a bool[T, E] tensor; ``"bitmap_u8"`` / + ``RoutingMapFormat.BITMAP_U8`` returns a uint8[T, ceil(E/8)] tensor with + bit ``(e % 8)`` of byte ``(e / 8)`` set when token ``t`` routes to expert + ``e`` (LSB-first / little-endian packing along the expert axis). Returns ------- probs : torch.Tensor in the same dtype as the "logits". - routing_map : torch.Tensor in bool. + Same shape as ``logits``. + routing_map : torch.Tensor + Same leading dims as ``logits``; trailing dim and dtype depend on + routing_map_format: + - BYTEMAP: bool[*logits.shape[:-1], num_experts] + - BITMAP_U8: uint8[*logits.shape[:-1], ceil(num_experts/8)] + LSB-first bit-packed. """ if logits.dtype == torch.float64: raise ValueError("Current TE does not support float64 router type.") + routing_map_format = _validate_routing_map_format(routing_map_format) return FusedTopkScoreFunction.apply( logits, topk, @@ -138,6 +210,7 @@ def fused_topk_with_score_function( scaling_factor, score_function, expert_bias, + routing_map_format, ) @@ -152,6 +225,7 @@ def forward( logits: torch.Tensor, topk: int, score_function: str, + routing_map_format: "RoutingMapFormat", ): # pylint: disable=missing-function-docstring # Save the shape of the logits @@ -164,6 +238,7 @@ def forward( logits=logits, topk=topk, score_function=score_function, + routing_map_format=routing_map_format, ) ctx.save_for_backward(intermediate_output) ctx.topk = topk @@ -171,6 +246,11 @@ def forward( ctx.num_tokens = num_tokens ctx.num_experts = num_experts ctx.logits_dtype = logits.dtype + # Restore the leading dims of the input on both outputs. The trailing + # dim of routing_map depends on the format: num_experts for BYTEMAP, + # ceil(num_experts/8) for BITMAP_U8. + scores = scores.view(tensor_shape) + routing_map = routing_map.view(*tensor_shape[:-1], routing_map.shape[-1]) return routing_map, scores @staticmethod @@ -195,13 +275,14 @@ def backward(ctx, _, grad_scores): ) # Restore the shape grad_logits = grad_logits.view(tensor_shape) - return grad_logits, None, None + return grad_logits, None, None, None def fused_compute_score_for_moe_aux_loss( logits: torch.Tensor, topk: int, score_function: str, + routing_map_format: Union[str, RoutingMapFormat, int] = RoutingMapFormat.BYTEMAP, ): """ Fused compute scores for MoE aux loss, subset of the fused_topk_with_score_function. @@ -211,13 +292,20 @@ def fused_compute_score_for_moe_aux_loss( topk : int score_function : str currently support "softmax", "sigmoid" and "sqrtsoftplus". + routing_map_format : Union[str, RoutingMapFormat, int], optional + Output layout for routing_map; see :func:`fused_topk_with_score_function`. Returns ------- - routing_map : torch.Tensor in bool + routing_map : torch.Tensor + Same leading dims as ``logits``; trailing dim and dtype depend on + routing_map_format (bool[..., num_experts] for BYTEMAP, + uint8[..., ceil(num_experts/8)] for BITMAP_U8). scores : torch.Tensor in fp32 + Same shape as ``logits``. """ - return FusedComputeScoresForMoEAuxLoss.apply(logits, topk, score_function) + routing_map_format = _validate_routing_map_format(routing_map_format) + return FusedComputeScoresForMoEAuxLoss.apply(logits, topk, score_function, routing_map_format) class FusedAuxLoss(torch.autograd.Function):