Skip to content

Commit b08ea95

Browse files
committed
Recipe Modularization, DAG Integration, and BigQuery Integration
1 parent 9204d6b commit b08ea95

12 files changed

+317
-56
lines changed

benchmarks/globals.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import os.path
1818

1919
# This is the MaxText root: with "max_utils.py"; &etc. TODO: Replace `os.path.basename` with `os.path.abspath`
20-
MAXTEXT_PKG_DIR = os.environ.get("MAXTEXT_PKG_DIR", "MaxText")
20+
MAXTEXT_PKG_DIR = os.environ.get("MAXTEXT_PKG_DIR", "src/MaxText")
2121

2222
# This is the maxtext repo root: with ".git" folder; "README.md"; "pyproject.toml"; &etc.
2323
MAXTEXT_REPO_ROOT = os.environ.get(

benchmarks/maxtext_xpk_runner.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -583,6 +583,8 @@ def generate_xpk_workload_cmd(
583583
cluster_config: XpkClusterConfig,
584584
wl_config: WorkloadConfig,
585585
workload_name=None,
586+
user=os.environ["USER"],
587+
temp_key=None,
586588
exp_name=None,
587589
):
588590
"""Generates a command to run a maxtext model on XPK."""
@@ -592,15 +594,19 @@ def generate_xpk_workload_cmd(
592594

593595
time.localtime()
594596
length_of_random_str = 3
595-
temp_post_fix = "".join(random.choice(string.ascii_lowercase + string.digits) for _ in range(length_of_random_str))
597+
# Allow DAG to resolve workload name for cleanup, preventing reliance on random IDs
598+
if temp_key is not None:
599+
temp_post_fix = temp_key
600+
else:
601+
temp_post_fix = "".join(random.choice(string.ascii_lowercase + string.digits) for _ in range(length_of_random_str))
596602

597603
truncate_model_name = 10
598604
truncate_prefix = 3
599605
post_fix = f"-{wl_config.num_slices}-{time.strftime('%m%d%H', time.localtime())}-{temp_post_fix}"
600-
common_prefix = os.environ["USER"]
606+
common_prefix = user
601607
pw_prefix = "pw-"
602608

603-
if workload_name is None: # Generate name if not provided
609+
if workload_name is None:
604610
if is_pathways_enabled:
605611
post_fix = f"-{wl_config.num_slices}-{temp_post_fix}"
606612
name = f"{pw_prefix}{wl_config.model.model_name.replace('_', '-')[:truncate_model_name - len(pw_prefix)]}"

benchmarks/recipes/args_helper.py

Lines changed: 4 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
be used to clean up existing XPK workloads before starting a new run.
2121
"""
2222

23-
import argparse
2423
import os
2524

2625
from benchmarks.xpk_configs import XpkClusterConfig
@@ -66,41 +65,19 @@ def handle_delete_specific_workload(cluster_config: XpkClusterConfig, workload_n
6665
os.system(f"yes | {delete_command}")
6766

6867

69-
def handle_cmd_args(cluster_config: XpkClusterConfig, *actions: str, **kwargs) -> bool:
68+
def handle_cmd_args(cluster_config: XpkClusterConfig, is_delete: bool, user: str, **kwargs) -> bool:
7069
"""Parses command-line arguments and executes the specified actions.
7170
7271
Args:
7372
cluster_config: Contains Cluster configuration information that's helpful
7473
for running the actions.
75-
*actions: Variable number of string arguments representing the actions to
76-
be performed.
74+
is_delete: A boolean indicating whether the delete action should be
75+
performed.
7776
**kwargs: Optional keyword arguments to be passed to action handlers.
78-
79-
Raises:
80-
ValueError: If an unsupported action is provided or if unknown arguments are
81-
passed.
8277
"""
83-
84-
parser = argparse.ArgumentParser()
85-
86-
if DELETE in actions:
87-
parser.add_argument(
88-
"--delete",
89-
action="store_true",
90-
help="Delete workloads starting with the user's first five characters.",
91-
)
92-
93-
known_args, unknown_args = parser.parse_known_args()
94-
95-
if unknown_args:
96-
raise ValueError(f"Unrecognized arguments: {unknown_args}")
97-
98-
# Get user
99-
user = os.environ["USER"]
100-
10178
# Handle actions
10279
should_continue = True
103-
if DELETE in actions and known_args.delete:
80+
if is_delete:
10481
_handle_delete(cluster_config, user, **kwargs)
10582
should_continue = False
10683

benchmarks/recipes/mcjax_long_running_recipe.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import benchmarks.maxtext_trillium_model_configs as model_configs
2828
import benchmarks.maxtext_xpk_runner as mxr
2929
from benchmarks.xpk_configs import XpkClusterConfig
30+
from . import user_configs
3031

3132
# Cluster Params
3233
CLUSTER = "v6e-256-cluster"
@@ -57,7 +58,7 @@ def main() -> None:
5758
)
5859

5960
# Handle command line arguments using args_helper
60-
should_continue = helper.handle_cmd_args(cluster_config, helper.DELETE, xpk_path=XPK_PATH)
61+
should_continue = helper.handle_cmd_args(cluster_config, user_configs.USER_CONFIG.delete, user_configs.USER_CONFIG.user)
6162

6263
if not should_continue:
6364
return

benchmarks/recipes/parser_utils.py

Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,158 @@
1+
# Copyright 2023–2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""
16+
This module provides utility functions for custom argument parsing
17+
and defines a comprehensive set of command-line arguments for configuring a machine learning workload.
18+
"""
19+
20+
import argparse
21+
22+
23+
def parse_int_list(arg):
24+
"""Parses a string with comma-separated values into a list of integers."""
25+
return [int(x) for x in arg.split(",")]
26+
27+
28+
def parse_str_list(arg):
29+
"""Parses a string with space-separated values into a list of strings."""
30+
return [s.strip() for s in arg.split(",")]
31+
32+
33+
def str2bool(v):
34+
"""Parses a string representation of a boolean value into a Python boolean."""
35+
if isinstance(v, bool):
36+
return v
37+
if v.lower() in ("true"):
38+
return True
39+
elif v.lower() in ("false"):
40+
return False
41+
else:
42+
raise argparse.ArgumentTypeError("Boolean value expected (e.g., True or False).")
43+
44+
45+
def add_arguments(parser: argparse.ArgumentParser):
46+
"""Add arguments to arg parsers that need it.
47+
48+
Args:
49+
parser: parser to add shared arguments to.
50+
"""
51+
# Add the arguments for each parser.
52+
# GCP Configuration
53+
parser.add_argument("--user", type=str, default="user_name", help="GCP user name.")
54+
parser.add_argument(
55+
"--cluster_name",
56+
type=str,
57+
default="test-v5e-32-cluster",
58+
help="Name of the TPU cluster.",
59+
)
60+
parser.add_argument("--project", type=str, default="cloud-tpu-cluster", help="GCP project ID.")
61+
parser.add_argument("--zone", type=str, default="us-south1-a", help="GCP zone for the cluster.")
62+
parser.add_argument(
63+
"--device_type",
64+
type=str,
65+
default="v5litepod-32",
66+
help="Type of TPU device (e.g., v5litepod-32).",
67+
)
68+
parser.add_argument(
69+
"--priority",
70+
type=str,
71+
choices=["low", "medium", "high", "very high"],
72+
default="medium",
73+
help="Priority of the job.",
74+
)
75+
76+
# Image Configuration
77+
parser.add_argument(
78+
"--server_image",
79+
type=str,
80+
default="us-docker.pkg.dev/cloud-tpu-v2-images/pathways/proxy_server",
81+
help="Docker image for the proxy server.",
82+
)
83+
parser.add_argument(
84+
"--proxy_image",
85+
type=str,
86+
default="us-docker.pkg.dev/cloud-tpu-v2-images/pathways/server",
87+
help="Docker image for the server.",
88+
)
89+
parser.add_argument(
90+
"--runner",
91+
type=str,
92+
default="us-docker.pkg.dev/path/to/maxtext_runner",
93+
help="Docker image for the runner.",
94+
)
95+
parser.add_argument(
96+
"--colocated_python_image",
97+
type=str,
98+
default=None,
99+
help="Colocated Python image.",
100+
)
101+
parser.add_argument("--worker_flags", type=str, default="", help="Worker flags.")
102+
parser.add_argument("--proxy_flags", type=str, default="", help="Proxy flags.")
103+
parser.add_argument("--server_flags", type=str, default="", help="Server flags.")
104+
105+
# Model Configuration
106+
parser.add_argument("--benchmark_steps", type=int, default=20, help="Number of benchmark steps.")
107+
parser.add_argument(
108+
"--headless",
109+
action=argparse.BooleanOptionalAction,
110+
default=False,
111+
help="Run in headless mode.",
112+
)
113+
parser.add_argument(
114+
"--selected_model_framework",
115+
type=parse_str_list,
116+
default=["pathways"],
117+
help="List of model frameworks (e.g., pathways, mcjax",
118+
)
119+
parser.add_argument(
120+
"--selected_model_names",
121+
type=parse_str_list,
122+
default=["llama3_1_8b_8192_v5e_256"],
123+
help="List of model names (e.g., llama3_1_8b_8192_v5e_256, llama2-7b-v5e-256",
124+
)
125+
parser.add_argument(
126+
"--num_slices_list",
127+
type=parse_int_list,
128+
default=[2],
129+
help="List of number of slices.",
130+
)
131+
132+
# BigQuery configuration
133+
parser.add_argument(
134+
"--bq_enable",
135+
type=str2bool,
136+
default=False,
137+
help="Enable BigQuery logging. Must be True or False. Defaults to False.",
138+
)
139+
140+
parser.add_argument(
141+
"--bq_db_project",
142+
type=str,
143+
default="",
144+
help="BigQuery project ID where the logging dataset resides.",
145+
)
146+
147+
parser.add_argument(
148+
"--bq_db_dataset",
149+
type=str,
150+
default="",
151+
help="BigQuery dataset name where metrics will be written.",
152+
)
153+
154+
# Other configurations
155+
parser.add_argument("--xpk_path", type=str, default="~/xpk", help="Path to xpk.")
156+
parser.add_argument("--delete", action="store_true", help="Delete the cluster workload")
157+
parser.add_argument("--max_restarts", type=int, default=0, help="Maximum number of restarts")
158+
parser.add_argument("--temp_key", type=str, default=None, help="Temporary placeholder code")

benchmarks/recipes/pw_elastic_training_recipe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def main() -> None:
4545
"""Main function to run the elastic training disruption test."""
4646
user_configs.USER_CONFIG.headless = False
4747
should_continue = helper.handle_cmd_args(
48-
user_configs.USER_CONFIG.cluster_config, helper.DELETE, xpk_path=user_configs.USER_CONFIG.xpk_path
48+
user_configs.USER_CONFIG.cluster_config, user_configs.USER_CONFIG.delete, user_configs.USER_CONFIG.user
4949
)
5050

5151
if not should_continue:

benchmarks/recipes/pw_headless_mode.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,12 @@
2323

2424
import benchmarks.recipes.args_helper as helper
2525
import maxtext_xpk_runner as mxr
26-
from recipes.user_configs import cluster_config, xpk_path, pathways_config, base_output_directory, headless_workload_name
26+
from recipes.user_configs import cluster_config, xpk_path, pathways_config, base_output_directory, headless_workload_name, delete, user
2727

2828

2929
def main() -> int:
3030
# Handle command line arguments using args_helper
31-
should_continue = helper.handle_cmd_args(cluster_config, helper.DELETE, xpk_path=xpk_path)
31+
should_continue = helper.handle_cmd_args(cluster_config, delete, user)
3232

3333
if not should_continue:
3434
return 0

0 commit comments

Comments
 (0)