Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pytest.ini
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ testpaths =
python_files = *_test.py *_tests.py
addopts =
-rf --import-mode=importlib --strict-markers
--ignore=tests/dcn_bandwidth_test.py
--ignore=tests/post_training/integration/grpo_trainer_correctness_test.py
--ignore=tests/integration/smoke/train_gpu_smoke_test.py
--ignore=tests/integration/smoke/train_int8_smoke_test.py
Expand Down
9 changes: 9 additions & 0 deletions src/maxtext/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -872,6 +872,15 @@ enable_diloco: false
diloco_sync_period: 36
diloco_outer_lr: 0.3
diloco_outer_momentum: 0.9
# DCN bandwidth throttling parameters (used for simulating slow networks).
# If dcn_bandwidth_limit is empty, no throttling is applied.
dcn_bandwidth_limit: ""

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add comments

# The burst size parameter for the traffic control (tc) token bucket filter.
dcn_bandwidth_burst: "10mb"
# The latency parameter for the traffic control (tc) token bucket filter.
dcn_bandwidth_latency: "50ms"
# The network interface to apply throttling rules to.
dcn_bandwidth_interface: "eth0"

# You may disable clipping by setting gradient_clipping_threshold to zero.
gradient_clipping_threshold: 1.0
Expand Down
8 changes: 8 additions & 0 deletions src/maxtext/configs/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1445,6 +1445,14 @@ class DilocoParams(BaseModel):
diloco_sync_period: int = Field(36, description="Diloco sync period.")
diloco_outer_lr: float = Field(0.3, description="learning rate for outer optimizer.")
diloco_outer_momentum: float = Field(0.9, description="momentum for outer optimizer.")
dcn_bandwidth_limit: str = Field(
"", description="Programmatic DCN egress bandwidth limit (e.g., '28gbit'). Empty means no limit."
)
dcn_bandwidth_burst: str = Field("10mb", description="Burst size for Token Bucket Filter (TBF) traffic shaping.")
dcn_bandwidth_latency: str = Field(
"50ms", description="Latency threshold for Token Bucket Filter (TBF) traffic shaping."
)
dcn_bandwidth_interface: str = Field("eth0", description="Network interface to apply bandwidth limits on.")


class Optimizer(BaseModel):
Expand Down
5 changes: 5 additions & 0 deletions src/maxtext/trainers/pre_train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -607,6 +607,10 @@ def train_loop(config, recorder, state=None):
state,
) = train_utils.setup_train_loop(config, recorder)

# Throttling is applied only if configured (dcn_bandwidth_limit is set).
# The default flag value is empty, meaning no throttling is applied by default.
train_utils.maybe_apply_dcn_throttling(config)

start_step = get_first_step(model, state) # this is the start_step for training
train_utils.validate_completed_steps(start_step, config.steps)

Expand Down Expand Up @@ -739,6 +743,7 @@ def train_loop(config, recorder, state=None):
if _job_completed_gracefully:
record_goodput(recorder, RECORD_JOB_END_TIME)
metric_logger_instance.flush_metrics_and_cleanup()
train_utils.cleanup_dcn_throttling(config)

return state

Expand Down
51 changes: 50 additions & 1 deletion src/maxtext/utils/train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,14 @@
# pylint: disable=bare-except, consider-using-generator
"""Utils that are only interesting for training in MaxText."""

import subprocess
import jax
import functools
from functools import partial

from flax import nnx
from flax.linen import partitioning as nn_partitioning
import jax

from maxtext.common import checkpointing
from maxtext.common import train_state_nnx
from maxtext.common.common_types import ReorderStrategy
Expand Down Expand Up @@ -378,3 +380,50 @@ def validate_completed_steps(completed_steps: int, config_steps: int):
f"Did you mean to continue training past step {completed_steps} (you should set steps > {completed_steps}) "
f"or to not load the checkpoint (use enable_checkpointing=False?)"
)


def maybe_apply_dcn_throttling(config):
"""Applies programmatic traffic control (tc) bandwidth limit if configured."""
interface = config.dcn_bandwidth_interface

# Always clean up any existing traffic control rule on the interface first.
try:
subprocess.run(
["tc", "qdisc", "del", "dev", interface, "root"],
stdout=subprocess.DEVNULL,
stderr=subprocess.DEVNULL,
check=False,
)
except Exception as e: # pylint: disable=broad-exception-caught
max_logging.error(f"Failed to clean up existing traffic control on {interface}: {e}")

if not config.dcn_bandwidth_limit:
return

rate = config.dcn_bandwidth_limit
burst = config.dcn_bandwidth_burst
latency = config.dcn_bandwidth_latency

max_logging.log(f"Applying tc egress limit of {rate} (burst: {burst}, latency: {latency}) on {interface}...")
try:
cmd = ["tc", "qdisc", "add", "dev", interface, "root", "tbf", "rate", rate, "burst", burst, "latency", latency]
subprocess.run(cmd, check=True)
max_logging.log("DCN Bandwidth throttling applied successfully.")
except Exception as e: # pylint: disable=broad-exception-caught
max_logging.error(f"Failed to apply DCN bandwidth throttling: {e}")


def cleanup_dcn_throttling(config):
"""Cleans up traffic control (tc) rules."""
interface = config.dcn_bandwidth_interface
max_logging.log(f"Cleaning up tc egress limit on {interface}...")
try:
subprocess.run(
["tc", "qdisc", "del", "dev", interface, "root"],
stdout=subprocess.DEVNULL,
stderr=subprocess.DEVNULL,
check=False,
)
max_logging.log("DCN Bandwidth throttling cleaned up successfully.")
except Exception as e: # pylint: disable=broad-exception-caught
max_logging.error(f"Failed to clean up DCN bandwidth throttling: {e}")
200 changes: 200 additions & 0 deletions tests/dcn_bandwidth_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,200 @@
# Copyright 2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# NOTE: This test/benchmark requires a multislice TPU setup to run correctly.
# It is excluded from standard pytest discovery.

"""A microbenchmark to test DCN network bandwidth using shard map.

This script should be run on a multi-slice TPU cluster (specifically across 2 slices
with any ICI/slice dimensions
"""

import datetime
import functools
import subprocess
from types import SimpleNamespace

import jax
from jax.experimental import mesh_utils
from jax.experimental.shard_map import shard_map
import jax.numpy as jnp
from jax.sharding import Mesh
from jax.sharding import PartitionSpec as P

from maxtext.utils import train_utils


def get_default_interface():
try:
route_output = subprocess.check_output("ip route show", shell=True, text=True)
for line in route_output.splitlines():
if "default" in line:
return line.split("dev")[1].strip().split()[0]
except (subprocess.SubprocessError, IndexError):
pass
return "eth0"


def simple_timeit(f, *args, tries=10, task=None):
"""Simple utility to time a function for multiple runs."""
assert task is not None
outcomes_ms = []

# Warm up
jax.block_until_ready(f(*args))

for _ in range(tries):
jax.devices() # Force synchronization
s = datetime.datetime.now()
jax.block_until_ready(f(*args))
e = datetime.datetime.now()
outcomes_ms.append(1000 * (e - s).total_seconds())

average_time_ms = sum(outcomes_ms) / len(outcomes_ms)
return average_time_ms


def create_mesh(dcn_size: int, ici_size: int):
"""Creates a hybrid mesh with DCN and ICI axes."""
dcn_parallelism = [dcn_size, 1]
ici_parallelism = [1, ici_size]

total_devices = jax.device_count()
if total_devices != (dcn_size * ici_size):
raise ValueError(f"Need {dcn_size * ici_size} devices, but found {total_devices}")
mesh_devices = mesh_utils.create_hybrid_device_mesh(ici_parallelism, dcn_parallelism, devices=jax.devices())
mesh = Mesh(mesh_devices, ("dcn", "ici"))
return mesh


def run_dcn_benchmark(case=None, limit=None, burst=None, latency="50ms"):
"""Runs the DCN bandwidth benchmark for specified throttling cases."""
print(f"JAX process index: {jax.process_index()} / {jax.process_count()}")
print(f"Total devices: {jax.device_count()}, local devices: {jax.local_device_count()}")

dcn_size = 2
total_devices = jax.device_count()
if total_devices % dcn_size != 0:
raise ValueError(f"Total devices ({total_devices}) must be divisible by dcn_size ({dcn_size}) for 2-slice setup.")
ici_size = total_devices // dcn_size
mesh = create_mesh(dcn_size, ici_size)

# Predefined cases
all_cases = [
("none", "NO THROTTLING (Baseline)", False, None, None),
("100g", "100G Throttling", True, "100gbit", "600mb"),
("50g", "50G Throttling", True, "50gbit", "300mb"),
]

# Filter based on arguments if requested
if case:
cases_to_run = [c for c in all_cases if c[0] == case]
if not cases_to_run:
# If custom limit/burst are provided
if limit and burst:
cases_to_run = [(case, f"CUSTOM Throttling ({case})", True, limit, burst)]
else:
raise ValueError(f"Unknown case: {case}")
else:
cases_to_run = all_cases

# Qwen3-30B MoE layer weight shape: (128, 2048, 768)
shape = (128, 2048, 768)
dtype = jnp.bfloat16

# Calculate size
num_elements = 1
for d in shape:
num_elements *= d
matrix_size_gbyte = num_elements * dtype.dtype.itemsize / 1e9

# We define shard map collective psum along the DCN axis.
# Input x is sharded across 'dcn' axis.
@functools.partial(shard_map, mesh=mesh, in_specs=P("dcn", None, None), out_specs=P(None, None, None))
def psum_dcn_op(x):
return jax.lax.psum(x, "dcn")

# Initialize matrix
matrix = jnp.ones(shape, dtype=dtype)

# Pre-distribute the matrix shard onto devices
sharded_matrix = jax.device_put(matrix, jax.sharding.NamedSharding(mesh, P("dcn", None, None)))

jitted_op = jax.jit(psum_dcn_op)

interface = get_default_interface()

for _, name, apply_throttling, dcn_limit, dcn_burst in cases_to_run:
if jax.process_index() == 0:
print("\n==================================================")
print(f"Running Case: {name}")
if apply_throttling:
print(f" Throttling Config: limit={dcn_limit}, burst={dcn_burst}, latency={latency}")
print("==================================================")

if apply_throttling:
config = SimpleNamespace(
dcn_bandwidth_limit=dcn_limit,
dcn_bandwidth_burst=dcn_burst,
dcn_bandwidth_latency=latency,
dcn_bandwidth_interface=interface,
)
train_utils.maybe_apply_dcn_throttling(config)
else:
config = None

try:
# Sync before starting benchmark
jax.block_until_ready(jax.device_put(0.0) + 1.0)

if jax.process_index() == 0:
print(f"Starting benchmark for shape: {shape} ({matrix_size_gbyte * 1000:.1f} MB)")

# Run time test
time_ms = simple_timeit(jitted_op, sharded_matrix, task=f"psum_dcn_{shape}")

# Calculate Bandwidth
achieved_bandwidth_gbyte_s = matrix_size_gbyte * (dcn_size - 1) * 2 / dcn_size / dcn_size / (time_ms / 1e3)
achieved_bandwidth_gbps = achieved_bandwidth_gbyte_s * 8.0

if jax.process_index() == 0:
print(f"Results for {name}:")
print(f" Avg Latency: {time_ms:.2f} ms")
print(
f" Achieved DCN Bandwidth: {achieved_bandwidth_gbyte_s:.3f} GB/s ({achieved_bandwidth_gbps:.2f} Gbps) per slice"
)
finally:
if apply_throttling and config:
if jax.process_index() == 0:
print(f"Cleaning up throttling for {name}...")
train_utils.cleanup_dcn_throttling(config)

# Sync after cleanup
jax.block_until_ready(jax.device_put(0.0) + 1.0)


if __name__ == "__main__":
import argparse

parser = argparse.ArgumentParser(description="DCN Bandwidth Benchmark")
parser.add_argument(
"--case", choices=["none", "100g", "50g"], help="Predefined throttling case (optional, runs all if omitted)"
)
parser.add_argument("--limit", help="Custom DCN bandwidth limit (e.g. 100gbit)")
parser.add_argument("--burst", help="Custom DCN bandwidth burst (e.g. 600mb)")
parser.add_argument("--latency", default="50ms", help="DCN bandwidth latency (default: 50ms)")
parsed_args = parser.parse_args()

run_dcn_benchmark(case=parsed_args.case, limit=parsed_args.limit, burst=parsed_args.burst, latency=parsed_args.latency)
Loading