From 98f0c5db92616ea3c01ecea3ee1199c6c1518d24 Mon Sep 17 00:00:00 2001 From: Mohit Khatwani Date: Tue, 23 Jun 2026 02:57:40 +0000 Subject: [PATCH] dcn throttling changes --- pytest.ini | 1 + src/maxtext/configs/base.yml | 9 ++ src/maxtext/configs/types.py | 8 + src/maxtext/trainers/pre_train/train.py | 5 + src/maxtext/utils/train_utils.py | 51 +++++- tests/dcn_bandwidth_test.py | 200 ++++++++++++++++++++++++ 6 files changed, 273 insertions(+), 1 deletion(-) create mode 100644 tests/dcn_bandwidth_test.py diff --git a/pytest.ini b/pytest.ini index 8800326967..4976e05ebc 100644 --- a/pytest.ini +++ b/pytest.ini @@ -5,6 +5,7 @@ testpaths = python_files = *_test.py *_tests.py addopts = -rf --import-mode=importlib --strict-markers + --ignore=tests/dcn_bandwidth_test.py --ignore=tests/post_training/integration/grpo_trainer_correctness_test.py --ignore=tests/integration/smoke/train_gpu_smoke_test.py --ignore=tests/integration/smoke/train_int8_smoke_test.py diff --git a/src/maxtext/configs/base.yml b/src/maxtext/configs/base.yml index 1541170fee..63fc6ee0c9 100644 --- a/src/maxtext/configs/base.yml +++ b/src/maxtext/configs/base.yml @@ -872,6 +872,15 @@ enable_diloco: false diloco_sync_period: 36 diloco_outer_lr: 0.3 diloco_outer_momentum: 0.9 +# DCN bandwidth throttling parameters (used for simulating slow networks). +# If dcn_bandwidth_limit is empty, no throttling is applied. +dcn_bandwidth_limit: "" +# The burst size parameter for the traffic control (tc) token bucket filter. +dcn_bandwidth_burst: "10mb" +# The latency parameter for the traffic control (tc) token bucket filter. +dcn_bandwidth_latency: "50ms" +# The network interface to apply throttling rules to. +dcn_bandwidth_interface: "eth0" # You may disable clipping by setting gradient_clipping_threshold to zero. gradient_clipping_threshold: 1.0 diff --git a/src/maxtext/configs/types.py b/src/maxtext/configs/types.py index 0d64347d60..f7700b5f78 100644 --- a/src/maxtext/configs/types.py +++ b/src/maxtext/configs/types.py @@ -1445,6 +1445,14 @@ class DilocoParams(BaseModel): diloco_sync_period: int = Field(36, description="Diloco sync period.") diloco_outer_lr: float = Field(0.3, description="learning rate for outer optimizer.") diloco_outer_momentum: float = Field(0.9, description="momentum for outer optimizer.") + dcn_bandwidth_limit: str = Field( + "", description="Programmatic DCN egress bandwidth limit (e.g., '28gbit'). Empty means no limit." + ) + dcn_bandwidth_burst: str = Field("10mb", description="Burst size for Token Bucket Filter (TBF) traffic shaping.") + dcn_bandwidth_latency: str = Field( + "50ms", description="Latency threshold for Token Bucket Filter (TBF) traffic shaping." + ) + dcn_bandwidth_interface: str = Field("eth0", description="Network interface to apply bandwidth limits on.") class Optimizer(BaseModel): diff --git a/src/maxtext/trainers/pre_train/train.py b/src/maxtext/trainers/pre_train/train.py index 047ddb97a8..674fa668e7 100644 --- a/src/maxtext/trainers/pre_train/train.py +++ b/src/maxtext/trainers/pre_train/train.py @@ -607,6 +607,10 @@ def train_loop(config, recorder, state=None): state, ) = train_utils.setup_train_loop(config, recorder) + # Throttling is applied only if configured (dcn_bandwidth_limit is set). + # The default flag value is empty, meaning no throttling is applied by default. + train_utils.maybe_apply_dcn_throttling(config) + start_step = get_first_step(model, state) # this is the start_step for training train_utils.validate_completed_steps(start_step, config.steps) @@ -739,6 +743,7 @@ def train_loop(config, recorder, state=None): if _job_completed_gracefully: record_goodput(recorder, RECORD_JOB_END_TIME) metric_logger_instance.flush_metrics_and_cleanup() + train_utils.cleanup_dcn_throttling(config) return state diff --git a/src/maxtext/utils/train_utils.py b/src/maxtext/utils/train_utils.py index eb429f5446..233bba542e 100644 --- a/src/maxtext/utils/train_utils.py +++ b/src/maxtext/utils/train_utils.py @@ -15,12 +15,14 @@ # pylint: disable=bare-except, consider-using-generator """Utils that are only interesting for training in MaxText.""" +import subprocess +import jax import functools from functools import partial from flax import nnx from flax.linen import partitioning as nn_partitioning -import jax + from maxtext.common import checkpointing from maxtext.common import train_state_nnx from maxtext.common.common_types import ReorderStrategy @@ -378,3 +380,50 @@ def validate_completed_steps(completed_steps: int, config_steps: int): f"Did you mean to continue training past step {completed_steps} (you should set steps > {completed_steps}) " f"or to not load the checkpoint (use enable_checkpointing=False?)" ) + + +def maybe_apply_dcn_throttling(config): + """Applies programmatic traffic control (tc) bandwidth limit if configured.""" + interface = config.dcn_bandwidth_interface + + # Always clean up any existing traffic control rule on the interface first. + try: + subprocess.run( + ["tc", "qdisc", "del", "dev", interface, "root"], + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL, + check=False, + ) + except Exception as e: # pylint: disable=broad-exception-caught + max_logging.error(f"Failed to clean up existing traffic control on {interface}: {e}") + + if not config.dcn_bandwidth_limit: + return + + rate = config.dcn_bandwidth_limit + burst = config.dcn_bandwidth_burst + latency = config.dcn_bandwidth_latency + + max_logging.log(f"Applying tc egress limit of {rate} (burst: {burst}, latency: {latency}) on {interface}...") + try: + cmd = ["tc", "qdisc", "add", "dev", interface, "root", "tbf", "rate", rate, "burst", burst, "latency", latency] + subprocess.run(cmd, check=True) + max_logging.log("DCN Bandwidth throttling applied successfully.") + except Exception as e: # pylint: disable=broad-exception-caught + max_logging.error(f"Failed to apply DCN bandwidth throttling: {e}") + + +def cleanup_dcn_throttling(config): + """Cleans up traffic control (tc) rules.""" + interface = config.dcn_bandwidth_interface + max_logging.log(f"Cleaning up tc egress limit on {interface}...") + try: + subprocess.run( + ["tc", "qdisc", "del", "dev", interface, "root"], + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL, + check=False, + ) + max_logging.log("DCN Bandwidth throttling cleaned up successfully.") + except Exception as e: # pylint: disable=broad-exception-caught + max_logging.error(f"Failed to clean up DCN bandwidth throttling: {e}") diff --git a/tests/dcn_bandwidth_test.py b/tests/dcn_bandwidth_test.py new file mode 100644 index 0000000000..c6c98ba967 --- /dev/null +++ b/tests/dcn_bandwidth_test.py @@ -0,0 +1,200 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# NOTE: This test/benchmark requires a multislice TPU setup to run correctly. +# It is excluded from standard pytest discovery. + +"""A microbenchmark to test DCN network bandwidth using shard map. + +This script should be run on a multi-slice TPU cluster (specifically across 2 slices +with any ICI/slice dimensions +""" + +import datetime +import functools +import subprocess +from types import SimpleNamespace + +import jax +from jax.experimental import mesh_utils +from jax.experimental.shard_map import shard_map +import jax.numpy as jnp +from jax.sharding import Mesh +from jax.sharding import PartitionSpec as P + +from maxtext.utils import train_utils + + +def get_default_interface(): + try: + route_output = subprocess.check_output("ip route show", shell=True, text=True) + for line in route_output.splitlines(): + if "default" in line: + return line.split("dev")[1].strip().split()[0] + except (subprocess.SubprocessError, IndexError): + pass + return "eth0" + + +def simple_timeit(f, *args, tries=10, task=None): + """Simple utility to time a function for multiple runs.""" + assert task is not None + outcomes_ms = [] + + # Warm up + jax.block_until_ready(f(*args)) + + for _ in range(tries): + jax.devices() # Force synchronization + s = datetime.datetime.now() + jax.block_until_ready(f(*args)) + e = datetime.datetime.now() + outcomes_ms.append(1000 * (e - s).total_seconds()) + + average_time_ms = sum(outcomes_ms) / len(outcomes_ms) + return average_time_ms + + +def create_mesh(dcn_size: int, ici_size: int): + """Creates a hybrid mesh with DCN and ICI axes.""" + dcn_parallelism = [dcn_size, 1] + ici_parallelism = [1, ici_size] + + total_devices = jax.device_count() + if total_devices != (dcn_size * ici_size): + raise ValueError(f"Need {dcn_size * ici_size} devices, but found {total_devices}") + mesh_devices = mesh_utils.create_hybrid_device_mesh(ici_parallelism, dcn_parallelism, devices=jax.devices()) + mesh = Mesh(mesh_devices, ("dcn", "ici")) + return mesh + + +def run_dcn_benchmark(case=None, limit=None, burst=None, latency="50ms"): + """Runs the DCN bandwidth benchmark for specified throttling cases.""" + print(f"JAX process index: {jax.process_index()} / {jax.process_count()}") + print(f"Total devices: {jax.device_count()}, local devices: {jax.local_device_count()}") + + dcn_size = 2 + total_devices = jax.device_count() + if total_devices % dcn_size != 0: + raise ValueError(f"Total devices ({total_devices}) must be divisible by dcn_size ({dcn_size}) for 2-slice setup.") + ici_size = total_devices // dcn_size + mesh = create_mesh(dcn_size, ici_size) + + # Predefined cases + all_cases = [ + ("none", "NO THROTTLING (Baseline)", False, None, None), + ("100g", "100G Throttling", True, "100gbit", "600mb"), + ("50g", "50G Throttling", True, "50gbit", "300mb"), + ] + + # Filter based on arguments if requested + if case: + cases_to_run = [c for c in all_cases if c[0] == case] + if not cases_to_run: + # If custom limit/burst are provided + if limit and burst: + cases_to_run = [(case, f"CUSTOM Throttling ({case})", True, limit, burst)] + else: + raise ValueError(f"Unknown case: {case}") + else: + cases_to_run = all_cases + + # Qwen3-30B MoE layer weight shape: (128, 2048, 768) + shape = (128, 2048, 768) + dtype = jnp.bfloat16 + + # Calculate size + num_elements = 1 + for d in shape: + num_elements *= d + matrix_size_gbyte = num_elements * dtype.dtype.itemsize / 1e9 + + # We define shard map collective psum along the DCN axis. + # Input x is sharded across 'dcn' axis. + @functools.partial(shard_map, mesh=mesh, in_specs=P("dcn", None, None), out_specs=P(None, None, None)) + def psum_dcn_op(x): + return jax.lax.psum(x, "dcn") + + # Initialize matrix + matrix = jnp.ones(shape, dtype=dtype) + + # Pre-distribute the matrix shard onto devices + sharded_matrix = jax.device_put(matrix, jax.sharding.NamedSharding(mesh, P("dcn", None, None))) + + jitted_op = jax.jit(psum_dcn_op) + + interface = get_default_interface() + + for _, name, apply_throttling, dcn_limit, dcn_burst in cases_to_run: + if jax.process_index() == 0: + print("\n==================================================") + print(f"Running Case: {name}") + if apply_throttling: + print(f" Throttling Config: limit={dcn_limit}, burst={dcn_burst}, latency={latency}") + print("==================================================") + + if apply_throttling: + config = SimpleNamespace( + dcn_bandwidth_limit=dcn_limit, + dcn_bandwidth_burst=dcn_burst, + dcn_bandwidth_latency=latency, + dcn_bandwidth_interface=interface, + ) + train_utils.maybe_apply_dcn_throttling(config) + else: + config = None + + try: + # Sync before starting benchmark + jax.block_until_ready(jax.device_put(0.0) + 1.0) + + if jax.process_index() == 0: + print(f"Starting benchmark for shape: {shape} ({matrix_size_gbyte * 1000:.1f} MB)") + + # Run time test + time_ms = simple_timeit(jitted_op, sharded_matrix, task=f"psum_dcn_{shape}") + + # Calculate Bandwidth + achieved_bandwidth_gbyte_s = matrix_size_gbyte * (dcn_size - 1) * 2 / dcn_size / dcn_size / (time_ms / 1e3) + achieved_bandwidth_gbps = achieved_bandwidth_gbyte_s * 8.0 + + if jax.process_index() == 0: + print(f"Results for {name}:") + print(f" Avg Latency: {time_ms:.2f} ms") + print( + f" Achieved DCN Bandwidth: {achieved_bandwidth_gbyte_s:.3f} GB/s ({achieved_bandwidth_gbps:.2f} Gbps) per slice" + ) + finally: + if apply_throttling and config: + if jax.process_index() == 0: + print(f"Cleaning up throttling for {name}...") + train_utils.cleanup_dcn_throttling(config) + + # Sync after cleanup + jax.block_until_ready(jax.device_put(0.0) + 1.0) + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser(description="DCN Bandwidth Benchmark") + parser.add_argument( + "--case", choices=["none", "100g", "50g"], help="Predefined throttling case (optional, runs all if omitted)" + ) + parser.add_argument("--limit", help="Custom DCN bandwidth limit (e.g. 100gbit)") + parser.add_argument("--burst", help="Custom DCN bandwidth burst (e.g. 600mb)") + parser.add_argument("--latency", default="50ms", help="DCN bandwidth latency (default: 50ms)") + parsed_args = parser.parse_args() + + run_dcn_benchmark(case=parsed_args.case, limit=parsed_args.limit, burst=parsed_args.burst, latency=parsed_args.latency)