Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 20 additions & 14 deletions src/benchmark_collectives.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def psum_benchmark(
# DCN benchmark
if dcn_size > 1:

@partial(shard_map, mesh=mesh, in_specs=P("dcn", None), out_specs=P(None))
@partial(shard_map, mesh=mesh, in_specs=P("dcn", None), out_specs=P(None, None))
def f(x):
return jax.lax.psum(x, "dcn")

Expand All @@ -99,12 +99,12 @@ def f(x):
# ICI benchmark
if ici_size > 1:

@partial(shard_map, mesh=mesh, in_specs=P(None, None), out_specs=P(None, None))
@partial(shard_map, mesh=mesh, in_specs=P(None, "ici"), out_specs=P(None, None))
def f(x):
return jax.lax.psum(x, "ici")

sharded_matrix = jax.device_put(
matrix, jax.sharding.NamedSharding(mesh, P(None, None))
matrix, jax.sharding.NamedSharding(mesh, P(None, "ici"))
)
jitted_op = jax.jit(f)
ici_average_time_ms_list = simple_timeit(
Expand Down Expand Up @@ -167,6 +167,7 @@ def psum_benchmark_calculate_metrics(
* (ici_size - 1)
* 2
/ ici_size
/ ici_size
/ (ici_average_time_ms / 1e3)
for ici_average_time_ms in ici_average_time_ms_list
]
Expand Down Expand Up @@ -235,12 +236,12 @@ def f(x):
# ICI benchmark
if ici_size > 1:

@partial(shard_map, mesh=mesh, in_specs=P(None, None), out_specs=P(None, "ici"))
@partial(shard_map, mesh=mesh, in_specs=P(None, "ici"), out_specs=P(None, "ici"))
def f(x):
return jax.lax.psum_scatter(x, "ici", tiled=True)

sharded_matrix = jax.device_put(
matrix, jax.sharding.NamedSharding(mesh, P(None, None))
matrix, jax.sharding.NamedSharding(mesh, P(None, "ici"))
)
jitted_op = jax.jit(f)
ici_average_time_ms_list = simple_timeit(
Expand Down Expand Up @@ -303,6 +304,7 @@ def psum_scatter_benchmark_calculate_metrics(
matrix_size_gbyte
* (ici_size - 1)
/ ici_size
/ ici_size
/ (ici_average_time_ms / 1e3)
for ici_average_time_ms in ici_average_time_ms_list
]
Expand Down Expand Up @@ -443,9 +445,10 @@ def all_gather_benchmark_calculate_metrics(
# each sharded matrix size is matrix_size_gbyte / ici_size and then it needs
# to use (ici_size - 1) steps in a ring algorithm
ici_bandwidth_gbyte_s_list = [
matrix_size_gbyte
* (ici_size - 1)
/ ici_size
matrix_size_gbyte
* (ici_size - 1)
/ ici_size
/ ici_size
/ (ici_average_time_ms / 1e3)
for ici_average_time_ms in ici_average_time_ms_list
]
Expand Down Expand Up @@ -517,13 +520,13 @@ def f(x):
# ICI benchmark
if ici_size > 1:

@partial(shard_map, mesh=mesh, in_specs=P(None, None), out_specs=P(None, "ici"))
@partial(shard_map, mesh=mesh, in_specs=P(None, "ici"), out_specs=P(None, "ici"))
def f(x):
perm = [(i, (i + 1) % ici_size) for i in range(ici_size)]
return jax.lax.ppermute(x, "ici", perm)

sharded_matrix = jax.device_put(
matrix, jax.sharding.NamedSharding(mesh, P(None, None))
matrix, jax.sharding.NamedSharding(mesh, P(None, "ici"))
)
jitted_op = jax.jit(f)
ici_average_time_ms_list = simple_timeit(
Expand Down Expand Up @@ -579,7 +582,9 @@ def ppermute_benchmark_calculate_metrics(
# each sharded matrix size is matrix_size_gbyte / ici_size and then it needs
# to use 1 step
ici_bandwidth_gbyte_s_list = [
matrix_size_gbyte / (ici_average_time_ms / 1e3)
matrix_size_gbyte
/ ici_size
/ (ici_average_time_ms / 1e3)
for ici_average_time_ms in ici_average_time_ms_list
]
ici_bandwidth_gbyte_s_statistics = MetricsStatistics(
Expand Down Expand Up @@ -648,15 +653,15 @@ def f(x):
@partial(
shard_map,
mesh=mesh,
in_specs=P(None, None),
out_specs=P(None, None),
in_specs=P(None, "ici"),
out_specs=P(None, "ici"),
check_rep=False,
)
def f(x):
return jax.lax.all_to_all(x, "ici", split_axis=0, concat_axis=0, tiled=True)

sharded_matrix = jax.device_put(
matrix, jax.sharding.NamedSharding(mesh, P(None, None))
matrix, jax.sharding.NamedSharding(mesh, P(None, "ici"))
)
jitted_op = jax.jit(f)
ici_average_time_ms_list = simple_timeit(
Expand Down Expand Up @@ -716,6 +721,7 @@ def all_to_all_benchmark_calculate_metrics(
matrix_size_gbyte
* (ici_size - 1)
/ ici_size
/ ici_size
/ (ici_average_time_ms / 1e3)
for ici_average_time_ms in ici_average_time_ms_list
]
Expand Down