From 74d325d98faad3577473ac8c755ce68101619d31 Mon Sep 17 00:00:00 2001 From: Shuwen-Fang Date: Mon, 22 Jun 2026 21:49:34 +0000 Subject: [PATCH] Update tgmm --- .../pallas_mosaic_tpu_v2_tgmm_kernel.py | 45 ++++++++++++++++--- .../unit/pallas_mosaic_tpu_v2_kernel_test.py | 18 ++++++-- 2 files changed, 52 insertions(+), 11 deletions(-) diff --git a/src/maxtext/kernels/megablox/pallas_mosaic_tpu_v2_tgmm_kernel.py b/src/maxtext/kernels/megablox/pallas_mosaic_tpu_v2_tgmm_kernel.py index 01b17bd247..c174076000 100644 --- a/src/maxtext/kernels/megablox/pallas_mosaic_tpu_v2_tgmm_kernel.py +++ b/src/maxtext/kernels/megablox/pallas_mosaic_tpu_v2_tgmm_kernel.py @@ -54,6 +54,7 @@ class OperandRef: jnp.dtype, jnp.dtype, int, + bool, ], gmm_v2.TileSizes, ] @@ -89,6 +90,7 @@ def calculate_tgmm_tiling( out_dtype: jnp.dtype, acc_dtype: jnp.dtype, target_zero_ref_bytes: int, + has_partial_sum: bool = False, ) -> gmm_v2.TileSizes: """Calculate optimal tile sizes for TGMM kernel.""" # In tgmm, we calculate lhs.T @ dout which doesn't require quantization. @@ -121,8 +123,10 @@ def within_vmem_limit(tile_m, tile_k, tile_n): # XLU's transpose. in order to reduce redundant XLU computation, instead # of performing XLU's transpose every time lhs is pushed into XLU, it # caches the transposed value into VMEM. this increases VMEM requirement. + ps_bytes = tile_k * tile_n * num_buffers * out_bytes if has_partial_sum else 0 budget = ( tile_k * tile_n * (acc_bytes + num_buffers * out_bytes) + + ps_bytes + (num_buffers + 1) * (tile_m * tile_k * lhs_bytes) + num_buffers * (tile_m * tile_n * rhs_bytes) # Reserve VMEM for zero_ref. Use the upper bound target_zero_ref_bytes @@ -179,6 +183,7 @@ def make_tgmm_configs( lhs: jax.Array, # [m, k] rhs: jax.Array, # [m, n] rhs_scale: jax.Array, # [1, 1, n] (per-N scale) + partial_sum: jax.Array | None, group_sizes: jax.Array, num_actual_groups: int, *, @@ -256,6 +261,7 @@ def make_tgmm_configs( out_dtype, acc_dtype, target_zero_ref_bytes, + partial_sum is not None, ) return gmm_v2.GmmConfigs( @@ -263,7 +269,7 @@ def make_tgmm_configs( tiles=tiles, lhs_cfgs=lhs_cfgs, rhs_cfgs=rhs_cfgs, - has_partial_sum=False, # This should always be False until partial sum support is added in bwd pass. + has_partial_sum=(partial_sum is not None), out_dtype=jnp.dtype(out_dtype), acc_dtype=jnp.dtype(acc_dtype), # GMM's 'zero_init' zeros unvisited m-rows via DMA, which doesn't apply to @@ -280,6 +286,7 @@ def tgmm_inner_kernel( tiled_rhs_ref: OperandRef, # .value: [tile_m // size_lhs_sublane, size_lhs_sublane, tile_n] # .scale: [1, 1, tile_n] or None + tiled_ps_ref: jax.Array | None, tiled_out_ref: jax.Array, acc_ref: jax.Array, metadata_ref: gmm_v2.MetadataRef, @@ -342,6 +349,8 @@ def _matmul(is_new_group: bool, is_group_changing: bool): if cfgs.rhs_cfgs.has_scale: scale_slice = tiled_rhs_scale_ref[0] acc *= scale_slice + if cfgs.has_partial_sum: + acc += tiled_ps_ref[...].astype(acc.dtype) tiled_out_ref[...] = acc.astype(tiled_out_ref.dtype) else: acc_ref[...] = acc @@ -430,7 +439,7 @@ def out_index_map(self, n_id: jax.Array, k_id: jax.Array, gm_id: jax.Array): def generate_tgmm_block_specs( metadata_ref: gmm_v2.MetadataRef, cfgs: gmm_v2.GmmConfigs -) -> Tuple[Tuple[pl.BlockSpec, OperandRef], pl.BlockSpec]: +) -> Tuple[Tuple[pl.BlockSpec, ...], pl.BlockSpec]: """Generates block specs for the given lhs, rhs, and out refs.""" index_map = TgmmIndexMaps(metadata_ref, cfgs) # NB: in tgmm, LHS is reshaped from (M, K) to (-1, size_lhs_sublane, K) so @@ -457,8 +466,14 @@ def generate_tgmm_block_specs( (None, cfgs.tiles.tile_k, cfgs.tiles.tile_n), index_map.out_index_map, ) - - return (lhs_block_spec, rhs_spec), out_block_spec + ps_block_spec = None + if cfgs.has_partial_sum: + ps_block_spec = pl.BlockSpec( + (None, cfgs.tiles.tile_k, cfgs.tiles.tile_n), + index_map.out_index_map, + ) + in_specs = (lhs_block_spec, rhs_spec, ps_block_spec) + return in_specs, out_block_spec def zero_out_start( @@ -526,6 +541,7 @@ def tgmm_kernel_main( group_offset_ref, # int32[1] lhs_ref, # [m, k] rhs_ref, # OperandRef: .value [m, n], .scale [1, 1, n] or None + partial_sum_ref, # [num_actual_groups, k, n] or None out_ref, # [num_actual_groups, k, n] # Scratch memory acc_ref: jax.Array, # [tile_k, tile_n] @@ -578,9 +594,12 @@ def tgmm_kernel_main( rhs_value = rhs_ref.value rhs_in = rhs_value.reshape(-1, cfgs.dims.size_lhs_sublane, rhs_value.shape[-1]) rhs_operand = OperandRef(value=rhs_in, scale=rhs_ref.scale) + ps_in = None + if cfgs.has_partial_sum: + ps_in = partial_sum_ref scratches = [acc_ref, metadata_ref] - pipeline_fn(lhs_in, rhs_operand, out_ref, scratches=scratches) + pipeline_fn(lhs_in, rhs_operand, ps_in, out_ref, scratches=scratches) zero_out_end( num_groups_to_zero, out_ref, @@ -632,6 +651,7 @@ def tgmm_v2( group_sizes: jax.Array, num_actual_groups: int, rhs_scale: jax.Array | None = None, # [1, 1, size_n] (per-N scale) + partial_sum: jax.Array | None = None, group_offset: jax.Array | None = None, *, tile_info: gmm_v2.TileSizes | TileTgmmFn = calculate_tgmm_tiling, @@ -683,6 +703,7 @@ def tgmm_v2( lhs, rhs, rhs_scale, + partial_sum, group_sizes, num_actual_groups, tile_info=tile_info, @@ -731,14 +752,18 @@ def tgmm_v2( rhs_scale = jnp.pad(rhs_scale, ((0, 0), (0, 0), (0, pad_n))) rhs = OperandRef(value=rhs, scale=rhs_scale) hbm_spec = pl.BlockSpec(memory_space=pltpu.HBM) + partial_sum_spec = None + if partial_sum is not None: + partial_sum_spec = hbm_spec in_specs = [ hbm_spec, # lhs # the tree.map build a # OperandRef(value=hbm_spec, scale=None if scale is None else hbm_spec. jax.tree.map(lambda _: hbm_spec, rhs), # rhs + partial_sum_spec, ] - return pl.pallas_call( + raw_out = pl.pallas_call( functools.partial(tgmm_kernel_main, cfgs=cfgs), out_shape=out_init, grid_spec=pltpu.PrefetchScalarGridSpec( @@ -756,4 +781,10 @@ def tgmm_v2( # the metadata here is for profiling, debugging, and cost modeling. # It does not affect the kernel's computation. metadata=gmm_v2.get_metadata(cfgs), - )(group_sizes, group_offset, lhs, rhs)[:, : dims.size_k, : dims.size_n] + )(group_sizes, group_offset, lhs, rhs, partial_sum)[:, : dims.size_k, : dims.size_n] + + if partial_sum is not None: + local_group_sizes = lax.dynamic_slice(group_sizes, (group_offset[0],), (num_actual_groups,)) + empty_mask = (local_group_sizes == 0).reshape(num_actual_groups, 1, 1) + return jnp.where(empty_mask, partial_sum, raw_out) + return raw_out diff --git a/tests/unit/pallas_mosaic_tpu_v2_kernel_test.py b/tests/unit/pallas_mosaic_tpu_v2_kernel_test.py index aeeec49533..15f1018b94 100644 --- a/tests/unit/pallas_mosaic_tpu_v2_kernel_test.py +++ b/tests/unit/pallas_mosaic_tpu_v2_kernel_test.py @@ -161,6 +161,7 @@ def reference_tgmm( # group_offset is obtained from # jnp.arange(0, num_experts, num_experts_per_shard) group_offset=None, + partial_sum=None, ): # [num_groups, k, n] """Computes reference transposed grouped matrix multiplication.""" # Compute lhs[:, sizes[i-1]:sizes[i]] @ rhs[sizes[i-1]:sizes[i], :] @@ -178,7 +179,10 @@ def reference_tgmm( group = global_group - group_offset[0] end = start + group_size if 0 <= group < num_actual_groups: - out.append(lhs[:, start:end] @ rhs[start:end, :]) + res = lhs[:, start:end].astype(jnp.float32) @ rhs[start:end, :].astype(jnp.float32) + if partial_sum is not None: + res = res + partial_sum[group].astype(jnp.float32) + out.append(res.astype(lhs.dtype)) start = end return jnp.stack(out) @@ -257,12 +261,13 @@ def test_gmm_basic(self, batch_size, in_size, out_size, num_groups, has_bias, ha in_size=[512, 1024], out_size=[512, 1024], num_groups=[5, 16, 32], + has_partial_sum=[True, False], group_offset=[0, 2, 3], ) - def test_tgmm_basic(self, batch_size, in_size, out_size, num_groups, group_offset): + def test_tgmm_basic(self, batch_size, in_size, out_size, num_groups, has_partial_sum, group_offset): num_local_groups = num_groups - group_offset key = jax.random.key(0) - key1, key2 = jax.random.split(key, 2) + key1, key2, key3 = jax.random.split(key, 3) lhs = jax.random.normal(key1, (batch_size, in_size), dtype=jnp.bfloat16) # [m, k] grad = jax.random.normal(key2, (batch_size, out_size), dtype=jnp.bfloat16) # [m, n] group_sizes = get_group_sizes(batch_size, num_groups) @@ -270,13 +275,18 @@ def test_tgmm_basic(self, batch_size, in_size, out_size, num_groups, group_offse # group_sizes=Array([14, 14, ..., 7]). group_offset = jnp.array(group_offset, dtype=jnp.int32) + ps = None + if has_partial_sum: + ps = jax.random.normal(key3, (num_local_groups, in_size, out_size), dtype=jnp.bfloat16) + lhs_t = lhs.swapaxes(0, 1) # [k, m] - expected = reference_tgmm(lhs_t, grad, group_sizes, num_local_groups, group_offset=group_offset) + expected = reference_tgmm(lhs_t, grad, group_sizes, num_local_groups, group_offset=group_offset, partial_sum=ps) actual = tgmm_backend.tgmm_v2( lhs, grad, group_sizes, num_local_groups, + partial_sum=ps, group_offset=group_offset, preferred_element_type=jnp.bfloat16, )