Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 71 additions & 18 deletions src/maxtext/kernels/megablox/pallas_mosaic_tpu_v2_gmm_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,7 @@ class GmmConfigs:
rhs_cfgs: InputConfigs
out_dtype: jnp.dtype
acc_dtype: jnp.dtype
has_partial_sum: bool
zero_init: bool
fuse_act: str | None

Expand Down Expand Up @@ -270,6 +271,16 @@ def out_index_map(self, n_id: jax.Array, gm_id: jax.Array, _: jax.Array):

return (pl.ds(row_start, row_size), 0, n_id)

def ps_index_map(self, n_id: jax.Array, gm_id: jax.Array, _: jax.Array):
m_start = self.metadata_ref.gm_id_to_m_offset[gm_id]
m_end = self.metadata_ref.gm_id_to_m_offset[gm_id + 1]

row_start = m_start // self.cfgs.dims.size_lhs_sublane
row_end = pl.cdiv(m_end, self.cfgs.dims.size_lhs_sublane)
row_size = row_end - row_start

return (pl.ds(row_start, row_size), 0, n_id)


def generate_block_specs(
metadata_ref: MetadataRef, cfgs: GmmConfigs
Expand All @@ -294,7 +305,7 @@ def generate_block_specs(
index_map.rhs_weight_index_map,
pipeline_mode=pl.Buffered(buffer_count=3),
)
rhs_scale_block_spec = rhs_bias_block_spec = None
rhs_scale_block_spec = rhs_bias_block_spec = ps_block_spec = None
if cfgs.rhs_cfgs.has_bias:
rhs_bias_block_spec = pl.BlockSpec(
(None, 1, cfgs.tiles.tile_n),
Expand All @@ -306,6 +317,12 @@ def generate_block_specs(
index_map.rhs_scale_index_map,
)

if cfgs.has_partial_sum:
ps_block_spec = pl.BlockSpec(
(bounded_slice_gm, cfgs.dims.size_lhs_sublane, cfgs.tiles.tile_n),
index_map.ps_index_map,
)

rhs_block_spec = WeightsRef(
weight=rhs_weight_spec,
scale=rhs_scale_block_spec,
Expand All @@ -317,7 +334,7 @@ def generate_block_specs(
index_map.out_index_map,
)

return (lhs_block_spec, rhs_block_spec), out_block_spec
return (lhs_block_spec, rhs_block_spec, ps_block_spec), out_block_spec


# Define kernels.
Expand All @@ -328,6 +345,9 @@ def inner_kernel(
tiled_lhs_ref: jax.Array,
# [tile_m // size_lhs_sublane, size_lhs_sublane, tile_k]
tiled_rhs_ref: RhsRef, # [tile_k, tile_n]
# Partial Sum
tiled_ps_ref: jax.Array | None,
# [tile_m // size_lhs_sublane, size_lhs_sublane, tile_n]
# Out
tiled_out_ref: jax.Array,
# [tile_m // size_lhs_sublane, size_lhs_sublane, tile_n]
Expand Down Expand Up @@ -489,6 +509,9 @@ def _matmul(is_first_k_step: bool, is_last_k_step: bool):
if cfgs.rhs_cfgs.has_bias:
tiled_rhs_bias = tiled_rhs_ref.get_bias()
acc += tiled_rhs_bias.astype(acc.dtype)
if cfgs.has_partial_sum:
ps_tile = tiled_ps_ref[...].reshape(acc.shape)
acc += ps_tile.astype(acc.dtype)

acc = apply_act_fn(acc, cfgs.fuse_act)

Expand Down Expand Up @@ -746,6 +769,7 @@ def kernel_main(
# In
lhs_ref: jax.Array, # [size_m, size_k]
rhs_ref: WeightsRef, # [size_group, size_k, size_n]
partial_sum_ref: jax.Array, # [size_m, size_n]
# Out
out_ref: jax.Array, # [size_m, size_n]
# Scratch memory
Expand Down Expand Up @@ -812,7 +836,7 @@ def kernel_main(
dims=cfgs.dims,
)

(lhs_spec, rhs_spec), out_spec = generate_block_specs(metadata_ref, cfgs)
(lhs_spec, rhs_spec, ps_spec), out_spec = generate_block_specs(metadata_ref, cfgs)

if cfgs.fuse_act is not None:
rhs_up_ref = jax.tree.map(lambda x: x.at[..., cfgs.out_size_n :], rhs_ref)
Expand All @@ -827,16 +851,19 @@ def kernel_main(
pipeline_fn = pltpu.emit_pipeline(
functools.partial(inner_kernel, cfgs=cfgs),
grid=(num_n, num_gm, num_k),
in_specs=(lhs_spec, rhs_spec),
in_specs=(lhs_spec, rhs_spec, ps_spec),
out_specs=out_spec,
)

# Bounded slice requires second last dim to be aligned to the sublane size.
# rhs_ref uses static tiling thus reshape is not needed.
lhs_in = lhs_ref.reshape(-1, cfgs.dims.size_lhs_sublane, lhs_ref.shape[-1])
ps_in = None
if cfgs.has_partial_sum:
ps_in = partial_sum_ref.reshape(-1, cfgs.dims.size_lhs_sublane, partial_sum_ref.shape[-1])
out_in = out_ref.reshape(-1, cfgs.dims.size_lhs_sublane, out_ref.shape[-1])
scratches = [partial_out_ref, acc_ref, metadata_ref]
pipeline_fn(lhs_in, rhs_ref, out_in, scratches=scratches)
pipeline_fn(lhs_in, rhs_ref, ps_in, out_in, scratches=scratches)

if cfgs.zero_init:
zero_out_end(out_ref, semaphore_ref, zero_size, dims=cfgs.dims)
Expand All @@ -848,6 +875,7 @@ def calculate_tiling(
rhs_cfgs: InputConfigs,
vmem_limit_bytes: int,
fuse_act: str | None = None,
has_partial_sum: bool = False,
) -> TileSizes:
"""Calculate optimal tile sizes for GMM kernel."""

Expand Down Expand Up @@ -914,11 +942,13 @@ def _gmm_vmem_estimate(tn: int, tk: int) -> int:
acc_dtype_bytes = 2 if lhs_cfgs.quant_dtype is not None else 4
acc_vmem = tile_m * acc_cols * acc_dtype_bytes

# 4. Output tile (double-buffered)
# 4. Output tile (double-buffered) and partial sum buffer
out_dtype_bytes = jax.dtypes.itemsize_bits(lhs_cfgs.dtype) // 8
out_vmem = 2 * tile_m * tn * out_dtype_bytes
ps_vmem = 2 * tile_m * tn * out_dtype_bytes if has_partial_sum else 0
partial_out_vmem = dims.size_lhs_sublane * tn * out_dtype_bytes

return lhs_vmem + rhs_vmem + acc_vmem + out_vmem
return lhs_vmem + rhs_vmem + acc_vmem + out_vmem + ps_vmem + partial_out_vmem

# Multiple k tiles will introduce accumulation overhead. Thus, we first try
# to fit the tensors into vmem by only adjusting tile_n.
Expand Down Expand Up @@ -952,6 +982,7 @@ def validate_inputs(
rhs: jax.Array,
rhs_scale: jax.Array | None,
rhs_bias: jax.Array | None,
partial_sum: jax.Array | None,
group_sizes: jax.Array,
group_offset: jax.Array,
fuse_act: str | None = None,
Expand All @@ -967,6 +998,8 @@ def validate_inputs(
assert rhs.shape == (size_group, size_k, size_n)
if rhs_bias is not None:
assert rhs_bias.shape == (size_group, 1, size_n)
if partial_sum is not None:
assert partial_sum.shape == (size_m, size_n)
if rhs_scale is not None:
num_quant_blocks = rhs_scale.shape[1]
assert rhs_scale.shape == (size_group, num_quant_blocks, 1, size_n)
Expand Down Expand Up @@ -1043,6 +1076,7 @@ def make_gmm_configs(
rhs: jax.Array,
rhs_scale: jax.Array | None,
rhs_bias: jax.Array | None,
partial_sum: jax.Array | None,
group_sizes: jax.Array,
group_offset: jax.Array,
*,
Expand All @@ -1056,7 +1090,7 @@ def make_gmm_configs(
):
"""Fills the GMM config for the GMM kernel."""

dims = validate_inputs(lhs, rhs, rhs_scale, rhs_bias, group_sizes, group_offset, fuse_act)
dims = validate_inputs(lhs, rhs, rhs_scale, rhs_bias, partial_sum, group_sizes, group_offset, fuse_act)

if rhs_scale is not None:
has_scale = True
Expand Down Expand Up @@ -1118,7 +1152,7 @@ def make_gmm_configs(
if isinstance(tile_info, TileSizes):
tiles = tile_info
else:
tiles = tile_info(dims, lhs_cfgs, rhs_cfgs, vmem_limit_bytes, fuse_act)
tiles = tile_info(dims, lhs_cfgs, rhs_cfgs, vmem_limit_bytes, fuse_act, partial_sum is not None)

return GmmConfigs(
dims=dims,
Expand All @@ -1127,6 +1161,7 @@ def make_gmm_configs(
rhs_cfgs=rhs_cfgs,
out_dtype=jnp.dtype(out_dtype),
acc_dtype=jnp.dtype(acc_dtype),
has_partial_sum=partial_sum is not None,
zero_init=zero_initialize,
fuse_act=fuse_act,
)
Expand Down Expand Up @@ -1161,6 +1196,7 @@ def gmm_v2(
group_sizes: jax.Array, # int32[size_lhs_group]
rhs_scale: jax.Array | None = None, # [size_group, num_blocks, 1, out_size]
rhs_bias: jax.Array | None = None, # [size_group, 1, out_size]
partial_sum: jax.Array | None = None, # [size_m, size_n]
group_offset: jax.Array | None = None, # int32[1]
*,
tile_info: TileSizes | TileFn = calculate_tiling,
Expand All @@ -1184,6 +1220,7 @@ def gmm_v2(
group_sizes: The group sizes of lhs rows of shape [size_lhs_group,].
rhs_scale: The rhs scale of shape [size_group, num_blocks, 1, out_size].
rhs_bias: The rhs bias of shape [size_group, 1, out_size].
partial_sum: Optional. Per-token partial sums of shape [size_m, size_n].
group_offset: Optional. The group offset of shape [1,].
tile_info: The tile sizes or tile function to use.
vmem_limit_bytes: Optional vmem limit in bytes.
Expand Down Expand Up @@ -1214,6 +1251,7 @@ def gmm_v2(
rhs,
rhs_scale,
rhs_bias,
partial_sum,
group_sizes,
group_offset,
tile_info=tile_info,
Expand Down Expand Up @@ -1280,20 +1318,35 @@ def gmm_v2(
aligned_n = align_to(cfgs.out_size_n, num_lanes)
out_init = jax.ShapeDtypeStruct((dims.size_m, aligned_n), cfgs.out_dtype)
rhs_weights = WeightsRef(weight=rhs, scale=rhs_scale, bias=rhs_bias)
in_specs = [
pl.BlockSpec(memory_space=pltpu.HBM),
WeightsRef(
weight=pl.BlockSpec(memory_space=pltpu.HBM),
scale=rhs_scale_spec,
bias=rhs_bias_spec,
),
]

partial_sum_spec = None
if partial_sum is not None:
in_specs.append(pl.BlockSpec(memory_space=pltpu.HBM))
partial_sum_spec = pl.BlockSpec(memory_space=pltpu.HBM)
in_specs = [
pl.BlockSpec(memory_space=pltpu.HBM), # lhs
WeightsRef(
weight=pl.BlockSpec(memory_space=pltpu.HBM),
scale=rhs_scale_spec,
bias=rhs_bias_spec,
), # rhs_weights
partial_sum_spec, # partial_sum
]

return pl.pallas_call(
functools.partial(kernel_main, cfgs=cfgs),
out_shape=out_init,
grid_spec=pltpu.PrefetchScalarGridSpec(
num_scalar_prefetch=2,
in_specs=[
pl.BlockSpec(memory_space=pltpu.HBM),
WeightsRef(
weight=pl.BlockSpec(memory_space=pltpu.HBM),
scale=rhs_scale_spec,
bias=rhs_bias_spec,
),
],
in_specs=in_specs,
out_specs=pl.BlockSpec(memory_space=pltpu.HBM),
scratch_shapes=scratch_shapes,
),
Expand All @@ -1304,4 +1357,4 @@ def gmm_v2(
name=get_scope_name(cfgs),
cost_estimate=get_cost_estimate(cfgs),
metadata=get_metadata(cfgs),
)(group_sizes, group_offset, lhs, rhs_weights)[:, : cfgs.out_size_n]
)(group_sizes, group_offset, lhs, rhs_weights, partial_sum)[:, : cfgs.out_size_n]
Loading
Loading