diff --git a/src/benchmark_collectives.py b/src/benchmark_collectives.py index 6eb0b37..c31c6da 100644 --- a/src/benchmark_collectives.py +++ b/src/benchmark_collectives.py @@ -78,7 +78,7 @@ def psum_benchmark( # DCN benchmark if dcn_size > 1: - @partial(shard_map, mesh=mesh, in_specs=P("dcn", None), out_specs=P(None)) + @partial(shard_map, mesh=mesh, in_specs=P("dcn", None), out_specs=P(None, None)) def f(x): return jax.lax.psum(x, "dcn") @@ -99,12 +99,12 @@ def f(x): # ICI benchmark if ici_size > 1: - @partial(shard_map, mesh=mesh, in_specs=P(None, None), out_specs=P(None, None)) + @partial(shard_map, mesh=mesh, in_specs=P(None, "ici"), out_specs=P(None, None)) def f(x): return jax.lax.psum(x, "ici") sharded_matrix = jax.device_put( - matrix, jax.sharding.NamedSharding(mesh, P(None, None)) + matrix, jax.sharding.NamedSharding(mesh, P(None, "ici")) ) jitted_op = jax.jit(f) ici_average_time_ms_list = simple_timeit( @@ -167,6 +167,7 @@ def psum_benchmark_calculate_metrics( * (ici_size - 1) * 2 / ici_size + / ici_size / (ici_average_time_ms / 1e3) for ici_average_time_ms in ici_average_time_ms_list ] @@ -235,12 +236,12 @@ def f(x): # ICI benchmark if ici_size > 1: - @partial(shard_map, mesh=mesh, in_specs=P(None, None), out_specs=P(None, "ici")) + @partial(shard_map, mesh=mesh, in_specs=P(None, "ici"), out_specs=P(None, "ici")) def f(x): return jax.lax.psum_scatter(x, "ici", tiled=True) sharded_matrix = jax.device_put( - matrix, jax.sharding.NamedSharding(mesh, P(None, None)) + matrix, jax.sharding.NamedSharding(mesh, P(None, "ici")) ) jitted_op = jax.jit(f) ici_average_time_ms_list = simple_timeit( @@ -303,6 +304,7 @@ def psum_scatter_benchmark_calculate_metrics( matrix_size_gbyte * (ici_size - 1) / ici_size + / ici_size / (ici_average_time_ms / 1e3) for ici_average_time_ms in ici_average_time_ms_list ] @@ -443,9 +445,10 @@ def all_gather_benchmark_calculate_metrics( # each sharded matrix size is matrix_size_gbyte / ici_size and then it needs # to use (ici_size - 1) steps in a ring algorithm ici_bandwidth_gbyte_s_list = [ - matrix_size_gbyte - * (ici_size - 1) - / ici_size + matrix_size_gbyte + * (ici_size - 1) + / ici_size + / ici_size / (ici_average_time_ms / 1e3) for ici_average_time_ms in ici_average_time_ms_list ] @@ -517,13 +520,13 @@ def f(x): # ICI benchmark if ici_size > 1: - @partial(shard_map, mesh=mesh, in_specs=P(None, None), out_specs=P(None, "ici")) + @partial(shard_map, mesh=mesh, in_specs=P(None, "ici"), out_specs=P(None, "ici")) def f(x): perm = [(i, (i + 1) % ici_size) for i in range(ici_size)] return jax.lax.ppermute(x, "ici", perm) sharded_matrix = jax.device_put( - matrix, jax.sharding.NamedSharding(mesh, P(None, None)) + matrix, jax.sharding.NamedSharding(mesh, P(None, "ici")) ) jitted_op = jax.jit(f) ici_average_time_ms_list = simple_timeit( @@ -579,7 +582,9 @@ def ppermute_benchmark_calculate_metrics( # each sharded matrix size is matrix_size_gbyte / ici_size and then it needs # to use 1 step ici_bandwidth_gbyte_s_list = [ - matrix_size_gbyte / (ici_average_time_ms / 1e3) + matrix_size_gbyte + / ici_size + / (ici_average_time_ms / 1e3) for ici_average_time_ms in ici_average_time_ms_list ] ici_bandwidth_gbyte_s_statistics = MetricsStatistics( @@ -648,15 +653,15 @@ def f(x): @partial( shard_map, mesh=mesh, - in_specs=P(None, None), - out_specs=P(None, None), + in_specs=P(None, "ici"), + out_specs=P(None, "ici"), check_rep=False, ) def f(x): return jax.lax.all_to_all(x, "ici", split_axis=0, concat_axis=0, tiled=True) sharded_matrix = jax.device_put( - matrix, jax.sharding.NamedSharding(mesh, P(None, None)) + matrix, jax.sharding.NamedSharding(mesh, P(None, "ici")) ) jitted_op = jax.jit(f) ici_average_time_ms_list = simple_timeit( @@ -716,6 +721,7 @@ def all_to_all_benchmark_calculate_metrics( matrix_size_gbyte * (ici_size - 1) / ici_size + / ici_size / (ici_average_time_ms / 1e3) for ici_average_time_ms in ici_average_time_ms_list ]