From de9589c6e30b34119bef18dad8619f846328725a Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Wed, 29 Oct 2025 00:33:57 -0700 Subject: [PATCH 01/14] Adding a benchmark for gemm with input/output bf16 with fp32 accum. --- configs/gemm_bf16_simple.yaml | 8 ++++ src/benchmark_gemm.py | 76 ++++++++++++++++++++++++++++++++++- 2 files changed, 83 insertions(+), 1 deletion(-) create mode 100644 configs/gemm_bf16_simple.yaml diff --git a/configs/gemm_bf16_simple.yaml b/configs/gemm_bf16_simple.yaml new file mode 100644 index 0000000..cf8cdbd --- /dev/null +++ b/configs/gemm_bf16_simple.yaml @@ -0,0 +1,8 @@ +benchmarks: +- benchmark_name: "gemm_bf16_simple" + trace_dir: "../microbenchmarks/gemm_bf16_simple" + csv_path: "../microbenchmarks/gemm_bf16_simple" + xlml_metrics_dir: "../microbenchmarks/gemm_bf16_simple" + num_runs: 1000 + benchmark_sweep_params: + - {m: {start: 1, end: 65536, multiplier: 2}, k: {start: 1, end: 65536, multiplier: 2}, n: {start: 1, end: 65536, multiplier: 2}} diff --git a/src/benchmark_gemm.py b/src/benchmark_gemm.py index 14b64a5..fa6c887 100644 --- a/src/benchmark_gemm.py +++ b/src/benchmark_gemm.py @@ -240,6 +240,80 @@ def gemm_simple_calculate_metrics( total_flops = total_flops // jax.device_count() return unified_flops_metrics(m, n, k, time_ms_list, total_flops, total_flops_all_devices, PEAK_FLOPS_PER_DEVICE*2) +def gemm_bf16_simple( + m: int, k: int, n: int, num_runs: int = 1, trace_dir: str = None +) -> Dict[str, Any]: + """Benchmarks the OUT:BF16 = IN0:BF16 x IN1:BF16. Accumulation is FP32.""" + + def f(x, y): + with jax.named_scope(MARKER): + # Keep accumulation in F)32 for precision + acc = jax.numpy.einsum("ij,jk->ik", x, y, preferred_element_type=jnp.float32) + # Output is BF16 + return acc.astype(jnp.bfloat16) + + mesh = create_mesh() + rhs_sharding = NamedSharding(mesh, P(None, None)) + if WITH_SHARDING: + lhs_sharding = NamedSharding(mesh, P("i", None)) + out_sharding = P("i", None) + else: + lhs_sharding = NamedSharding(mesh, P(None, None)) + out_sharding = P(None, None) + + jit_sharded_f = jax.jit( + shard_map( + f, + mesh, + in_specs=(lhs_sharding.spec, rhs_sharding.spec), + out_specs=out_sharding, + check_rep=False, + ) + ) + + lhs_shape = (m, k) + rhs_shape = (k, n) + lhs_dtype = jnp.bfloat16 + rhs_dtype = jnp.bfloat16 + + key = jax.random.key(SEED) + + def data_generator(): + """Creates new random data on host and puts it on device.""" + nonlocal key # Use and update the outer 'key' + key, key_lhs, key_rhs = jax.random.split(key, 3) + + # Create random data on host + lhs_host = jax.random.normal(key_lhs, lhs_shape).astype(lhs_dtype) + rhs_host = jax.random.normal(key_rhs, rhs_shape).astype(rhs_dtype) + + # Put on device (HBM) + lhs_device = jax.device_put(lhs_host, lhs_sharding) + rhs_device = jax.device_put(rhs_host, rhs_sharding) + + return (lhs_device, rhs_device) + + # Run the benchmark + time_ms_list = iteration_timeit( + jit_sharded_f, + data_generator, + matrix_dim=f"{m}x{n}x{k}", + tries=num_runs, + task="gemm_bf16_simple", + trace_dir=trace_dir, + ) + return {"time_ms_list": time_ms_list} + +def gemm_bf16_simple_calculate_metrics( + m: int, k: int, n: int, time_ms_list: list[float] +) -> Dict[str, Any]: + # Calculate FLOPs + total_flops = (2 * m * k * n) - (m * n) # Total floating-point operations + total_flops_all_devices = total_flops + if WITH_SHARDING: + total_flops = total_flops // jax.device_count() + return unified_flops_metrics(m, n, k, time_ms_list, total_flops, total_flops_all_devices, PEAK_FLOPS_PER_DEVICE*2) + def gemm( m: int, k: int, n: int, num_runs: int = 1, trace_dir: str = None ) -> Dict[str, Any]: @@ -823,4 +897,4 @@ def add_calculate_metrics( total_bytes_all_device = total_bytes if WITH_SHARDING: total_bytes = total_bytes // jax.device_count() - return unified_bytes_metrics(m, n, time_ms_list, total_bytes, total_bytes_all_device) \ No newline at end of file + return unified_bytes_metrics(m, n, time_ms_list, total_bytes, total_bytes_all_device) From 26225de18a01e075b3e47bf809b079476b9d6805 Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Thu, 30 Oct 2025 11:20:12 -0700 Subject: [PATCH 02/14] Fixing the total_flops calculations formula and adding the rescale yamls --- configs/gemm_rescale_inference.yaml | 8 ++++++++ configs/gemm_simple_bf16_inference.yaml | 8 ++++++++ configs/gemm_simple_inference.yaml | 8 ++++++++ src/benchmark_gemm.py | 4 ++-- src/run_benchmark.py | 1 + 5 files changed, 27 insertions(+), 2 deletions(-) create mode 100644 configs/gemm_rescale_inference.yaml create mode 100644 configs/gemm_simple_bf16_inference.yaml create mode 100644 configs/gemm_simple_inference.yaml diff --git a/configs/gemm_rescale_inference.yaml b/configs/gemm_rescale_inference.yaml new file mode 100644 index 0000000..9e2a31d --- /dev/null +++ b/configs/gemm_rescale_inference.yaml @@ -0,0 +1,8 @@ +benchmarks: +- benchmark_name: "gemm" + trace_dir: "../microbenchmarks/gemm_rescale_inference" + csv_path: "../microbenchmarks/gemm_rescale_inference" + xlml_metrics_dir: "../microbenchmarks/gemm_rescale_inference" + num_runs: 1000 + benchmark_sweep_params: + - {m: {start: 512, end: 65536, multiplier: 2}, k: {start: 512, end: 65536, multiplier: 2}, n: {start: 512, end: 65536, multiplier: 2}} diff --git a/configs/gemm_simple_bf16_inference.yaml b/configs/gemm_simple_bf16_inference.yaml new file mode 100644 index 0000000..18e6e81 --- /dev/null +++ b/configs/gemm_simple_bf16_inference.yaml @@ -0,0 +1,8 @@ +benchmarks: +- benchmark_name: "gemm_bf16_simple" + trace_dir: "../microbenchmarks/gemm_simple_bf16_inference" + csv_path: "../microbenchmarks/gemm_simple_bf16_inference" + xlml_metrics_dir: "../microbenchmarks/gemm_simple_bf16_inference" + num_runs: 1000 + benchmark_sweep_params: + - {m: {start: 512, end: 65536, multiplier: 2}, k: {start: 512, end: 65536, multiplier: 2}, n: {start: 512, end: 65536, multiplier: 2}} diff --git a/configs/gemm_simple_inference.yaml b/configs/gemm_simple_inference.yaml new file mode 100644 index 0000000..3af7b1f --- /dev/null +++ b/configs/gemm_simple_inference.yaml @@ -0,0 +1,8 @@ +benchmarks: +- benchmark_name: "gemm_simple" + trace_dir: "../microbenchmarks/gemm_simple_inference" + csv_path: "../microbenchmarks/gemm_simple_inference" + xlml_metrics_dir: "../microbenchmarks/gemm_simple_inference" + num_runs: 1000 + benchmark_sweep_params: + - {m: {start: 512, end: 65536, multiplier: 2}, k: {start: 512, end: 65536, multiplier: 2}, n: {start: 512, end: 65536, multiplier: 2}} diff --git a/src/benchmark_gemm.py b/src/benchmark_gemm.py index fa6c887..20b9664 100644 --- a/src/benchmark_gemm.py +++ b/src/benchmark_gemm.py @@ -234,7 +234,7 @@ def gemm_simple_calculate_metrics( m: int, k: int, n: int, time_ms_list: list[float] ) -> Dict[str, Any]: # Calculate FLOPs - total_flops = 2 * m * k * n # Total floating-point operations + total_flops = (2 * k - 1) * m * n # Total floating-point operations total_flops_all_devices = total_flops if WITH_SHARDING: total_flops = total_flops // jax.device_count() @@ -393,7 +393,7 @@ def gemm_calculate_metrics( m: int, k: int, n: int, time_ms_list: list[float] ) -> Dict[str, Any]: # Calculate FLOPs - total_flops = 2 * m * k * n # Total floating-point operations + total_flops = (2 * k + 1) * m * n # Total floating-point operations total_flops_all_devices = total_flops if WITH_SHARDING: total_flops = total_flops // jax.device_count() diff --git a/src/run_benchmark.py b/src/run_benchmark.py index fb114d3..b5c294c 100644 --- a/src/run_benchmark.py +++ b/src/run_benchmark.py @@ -62,6 +62,7 @@ } GEMM_BENCHMARK_MAP = { "gemm_simple":"benchmark_gemm.gemm_simple", + "gemm_bf16_simple":"benchmark_gemm.gemm_bf16_simple", "gemm": "benchmark_gemm.gemm", "gemm_accum": "benchmark_gemm.gemm_accum", "quantization": "benchmark_gemm.quantization", From 79bd4c0c47efbb8d3cbb580231807353ad5d439e Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Thu, 30 Oct 2025 11:21:24 -0700 Subject: [PATCH 03/14] Removing a yaml. --- configs/gemm_bf16_simple.yaml | 8 -------- 1 file changed, 8 deletions(-) delete mode 100644 configs/gemm_bf16_simple.yaml diff --git a/configs/gemm_bf16_simple.yaml b/configs/gemm_bf16_simple.yaml deleted file mode 100644 index cf8cdbd..0000000 --- a/configs/gemm_bf16_simple.yaml +++ /dev/null @@ -1,8 +0,0 @@ -benchmarks: -- benchmark_name: "gemm_bf16_simple" - trace_dir: "../microbenchmarks/gemm_bf16_simple" - csv_path: "../microbenchmarks/gemm_bf16_simple" - xlml_metrics_dir: "../microbenchmarks/gemm_bf16_simple" - num_runs: 1000 - benchmark_sweep_params: - - {m: {start: 1, end: 65536, multiplier: 2}, k: {start: 1, end: 65536, multiplier: 2}, n: {start: 1, end: 65536, multiplier: 2}} From 0fef570e128d9e794c05674fdd3df9af5486a9df Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Thu, 30 Oct 2025 13:34:37 -0700 Subject: [PATCH 04/14] Making change to gemm_simple to take in_dtype as input. Also added gemm_batched_simple benchmark. --- configs/gemm_simple_bf16_inference.yaml | 8 -------- configs/gemm_simple_inference.yaml | 3 ++- 2 files changed, 2 insertions(+), 9 deletions(-) delete mode 100644 configs/gemm_simple_bf16_inference.yaml diff --git a/configs/gemm_simple_bf16_inference.yaml b/configs/gemm_simple_bf16_inference.yaml deleted file mode 100644 index 18e6e81..0000000 --- a/configs/gemm_simple_bf16_inference.yaml +++ /dev/null @@ -1,8 +0,0 @@ -benchmarks: -- benchmark_name: "gemm_bf16_simple" - trace_dir: "../microbenchmarks/gemm_simple_bf16_inference" - csv_path: "../microbenchmarks/gemm_simple_bf16_inference" - xlml_metrics_dir: "../microbenchmarks/gemm_simple_bf16_inference" - num_runs: 1000 - benchmark_sweep_params: - - {m: {start: 512, end: 65536, multiplier: 2}, k: {start: 512, end: 65536, multiplier: 2}, n: {start: 512, end: 65536, multiplier: 2}} diff --git a/configs/gemm_simple_inference.yaml b/configs/gemm_simple_inference.yaml index 3af7b1f..3ec229a 100644 --- a/configs/gemm_simple_inference.yaml +++ b/configs/gemm_simple_inference.yaml @@ -5,4 +5,5 @@ benchmarks: xlml_metrics_dir: "../microbenchmarks/gemm_simple_inference" num_runs: 1000 benchmark_sweep_params: - - {m: {start: 512, end: 65536, multiplier: 2}, k: {start: 512, end: 65536, multiplier: 2}, n: {start: 512, end: 65536, multiplier: 2}} + - {m: {start: 512, end: 65536, multiplier: 2}, k: {start: 512, end: 65536, multiplier: 2}, n: {start: 512, end: 65536, multiplier: 2}, in_dtype_str: "bf16", out_dtype_str: "bf16"} + - {m: {start: 512, end: 65536, multiplier: 2}, k: {start: 512, end: 65536, multiplier: 2}, n: {start: 512, end: 65536, multiplier: 2}, in_dtype_str: "fp8", out_dtype_str: "bf16"} From 879e1b841e2f93be152a4ab7761e8d149cf7290b Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Thu, 30 Oct 2025 13:49:10 -0700 Subject: [PATCH 05/14] Making change to gemm_simple to take in_dtype as input. Also added gemm_batched_simple benchmark. [Part-2] --- src/benchmark_gemm.py | 153 +++++++++++++++++++++++++++++++++--------- src/run_benchmark.py | 2 +- 2 files changed, 122 insertions(+), 33 deletions(-) diff --git a/src/benchmark_gemm.py b/src/benchmark_gemm.py index 97f6ee8..c52ba97 100644 --- a/src/benchmark_gemm.py +++ b/src/benchmark_gemm.py @@ -60,6 +60,20 @@ class ShardingStrategy(Enum): SEED = 0 PEAK_FLOPS_PER_DEVICE=1153.5 # TFLOP/s for single core(device) under p_state=7 + +def str_to_dtype(dtype_str: str) -> jnp.dtype: + """Converts a string identifier to a JAX numpy dtype.""" + if dtype_str.lower() == "fp8": + return jnp.float8_e4m3fn + elif dtype_str.lower() == "bf16": + return jnp.bfloat16 + elif dtype_str.lower() == "fp16": + return jnp.float16 + elif dtype_str.lower() == "fp32": + return jnp.float32 + else: + raise ValueError(f"Unsupported dtype string: {dtype_str}") + def get_lhs_named_shading(mesh): match SHARDING_STRATEGY: case ShardingStrategy.NO_SHARDING: @@ -284,14 +298,21 @@ def unified_bytes_metrics( return metadata, metrics def gemm_simple( - m: int, k: int, n: int, num_runs: int = 1, trace_dir: str = None + m: int, k: int, n: int, + in_dtype_str: str, out_dtype_str: str, + num_runs: int = 1, trace_dir: str = None ) -> Dict[str, Any]: """Benchmarks the OUT:BF16 = IN0:FP8 x IN1:FP8. Accumulation is FP32.""" + # Convert string dtypes to jnp dtypes + lhs_dtype = str_to_dtype(in_dtype_str) + rhs_dtype = str_to_dtype(in_dtype_str) + out_dtype = str_to_dtype(out_dtype_str) + def f(x, y): with jax.named_scope(MARKER): acc = jax.numpy.einsum("ij,jk->ik", x, y, preferred_element_type=jnp.float32) - return acc.astype(jnp.bfloat16) + return acc.astype(out_dtype) mesh = create_mesh() lhs_sharding = get_lhs_named_shading(mesh) @@ -310,8 +331,6 @@ def f(x, y): lhs_shape = (m, k) rhs_shape = (k, n) - lhs_dtype = jnp.float8_e4m3fn - rhs_dtype = jnp.float8_e4m3fn key = jax.random.key(SEED) @@ -336,50 +355,91 @@ def data_generator(): data_generator, matrix_dim=f"{m}x{n}x{k}", tries=num_runs, - task="gemm_simple", + task=f"gemm_simple_{in_dtype_str}_{out_dtype_str}", trace_dir=trace_dir, ) return {"time_ms_list": time_ms_list} def gemm_simple_calculate_metrics( - m: int, k: int, n: int, time_ms_list: list[float] + m: int, k: int, n: int, + in_dtype_str: str, out_dtype_str: str, + time_ms_list: list[float] ) -> Dict[str, Any]: # Calculate FLOPs total_flops = (2 * k - 1) * m * n # Total floating-point operations total_flops, total_flops_all_devices = handle_based_on_sharding(total_flops) - return unified_flops_metrics(m, n, k, time_ms_list, total_flops, total_flops_all_devices, PEAK_FLOPS_PER_DEVICE*2) -def gemm_bf16_simple( - m: int, k: int, n: int, num_runs: int = 1, trace_dir: str = None + # Set peak FLOPS multiplier based on input datatype + if in_dtype_str.lower() == "bf16" or in_dtype_str.lower() == "fp16": + peak_flops_multiplier = 0.5 + elif in_dtype_str.lower() == "fp32": + peak_flops_multiplier = 0.25 + elif in_dtype_str.lower() == "fp8": + peak_flops_multiplier = 1.0 + else + raise RuntimeError(f"{in_dtype_str.lower()} is not supported for setting peak_flops_multiplier." + + metadata, metrics = unified_flops_metrics( + m, n, k, time_ms_list, + total_flops, total_flops_all_devices, + PEAK_FLOPS_PER_DEVICE * peak_flops_multiplier * 2) + + # Add dtype info to metadata for logging + metadata["in_dtype"] = in_dtype_str + metadata["out_dtype"] = out_dtype_str + + return metadata, metrics + +def gemm_batched_simple( + b: int, m: int, k: int, n: int, + in_dtype_str: str, out_dtype_str: str, + num_runs: int = 1, trace_dir: str = None ) -> Dict[str, Any]: - """Benchmarks the OUT:BF16 = IN0:BF16 x IN1:BF16. Accumulation is FP32.""" + """Benchmarks the BATCHED OUT = IN0 x IN1. Accumulation is FP32.""" + + # Convert string dtypes to jnp dtypes + lhs_dtype = str_to_dtype(in_dtype_str) + rhs_dtype = str_to_dtype(in_dtype_str) + out_dtype = str_to_dtype(out_dtype_str) def f(x, y): with jax.named_scope(MARKER): - # Keep accumulation in F)32 for precision - acc = jax.numpy.einsum("ij,jk->ik", x, y, preferred_element_type=jnp.float32) - # Output is BF16 - return acc.astype(jnp.bfloat16) + # Batched matmul: (B, M, K) @ (B, K, N) -> (B, M, N) + acc = jax.numpy.einsum("bij,bjk->bik", x, y, preferred_element_type=jnp.float32) + return acc.astype(out_dtype) mesh = create_mesh() - lhs_sharding = get_lhs_named_shading(mesh) - rhs_sharding = get_rhs_named_shading(mesh) - out_sharding = get_out_sharding() + + # Get the 2D sharding specs from your helper functions + lhs_sharding_2d = get_lhs_named_shading(mesh).spec + rhs_sharding_2d = get_rhs_named_shading(mesh).spec + out_sharding_2d = get_out_sharding() + + # Create new 3D specs by adding 'None' for the batch dimension (dim 0) + # (B, M, K) - sharding from (M, K) + lhs_spec = P(None, *lhs_sharding_2d) + # (B, K, N) - sharding from (K, N) + rhs_spec = P(None, *rhs_sharding_2d) + # (B, M, N) - sharding from (M, N) + out_spec = P(None, *out_sharding_2d) + + # Create the full NamedSharding objects for the data generator + lhs_sharding = NamedSharding(mesh, lhs_spec) + rhs_sharding = NamedSharding(mesh, rhs_spec) jit_sharded_f = jax.jit( shard_map( f, mesh, - in_specs=(lhs_sharding.spec, rhs_sharding.spec), - out_specs=out_sharding, + in_specs=(lhs_spec, rhs_spec), + out_specs=out_spec, check_rep=False, ) ) - lhs_shape = (m, k) - rhs_shape = (k, n) - lhs_dtype = jnp.bfloat16 - rhs_dtype = jnp.bfloat16 + # Add the batch dimension 'b' to the shapes + lhs_shape = (b, m, k) + rhs_shape = (b, k, n) key = jax.random.key(SEED) @@ -388,7 +448,7 @@ def data_generator(): nonlocal key # Use and update the outer 'key' key, key_lhs, key_rhs = jax.random.split(key, 3) - # Create random data on host + # Create random data on host with new 3D shapes lhs_host = jax.random.normal(key_lhs, lhs_shape).astype(lhs_dtype) rhs_host = jax.random.normal(key_rhs, rhs_shape).astype(rhs_dtype) @@ -402,20 +462,49 @@ def data_generator(): time_ms_list = iteration_timeit( jit_sharded_f, data_generator, - matrix_dim=f"{m}x{n}x{k}", + matrix_dim=f"{b}x{m}x{n}x{k}", tries=num_runs, - task="gemm_bf16_simple", + task=f"gemm_batched_simple_{in_dtype_str}_{out_dtype_str}", trace_dir=trace_dir, ) return {"time_ms_list": time_ms_list} -def gemm_bf16_simple_calculate_metrics( - m: int, k: int, n: int, time_ms_list: list[float] +def gemm_batched_simple_calculate_metrics( + b: int, m: int, k: int, n: int, + in_dtype_str: str, out_dtype_str: str, + time_ms_list: list[float] ) -> Dict[str, Any]: - # Calculate FLOPs - total_flops = (2 * k - 1) * m * n) # Total floating-point operations - total_flops, total_flops_all_devices = handle_based_on_sharding(total_flops) - return unified_flops_metrics(m, n, k, time_ms_list, total_flops, total_flops_all_devices, PEAK_FLOPS_PER_DEVICE*2) + + # Calculate FLOPs for the *entire batch* + total_flops_base = b * (2 * m * k * n) + + # Get per-device and all-device FLOPS based on sharding strategy + total_flops_per_device, total_flops_all_devices = handle_based_on_sharding(total_flops_base) + + # Set peak FLOPS multiplier based on input datatype + if in_dtype_str.lower() == "bf16" or in_dtype_str.lower() == "fp16": + peak_flops_multiplier = 0.5 + elif in_dtype_str.lower() == "fp32": + peak_flops_multiplier = 0.25 + elif in_dtype_str.lower() == "fp8": + peak_flops_multiplier = 1.0 + else + raise RuntimeError(f"{in_dtype_str.lower()} is not supported for setting peak_flops_multiplier." + + + metadata, metrics = unified_flops_metrics( + m, n, k, time_ms_list, + total_flops_per_device, + total_flops_all_devices, + PEAK_FLOPS_PER_DEVICE * peak_flops_multiplier * 2 + ) + + # Manually add the new parameters to the metadata for logging + metadata["b"] = b + metadata["in_dtype"] = in_dtype_str + metadata["out_dtype"] = out_dtype_str + + return metadata, metrics def gemm( m: int, k: int, n: int, num_runs: int = 1, trace_dir: str = None diff --git a/src/run_benchmark.py b/src/run_benchmark.py index 10f8f22..1c4c099 100644 --- a/src/run_benchmark.py +++ b/src/run_benchmark.py @@ -62,7 +62,7 @@ } GEMM_BENCHMARK_MAP = { "gemm_simple":"benchmark_gemm.gemm_simple", - "gemm_bf16_simple":"benchmark_gemm.gemm_bf16_simple", + "gemm_batched_simple":"benchmark_gemm.gemm_batched_simple", "gemm": "benchmark_gemm.gemm", "gemm_accum": "benchmark_gemm.gemm_accum", "quantization": "benchmark_gemm.quantization", From 23634e1919e72db5d3f4a35e1df6c002bd3dbfc4 Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Thu, 30 Oct 2025 14:55:28 -0700 Subject: [PATCH 06/14] Fixing typos --- src/benchmark_gemm.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/benchmark_gemm.py b/src/benchmark_gemm.py index c52ba97..8c3252f 100644 --- a/src/benchmark_gemm.py +++ b/src/benchmark_gemm.py @@ -376,8 +376,8 @@ def gemm_simple_calculate_metrics( peak_flops_multiplier = 0.25 elif in_dtype_str.lower() == "fp8": peak_flops_multiplier = 1.0 - else - raise RuntimeError(f"{in_dtype_str.lower()} is not supported for setting peak_flops_multiplier." + else: + raise RuntimeError(f"{in_dtype_str.lower()} is not supported for setting peak_flops_multiplier.") metadata, metrics = unified_flops_metrics( m, n, k, time_ms_list, @@ -488,8 +488,8 @@ def gemm_batched_simple_calculate_metrics( peak_flops_multiplier = 0.25 elif in_dtype_str.lower() == "fp8": peak_flops_multiplier = 1.0 - else - raise RuntimeError(f"{in_dtype_str.lower()} is not supported for setting peak_flops_multiplier." + else: + raise RuntimeError(f"{in_dtype_str.lower()} is not supported for setting peak_flops_multiplier.") metadata, metrics = unified_flops_metrics( From bb92b9ea8043817c71fa64255941e82a6b5e974b Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Thu, 30 Oct 2025 15:52:42 -0700 Subject: [PATCH 07/14] Adding yaml for gemm_grouped = gemm_batched. --- configs/gemm_grouped_inference.yaml | 8 ++++++++ 1 file changed, 8 insertions(+) create mode 100644 configs/gemm_grouped_inference.yaml diff --git a/configs/gemm_grouped_inference.yaml b/configs/gemm_grouped_inference.yaml new file mode 100644 index 0000000..f339f08 --- /dev/null +++ b/configs/gemm_grouped_inference.yaml @@ -0,0 +1,8 @@ +benchmarks: +- benchmark_name: "gemm_batched_simple" + trace_dir: "../microbenchmarks/gemm_batched_simple_inference" + csv_path: "../microbenchmarks/gemm_batched_simple_inference" + xlml_metrics_dir: "../microbenchmarks/gemm_batched_simple_inference" + num_runs: 1000 + benchmark_sweep_params: + - {b: {start: 4, end: 256, multiplier: 2}, m: {start: 512, end: 65536, multiplier: 2}, k: {start: 512, end: 65536, multiplier: 2}, n: {start: 512, end: 65536, multiplier: 2}, in_dtype_str: "bf16", out_dtype_str: "bf16"} From 9600a509551ffe6319ce3484a3ebcf88a2cb3e21 Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Fri, 31 Oct 2025 07:18:27 -0700 Subject: [PATCH 08/14] Adding gemm_batched for grouped gemm scaled matrix multiplication benchmark. --- configs/gemm_grouped_inference.yaml | 3 +- configs/gemm_grouped_rescale_inference.yaml | 8 + src/benchmark_gemm.py | 168 +++++++++++++++++--- src/run_benchmark.py | 1 + 4 files changed, 160 insertions(+), 20 deletions(-) create mode 100644 configs/gemm_grouped_rescale_inference.yaml diff --git a/configs/gemm_grouped_inference.yaml b/configs/gemm_grouped_inference.yaml index f339f08..43e84c3 100644 --- a/configs/gemm_grouped_inference.yaml +++ b/configs/gemm_grouped_inference.yaml @@ -5,4 +5,5 @@ benchmarks: xlml_metrics_dir: "../microbenchmarks/gemm_batched_simple_inference" num_runs: 1000 benchmark_sweep_params: - - {b: {start: 4, end: 256, multiplier: 2}, m: {start: 512, end: 65536, multiplier: 2}, k: {start: 512, end: 65536, multiplier: 2}, n: {start: 512, end: 65536, multiplier: 2}, in_dtype_str: "bf16", out_dtype_str: "bf16"} + - {b: {start: 4, end: 256, multiplier: 2}, m: {start: 256, end: 2048, multiplier: 2}, k: {start: 256, end: 2048, multiplier: 2}, n: {start: 256, end: 2048, multiplier: 2}, in_dtype_str: "bf16", out_dtype_str: "bf16"} + - {b: {start: 4, end: 256, multiplier: 2}, m: {start: 256, end: 2048, multiplier: 2}, k: {start: 256, end: 2048, multiplier: 2}, n: {start: 256, end: 2048, multiplier: 2}, in_dtype_str: "fp8", out_dtype_str: "bf16"} diff --git a/configs/gemm_grouped_rescale_inference.yaml b/configs/gemm_grouped_rescale_inference.yaml new file mode 100644 index 0000000..195e164 --- /dev/null +++ b/configs/gemm_grouped_rescale_inference.yaml @@ -0,0 +1,8 @@ +benchmarks: +- benchmark_name: "gemm_batched" + trace_dir: "../microbenchmarks/gemm_batched_rescale_inference" + csv_path: "../microbenchmarks/gemm_batched_rescale_inference" + xlml_metrics_dir: "../microbenchmarks/gemm_batched_rescale_inference" + num_runs: 1000 + benchmark_sweep_params: + - {b: {start: 4, end: 256, multiplier: 2}, m: {start: 56, end: 2048, multiplier: 2}, k: {start: 256, end: 2048, multiplier: 2}, n: {start: 256, end: 2048, multiplier: 2}, in_dtype_str: "fp8", out_dtype_str: "bf16"} diff --git a/src/benchmark_gemm.py b/src/benchmark_gemm.py index 9e6e528..7dc777a 100644 --- a/src/benchmark_gemm.py +++ b/src/benchmark_gemm.py @@ -74,6 +74,26 @@ def str_to_dtype(dtype_str: str) -> jnp.dtype: else: raise ValueError(f"Unsupported dtype string: {dtype_str}") +def get_peak_flops_multiplier(in_dtype_str: str) -> float: + """ + Returns the peak FLOPS multiplier relative to the baseline + (PEAK_FLOPS_PER_DEVICE) based on the input data type. + """ + in_dtype_lower = in_dtype_str.lower() + if in_dtype_lower == "fp8": + # FP8 is 2x faster than BF16 + # The baseline PEAK_FLOPS_PER_DEVICE is 1153.5 * 2 = 2307, which is FP8 peak. + # So the multiplier should be 1.0 + return 1.0 + elif in_dtype_lower == "bf16" or in_dtype_lower == "fp16": + # BF16/FP16 is 2x slower than FP8 peak + return 0.5 + elif in_dtype_lower == "fp32": + # FP32 is 4x slower than FP8 peak + return 0.25 + else: + raise RuntimeError(f"{in_dtype_lower} is not supported for setting peak_flops_multiplier.") + def get_lhs_named_shading(mesh): match SHARDING_STRATEGY: case ShardingStrategy.NO_SHARDING: @@ -369,15 +389,8 @@ def gemm_simple_calculate_metrics( total_flops = (2 * k - 1) * m * n # Total floating-point operations total_flops, total_flops_all_devices = handle_based_on_sharding(total_flops) - # Set peak FLOPS multiplier based on input datatype - if in_dtype_str.lower() == "bf16" or in_dtype_str.lower() == "fp16": - peak_flops_multiplier = 0.5 - elif in_dtype_str.lower() == "fp32": - peak_flops_multiplier = 0.25 - elif in_dtype_str.lower() == "fp8": - peak_flops_multiplier = 1.0 - else: - raise RuntimeError(f"{in_dtype_str.lower()} is not supported for setting peak_flops_multiplier.") + # Get the multiplier by calling the utility function + peak_flops_multiplier = get_peak_flops_multiplier(in_dtype_str) metadata, metrics = unified_flops_metrics( m, n, k, time_ms_list, @@ -481,16 +494,8 @@ def gemm_batched_simple_calculate_metrics( # Get per-device and all-device FLOPS based on sharding strategy total_flops_per_device, total_flops_all_devices = handle_based_on_sharding(total_flops_base) - # Set peak FLOPS multiplier based on input datatype - if in_dtype_str.lower() == "bf16" or in_dtype_str.lower() == "fp16": - peak_flops_multiplier = 0.5 - elif in_dtype_str.lower() == "fp32": - peak_flops_multiplier = 0.25 - elif in_dtype_str.lower() == "fp8": - peak_flops_multiplier = 1.0 - else: - raise RuntimeError(f"{in_dtype_str.lower()} is not supported for setting peak_flops_multiplier.") - + # Get the multiplier by calling the utility function + peak_flops_multiplier = get_peak_flops_multiplier(in_dtype_str) metadata, metrics = unified_flops_metrics( m, n, k, time_ms_list, @@ -583,6 +588,131 @@ def gemm_calculate_metrics( total_flops, total_flops_all_devices = handle_based_on_sharding(total_flops) return unified_flops_metrics(m, n, k, time_ms_list, total_flops, total_flops_all_devices, PEAK_FLOPS_PER_DEVICE) +def gemm_batched( + b: int, m: int, k: int, n: int, + in_dtype_str: str, out_dtype_str: str, + num_runs: int = 1, trace_dir: str = None +) -> Dict[str, Any]: + """Batched OUT = matmul(IN0, IN1) * outer_product(SF0 * SF1)""" + + # Convert string dtypes to jnp dtypes + lhs_dtype = str_to_dtype(in_dtype_str) + rhs_dtype = str_to_dtype(in_dtype_str) + out_dtype = str_to_dtype(out_dtype_str) + # Scale factors are typically FP32 + sf_dtype = jnp.float32 + + def f(x, y, scale_m, scale_n): + with jax.named_scope(MARKER): + # Batched matmul: (B, M, K) @ (B, K, N) -> (B, M, N) + acc = jax.numpy.einsum("bij,bjk->bik", x, y, preferred_element_type=jnp.float32) + # Batched scale outer product: (B, M, 1) * (B, 1, N) -> (B, M, N) + scales = scale_m * scale_n + # Batched element-wise scaling + result_fp32 = acc * scales + return result_fp32.astype(out_dtype) + + mesh = create_mesh() + + # --- Adapt 2D sharding specs to 3D by prepending 'None' for batch dim --- + lhs_sharding_2d = get_lhs_named_shading(mesh).spec + rhs_sharding_2d = get_rhs_named_shading(mesh).spec + out_sharding_2d = get_out_sharding() + + # (B, M, K) + lhs_spec = P(None, *lhs_sharding_2d) + # (B, K, N) + rhs_spec = P(None, *rhs_sharding_2d) + # (B, M, 1) - Sharded along M, like LHS + sf0_spec = P(None, *get_lhs_named_shading(mesh).spec) + # (B, 1, N) - Sharded along N, like RHS + sf1_spec = P(None, None, get_rhs_named_shading(mesh).spec[1]) + # (B, M, N) + out_spec = P(None, *out_sharding_2d) + + # Create full NamedSharding objects for the data generator + lhs_sharding = NamedSharding(mesh, lhs_spec) + rhs_sharding = NamedSharding(mesh, rhs_spec) + sf0_sharding = NamedSharding(mesh, sf0_spec) + sf1_sharding = NamedSharding(mesh, sf1_spec) + # --- End of sharding logic --- + + jit_sharded_f = jax.jit( + shard_map( + f, + mesh, + in_specs=(lhs_spec, rhs_spec, sf0_spec, sf1_spec), + out_specs=out_spec, + check_rep=False, + ) + ) + + # Add the batch dimension 'b' to all shapes + lhs_shape = (b, m, k) + rhs_shape = (b, k, n) + sf0_shape = (b, m, 1) + sf1_shape = (b, 1, n) + + key = jax.random.key(SEED) + + def data_generator(): + """Creates new random data on host and puts it on device.""" + nonlocal key # Use and update the outer 'key' + key, k1, k2, k3, k4 = jax.random.split(key, 5) + + # Create random data on host with new 3D shapes + lhs_host = jax.random.normal(k1, lhs_shape).astype(lhs_dtype) + rhs_host = jax.random.normal(k2, rhs_shape).astype(rhs_dtype) + sf0_host = jax.random.normal(k3, sf0_shape).astype(sf_dtype) + sf1_host = jax.random.normal(k4, sf1_shape).astype(sf_dtype) + + # Put on device (HBM) + lhs_device = jax.device_put(lhs_host, lhs_sharding) + rhs_device = jax.device_put(rhs_host, rhs_sharding) + sf0_device = jax.device_put(sf0_host, sf0_sharding) + sf1_device = jax.device_put(sf1_host, sf1_sharding) + + return (lhs_device, rhs_device, sf0_device, sf1_device) + + time_ms_list = iteration_timeit( + jit_sharded_f, + data_generator, + matrix_dim=f"{b}x{m}x{n}x{k}", + tries=num_runs, + task=f"gemm_batched_{in_dtype_str}_{out_dtype_str}", + trace_dir=trace_dir, + ) + return {"time_ms_list": time_ms_list} + +def gemm_batched_calculate_metrics( + b: int, m: int, k: int, n: int, + in_dtype_str: str, out_dtype_str: str, + time_ms_list: list[float] +) -> Dict[str, Any]: + + # Calculate FLOPs for the *entire batch* + # (2*k+1)*m*n FLOPS per item, multiplied by batch size b + total_flops_base = b * ((2 * k + 1) * m * n) + + # Get per-device and all-device FLOPS based on sharding strategy + total_flops_per_device, total_flops_all_devices = handle_based_on_sharding(total_flops_base) + + # Get the multiplier by calling the utility function + peak_flops_multiplier = get_peak_flops_multiplier(in_dtype_str) + + metadata, metrics = unified_flops_metrics( + m, n, k, time_ms_list, + total_flops_per_device, + total_flops_all_devices, + PEAK_FLOPS_PER_DEVICE * peak_flops_multiplier + ) + + # Manually add the new parameters to the metadata for logging + metadata["b"] = b + metadata["in_dtype"] = in_dtype_str + metadata["out_dtype"] = out_dtype_str + + return metadata, metrics def gemm_accum( m: int, k: int, n: int, num_runs: int = 1, trace_dir: str = None, diff --git a/src/run_benchmark.py b/src/run_benchmark.py index a2a3115..a371bc9 100644 --- a/src/run_benchmark.py +++ b/src/run_benchmark.py @@ -64,6 +64,7 @@ "gemm_simple":"benchmark_gemm.gemm_simple", "gemm_batched_simple":"benchmark_gemm.gemm_batched_simple", "gemm": "benchmark_gemm.gemm", + "gemm_batched":"benchmark_gemm.gemm_batched", "gemm_accum": "benchmark_gemm.gemm_accum", "quantization": "benchmark_gemm.quantization", "transpose_quantization": "benchmark_gemm.transpose_quantization", From a2c8ab7b73043a2744f9419942d7ca065494fc86 Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Tue, 4 Nov 2025 14:38:42 -0800 Subject: [PATCH 09/14] Adding gemm_grouped benchmark using ragged_dot kernel. --- configs/gemm_grouped_inference.yaml | 8 +- src/benchmark_gemm.py | 159 ++++++++++++++++++++++++++++ src/run_benchmark.py | 1 + 3 files changed, 164 insertions(+), 4 deletions(-) diff --git a/configs/gemm_grouped_inference.yaml b/configs/gemm_grouped_inference.yaml index 43e84c3..e6471a7 100644 --- a/configs/gemm_grouped_inference.yaml +++ b/configs/gemm_grouped_inference.yaml @@ -1,8 +1,8 @@ benchmarks: -- benchmark_name: "gemm_batched_simple" - trace_dir: "../microbenchmarks/gemm_batched_simple_inference" - csv_path: "../microbenchmarks/gemm_batched_simple_inference" - xlml_metrics_dir: "../microbenchmarks/gemm_batched_simple_inference" +- benchmark_name: "gemm_grouped_ragged_dot" + trace_dir: "../microbenchmarks/gemm_grouped_ragged_dot_inference" + csv_path: "../microbenchmarks/gemm_grouped_ragged_dot_inference" + xlml_metrics_dir: "../microbenchmarks/gemm_grouped_ragged_dot_inference" num_runs: 1000 benchmark_sweep_params: - {b: {start: 4, end: 256, multiplier: 2}, m: {start: 256, end: 2048, multiplier: 2}, k: {start: 256, end: 2048, multiplier: 2}, n: {start: 256, end: 2048, multiplier: 2}, in_dtype_str: "bf16", out_dtype_str: "bf16"} diff --git a/src/benchmark_gemm.py b/src/benchmark_gemm.py index 21b3e90..3f55124 100644 --- a/src/benchmark_gemm.py +++ b/src/benchmark_gemm.py @@ -512,6 +512,165 @@ def gemm_batched_simple_calculate_metrics( return metadata, metrics +# --- Add this import at the top of your file --- +try: + import tokamax_api +except ImportError: + print("Warning: tokamax_api not found. gemm_ragged_dot will not be available.") + tokamax_api = None +# --- + +def gemm_grouped_ragged_dot( + b: int, m: int, k: int, n: int, + in_dtype_str: str, out_dtype_str: str, + num_runs: int = 1, trace_dir: str = None +) -> Dict[str, Any]: + """Benchmarks a Batched GEMM using tokamax_api.ragged_dot. + + This implements (B, M, K) @ (B, K, N) by treating it as a Grouped GEMM: + LHS: (B*M, K) + RHS: (B, K, N) [B is the number of "experts"] + group_sizes: [M, M, ..., M] (length B) + """ + if tokamax_api is None: + raise ImportError("tokamax_api not found. Cannot run gemm_ragged_dot.") + + # Convert string dtypes to jnp dtypes + in_dtype = str_to_dtype(in_dtype_str) + out_dtype = str_to_dtype(out_dtype_str) + + def f(lhs, rhs, group_sizes): + with jax.named_scope(MARKER): + # Call the ragged_dot kernel + output_stacked = tokamax_api.ragged_dot( + lhs=lhs, + rhs=rhs, + group_sizes=group_sizes, + precision=jax.lax.Precision.DEFAULT, + preferred_element_type=jnp.float32, # Use FP32 accumulation + implementation="mosaic" # Use the high-performance TPU kernel + ) + + # Reshape the output from (B*M, N) back to (B, M, N) + output_batched = output_stacked.reshape((b, m, n)) + return output_batched.astype(out_dtype) + + mesh = create_mesh() + + # --- Sharding logic that respects SHARDING_STRATEGY --- + # We define sharding specs based on the *intent* of the strategy + + match SHARDING_STRATEGY: + case ShardingStrategy.NO_SHARDING: + # Replicate everything + lhs_spec = P(None, None) # (B*M, K) + rhs_spec = P(None, None, None) # (B, K, N) + group_sizes_spec = P(None,) # (B,) + out_spec = P(None, None, None) # (B, M, N) + + case ShardingStrategy.SHARDING_ON_ALL_DEVICES_WITH_M | ShardingStrategy.SHARDING_ON_SINGLE_CHIP_WITH_M: + # This is Token/Expert Parallelism. We shard along the 'device' axis. + # 'device' axis is defined in create_mesh() + lhs_spec = P("device", None) # Shard (B*M) tokens + rhs_spec = P("device", None, None) # Shard (B) experts + group_sizes_spec = P("device",) # Shard (B) group_sizes + out_spec = P("device", None, None) # Output is sharded by B + + case ShardingStrategy.SHARDING_ON_ALL_DEVICES_WITH_N | ShardingStrategy.SHARDING_ON_SINGLE_CHIP_WITH_N: + # This is Tensor Parallelism. We shard along the 'N' dimension. + lhs_spec = P(None, None) # Replicate (B*M) tokens + rhs_spec = P(None, None, "device") # Shard (N) dimension + group_sizes_spec = P(None,) # Replicate group_sizes + out_spec = P(None, None, "device") # Output is sharded by N + + case _: + raise ValueError(f"Unsupported SHARDING_STRATEGY for ragged_dot: {SHARDING_STRATEGY}") + + # Create the full NamedSharding objects + lhs_sharding = NamedSharding(mesh, lhs_spec) + rhs_sharding = NamedSharding(mesh, rhs_spec) + group_sizes_sharding = NamedSharding(mesh, group_sizes_spec) + # --- End of sharding logic --- + + jit_sharded_f = jax.jit( + shard_map( + f, + mesh, + in_specs=(lhs_sharding.spec, rhs_sharding.spec, group_sizes_sharding.spec), + out_specs=out_spec, + check_rep=False, + ) + ) + + # Note the 2D and 3D shapes required for the inputs + lhs_shape = (b * m, k) # Stacked tokens + rhs_shape = (b, k, n) # Stacked expert weights + group_sizes_shape = (b,) + + key = jax.random.key(SEED) + + def data_generator(): + """Creates new random data on host and puts it on device.""" + nonlocal key + key, key_lhs, key_rhs = jax.random.split(key, 3) + + # Create random data on host + lhs_host = jax.random.normal(key_lhs, lhs_shape).astype(in_dtype) + rhs_host = jax.random.normal(key_rhs, rhs_shape).astype(in_dtype) + + # Create the group_sizes array: [m, m, m, ...] + group_sizes_host = jnp.full(group_sizes_shape, m, dtype=jnp.int32) + + # Put on device (HBM) with the new sharding + lhs_device = jax.device_put(lhs_host, lhs_sharding) + rhs_device = jax.device_put(rhs_host, rhs_sharding) + group_sizes_device = jax.device_put(group_sizes_host, group_sizes_sharding) + + return (lhs_device, rhs_device, group_sizes_device) + + # Run the benchmark + time_ms_list = iteration_timeit( + jit_sharded_f, + data_generator, + matrix_dim=f"{b}x{m}x{n}x{k}", + tries=num_runs, + task=f"gemm_ragged_dot_{in_dtype_str}_{out_dtype_str}", + trace_dir=trace_dir, + ) + return {"time_ms_list": time_ms_list} + +def gemm_grouped_ragged_dot_calculate_metrics( + b: int, m: int, k: int, n: int, + in_dtype_str: str, out_dtype_str: str, + time_ms_list: list[float] +) -> Dict[str, Any]: + + # The total FLOPS are identical to the batched simple GEMM + # (2*k-1) FLOPS per element, times M*N elements, times B batches + total_flops_base = b * ( (2 * k - 1) * m * n ) + + # Get per-device and all-device FLOPS based on sharding strategy + # Your `handle_based_on_sharding` function should work perfectly here! + total_flops_per_device, total_flops_all_devices = handle_based_on_sharding(total_flops_base) + + # Get the multiplier by calling the utility function + peak_flops_multiplier = get_peak_flops_multiplier(in_dtype_str) + + metadata, metrics = unified_flops_metrics( + m, n, k, time_ms_list, + total_flops_per_device, + total_flops_all_devices, + PEAK_FLOPS_PER_DEVICE * peak_flops_multiplier + ) + + # Manually add the new parameters to the metadata for logging + metadata["b"] = b + metadata["in_dtype"] = in_dtype_str + metadata["out_dtype"] = out_dtype_str + metadata["kernel"] = "ragged_dot" # Add kernel name + + return metadata, metrics + def gemm( m: int, k: int, n: int, num_runs: int = 1, trace_dir: str = None ) -> Dict[str, Any]: diff --git a/src/run_benchmark.py b/src/run_benchmark.py index 0c792f9..6a078ef 100644 --- a/src/run_benchmark.py +++ b/src/run_benchmark.py @@ -63,6 +63,7 @@ GEMM_BENCHMARK_MAP = { "gemm_simple":"benchmark_gemm.gemm_simple", "gemm_batched_simple":"benchmark_gemm.gemm_batched_simple", + "gemm_grouped_ragged_dot":"benchmark_gemm.gemm_grouped_ragged_dot", "gemm": "benchmark_gemm.gemm", "gemm_batched":"benchmark_gemm.gemm_batched", "gemm_accum": "benchmark_gemm.gemm_accum", From 5db4ba05acee97fb12de5c33d2e7352d752b2dd9 Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Wed, 5 Nov 2025 11:49:46 -0800 Subject: [PATCH 10/14] Replacing tokamax.ragged_dot() with jax.lax.ragged_dot(). --- src/benchmark_gemm.py | 15 ++------------- 1 file changed, 2 insertions(+), 13 deletions(-) diff --git a/src/benchmark_gemm.py b/src/benchmark_gemm.py index 3f55124..c1d5813 100644 --- a/src/benchmark_gemm.py +++ b/src/benchmark_gemm.py @@ -512,14 +512,6 @@ def gemm_batched_simple_calculate_metrics( return metadata, metrics -# --- Add this import at the top of your file --- -try: - import tokamax_api -except ImportError: - print("Warning: tokamax_api not found. gemm_ragged_dot will not be available.") - tokamax_api = None -# --- - def gemm_grouped_ragged_dot( b: int, m: int, k: int, n: int, in_dtype_str: str, out_dtype_str: str, @@ -532,8 +524,6 @@ def gemm_grouped_ragged_dot( RHS: (B, K, N) [B is the number of "experts"] group_sizes: [M, M, ..., M] (length B) """ - if tokamax_api is None: - raise ImportError("tokamax_api not found. Cannot run gemm_ragged_dot.") # Convert string dtypes to jnp dtypes in_dtype = str_to_dtype(in_dtype_str) @@ -542,13 +532,12 @@ def gemm_grouped_ragged_dot( def f(lhs, rhs, group_sizes): with jax.named_scope(MARKER): # Call the ragged_dot kernel - output_stacked = tokamax_api.ragged_dot( + output_stacked = jax.lax.ragged_dot( lhs=lhs, rhs=rhs, group_sizes=group_sizes, precision=jax.lax.Precision.DEFAULT, - preferred_element_type=jnp.float32, # Use FP32 accumulation - implementation="mosaic" # Use the high-performance TPU kernel + preferred_element_type=jnp.float32 # Use FP32 accumulation ) # Reshape the output from (B*M, N) back to (B, M, N) From 18cef9b8a241fc27756314aa7649a76230294546 Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Sun, 9 Nov 2025 22:00:54 -0800 Subject: [PATCH 11/14] Adding tiling scope for ragged_dot API usage. Fixed the iteration metrics to get correct ragged_dot wall_time. Added a try catch to continue the sweap even if some configs failed due to lack of resources. --- src/benchmark_gemm.py | 30 ++++++++++----- src/benchmark_utils.py | 3 +- src/run_benchmark.py | 86 ++++++++++++++++++++++-------------------- 3 files changed, 67 insertions(+), 52 deletions(-) diff --git a/src/benchmark_gemm.py b/src/benchmark_gemm.py index c1d5813..2f242d2 100644 --- a/src/benchmark_gemm.py +++ b/src/benchmark_gemm.py @@ -515,9 +515,10 @@ def gemm_batched_simple_calculate_metrics( def gemm_grouped_ragged_dot( b: int, m: int, k: int, n: int, in_dtype_str: str, out_dtype_str: str, - num_runs: int = 1, trace_dir: str = None + num_runs: int = 1, trace_dir: str = None, + ragged_dot_tiling: str = None, ) -> Dict[str, Any]: - """Benchmarks a Batched GEMM using tokamax_api.ragged_dot. + """Benchmarks a Batched GEMM using jax.lax.ragged_dot. This implements (B, M, K) @ (B, K, N) by treating it as a Grouped GEMM: LHS: (B*M, K) @@ -531,18 +532,27 @@ def gemm_grouped_ragged_dot( def f(lhs, rhs, group_sizes): with jax.named_scope(MARKER): - # Call the ragged_dot kernel - output_stacked = jax.lax.ragged_dot( - lhs=lhs, - rhs=rhs, - group_sizes=group_sizes, - precision=jax.lax.Precision.DEFAULT, - preferred_element_type=jnp.float32 # Use FP32 accumulation + # Use the tiling context if provided, otherwise do nothing. + tiling_ctx = ( + set_xla_metadata(ragged_dot_tiling=ragged_dot_tiling) + if ragged_dot_tiling + else contextlib.nullcontext() ) + with tiling_ctx: + # Call the ragged_dot kernel. Accumulation is fp32 by default + output_stacked = jax.lax.ragged_dot( + lhs=lhs, + rhs=rhs, + group_sizes=group_sizes, + precision=jax.lax.Precision.DEFAULT, + preferred_element_type=out_dtype + ) + # Reshape the output from (B*M, N) back to (B, M, N) output_batched = output_stacked.reshape((b, m, n)) - return output_batched.astype(out_dtype) + + return output_batched mesh = create_mesh() diff --git a/src/benchmark_utils.py b/src/benchmark_utils.py index fec0f2d..ebbe87d 100644 --- a/src/benchmark_utils.py +++ b/src/benchmark_utils.py @@ -68,10 +68,11 @@ def iteration_timeit_from_trace( def iteration_get_metrics_from_trace(trace: dict[str, Any]) -> list[float]: marker_done_events = [] + events_lookup =[MARKER, "ragged-dot"] for event in trace["traceEvents"]: args = event.get("args", {}) tf_op = args.get("tf_op", "") - if MARKER in tf_op: + if any(s in tf_op for s in events_lookup): marker_done_events.append(event) # print(marker_done_events) diff --git a/src/run_benchmark.py b/src/run_benchmark.py index 6a078ef..4197e17 100644 --- a/src/run_benchmark.py +++ b/src/run_benchmark.py @@ -328,47 +328,51 @@ def run_single_benchmark(benchmark_config: Dict[str, Any], output_path: str): test_start_time = ( datetime.datetime.now(tz=datetime.timezone.utc).isoformat() + "Z" ) # "Z" indicates UTC - benchmark_results = benchmark_func(**benchmark_param) - test_end_time = ( - datetime.datetime.now(tz=datetime.timezone.utc).isoformat() + "Z" - ) - - # Filter benchmark_results to include only keys present in - # calculate_metrics_func - calculate_metrics_params = inspect.signature(calculate_metrics_func).parameters - filtered_benchmark_results = { - key: value - for key, value in benchmark_results.items() - if key in calculate_metrics_params - } - # Filter out certain parameters from benchmark_param, eg. "num_runs". - benchmark_params_to_filter = ["num_runs", "trace_dir"] - filtered_benchmark_param = { - key: value - for key, value in benchmark_param.items() - if key not in benchmark_params_to_filter - } - metadata, metrics = calculate_metrics_func( - **filtered_benchmark_param, **filtered_benchmark_results - ) - calculate_metrics_results.append({"metadata": metadata, "metrics": metrics}) - if xlml_metrics_dir: - maybe_write_metrics_file( - xlml_metrics_dir, - metrics, - metadata, - benchmark_name, - test_start_time, - test_end_time, - ) - # Post process the xla dump - if xla_dump_dir: - rename_xla_dump( - tmp_xla_dump_dir=TMP_XLA_DUMP_DIR, - dest_xla_dump_dir=xla_dump_dir, - benchmark_name=benchmark_name, - benchmark_param=original_benchmark_param, - ) + try: + benchmark_results = benchmark_func(**benchmark_param) + test_end_time = ( + datetime.datetime.now(tz=datetime.timezone.utc).isoformat() + "Z" + ) + + # Filter benchmark_results to include only keys present in + # calculate_metrics_func + calculate_metrics_params = inspect.signature(calculate_metrics_func).parameters + filtered_benchmark_results = { + key: value + for key, value in benchmark_results.items() + if key in calculate_metrics_params + } + # Filter out certain parameters from benchmark_param, eg. "num_runs". + benchmark_params_to_filter = ["num_runs", "trace_dir"] + filtered_benchmark_param = { + key: value + for key, value in benchmark_param.items() + if key not in benchmark_params_to_filter + } + metadata, metrics = calculate_metrics_func( + **filtered_benchmark_param, **filtered_benchmark_results + ) + calculate_metrics_results.append({"metadata": metadata, "metrics": metrics}) + if xlml_metrics_dir: + maybe_write_metrics_file( + xlml_metrics_dir, + metrics, + metadata, + benchmark_name, + test_start_time, + test_end_time, + ) + # Post process the xla dump + if xla_dump_dir: + rename_xla_dump( + tmp_xla_dump_dir=TMP_XLA_DUMP_DIR, + dest_xla_dump_dir=xla_dump_dir, + benchmark_name=benchmark_name, + benchmark_param=original_benchmark_param, + ) + except Exception as e: + print(f"-------------- failed with error {e} -----") + continue # Dump metrics to file. if csv_path: From 91ee081f6216f0b49624136d0697438f1cfad0d1 Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Sun, 9 Nov 2025 22:58:07 -0800 Subject: [PATCH 12/14] Fixing a mising import. --- src/benchmark_gemm.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/benchmark_gemm.py b/src/benchmark_gemm.py index 2f242d2..6fecb2c 100644 --- a/src/benchmark_gemm.py +++ b/src/benchmark_gemm.py @@ -11,6 +11,7 @@ from benchmark_utils import simple_timeit, MetricsStatistics, iteration_timeit import jax from jax.experimental.shard_map import shard_map +from jax.experimental.xla_metadata import set_xla_metadata import jax.numpy as jnp from jax.sharding import Mesh from jax.sharding import NamedSharding From b1f761406b9ee0b12a490a22714cfb36bfe5f96c Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Sun, 9 Nov 2025 23:05:14 -0800 Subject: [PATCH 13/14] Adding the tiling parameter into the calculation metric function though it won't be used in the calculations. --- src/benchmark_gemm.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/benchmark_gemm.py b/src/benchmark_gemm.py index 6fecb2c..062e751 100644 --- a/src/benchmark_gemm.py +++ b/src/benchmark_gemm.py @@ -642,6 +642,7 @@ def data_generator(): def gemm_grouped_ragged_dot_calculate_metrics( b: int, m: int, k: int, n: int, in_dtype_str: str, out_dtype_str: str, + ragged_dot_tiling: str, time_ms_list: list[float] ) -> Dict[str, Any]: From 80007f5e87d2545725093c8aa0a5fe07b32568ff Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Mon, 10 Nov 2025 10:51:57 -0800 Subject: [PATCH 14/14] Adding ragged_dot_tiling to the metric metadata. --- src/benchmark_gemm.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/benchmark_gemm.py b/src/benchmark_gemm.py index 062e751..354429f 100644 --- a/src/benchmark_gemm.py +++ b/src/benchmark_gemm.py @@ -668,6 +668,7 @@ def gemm_grouped_ragged_dot_calculate_metrics( metadata["b"] = b metadata["in_dtype"] = in_dtype_str metadata["out_dtype"] = out_dtype_str + metadata["ragged_dot_tiling"] = ragged_dot_tiling metadata["kernel"] = "ragged_dot" # Add kernel name return metadata, metrics