Skip to content

Commit 69b8811

Browse files
authored
Add compilation metrics, parse step time in e2e, add name to train partial etc (#166)
* Add compilation metrics, script to assert metrics, add conditional import of hugginface_hub and update huggingface hub Signed-off-by: Kunjan patel <kunjanp@google.com> * Formatting Signed-off-by: Kunjan patel <kunjanp@google.com> * Fix pylint issues Signed-off-by: Kunjan patel <kunjanp@google.com> --------- Signed-off-by: Kunjan patel <kunjanp@google.com>
1 parent 0d9afba commit 69b8811

20 files changed

+219
-35
lines changed

end_to_end/tpu/eval_assert.py

Lines changed: 77 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,15 @@
1414
limitations under the License.
1515
"""
1616

17+
"""
18+
Example to run
19+
python end_to_end/tpu/eval_assert.py avg_tflops metrics.txt 100
20+
python end_to_end/tpu/eval_assert.py avg_step_time metrics.txt 0.5 100
21+
python end_to_end/tpu/eval_assert.py avg_step_time metrics.txt 0.5 100
22+
"""
23+
24+
25+
1726
# pylint: skip-file
1827
"""Reads and asserts over target values"""
1928
from absl import app
@@ -34,26 +43,89 @@ def get_last_n_data(metrics_file, target, n=10):
3443
return last_n_data
3544

3645

37-
def test_final_loss(metrics_file, target_loss):
46+
def test_final_loss(metrics_file, target_loss, num_samples_str="10"):
3847
target_loss = float(target_loss)
48+
num_samples = int(num_samples_str)
3949
with open(metrics_file, "r", encoding="utf8") as _:
40-
use_last_n_data = 10
41-
last_n_data = get_last_n_data(metrics_file, "learning/loss", use_last_n_data)
50+
last_n_data = get_last_n_data(metrics_file, "learning/loss",num_samples)
4251
avg_last_n_data = sum(last_n_data) / len(last_n_data)
4352
print(f"Mean of last {len(last_n_data)} losses is {avg_last_n_data}")
4453
print(f"Target loss is {target_loss}")
4554
assert avg_last_n_data < target_loss
4655
print("Final loss test passed.")
4756

4857

58+
def test_avg_step_time(metrics_file, max_avg_step_time_str, num_samples_str="10"):
59+
"""Tests if the average of the last N step times is below a maximum threshold."""
60+
max_avg_step_time = float(max_avg_step_time_str)
61+
num_samples = int(num_samples_str)
62+
metric_key = "perf/step_time_seconds"
63+
last_n_step_times = get_last_n_data(metrics_file, metric_key, num_samples)
64+
65+
if not last_n_step_times:
66+
raise ValueError(f"Metric '{metric_key}' not found or no data points in {metrics_file}.")
67+
68+
avg_last_n_step_time = sum(last_n_step_times) / len(last_n_step_times)
69+
70+
print(f"Found {len(last_n_step_times)} data points for '{metric_key}'.")
71+
print(f"Mean of last {len(last_n_step_times)} step times is {avg_last_n_step_time:.4f} s")
72+
73+
assert (
74+
avg_last_n_step_time < max_avg_step_time
75+
), f"Average step time {avg_last_n_step_time:.4f}s is not less than target {max_avg_step_time}s."
76+
print("Average step time test passed.")
77+
78+
79+
def test_avg_tflops(metrics_file, min_avg_tflops_str, num_samples_str="10"):
80+
"""Tests if the average of the last N TFLOPs/sec values is above a minimum threshold."""
81+
min_avg_tflops = float(min_avg_tflops_str)
82+
num_samples = int(num_samples_str)
83+
metric_key = "perf/per_device_tflops_per_sec"
84+
85+
last_n_tflops = get_last_n_data(metrics_file, metric_key, num_samples)
86+
87+
if not last_n_tflops:
88+
raise ValueError(f"Metric '{metric_key}' not found or no data points in {metrics_file}.")
89+
90+
avg_last_n_tflops = sum(last_n_tflops) / len(last_n_tflops)
91+
92+
print(f"Found {len(last_n_tflops)} data points for '{metric_key}'.")
93+
print(f"Mean of last {len(last_n_tflops)} steps TFLOPs/sec is {avg_last_n_tflops:.2f}")
94+
95+
assert (
96+
avg_last_n_tflops > min_avg_tflops
97+
), f"Average TFLOPs/sec {avg_last_n_tflops:.2f} is not greater than target {min_avg_tflops}."
98+
print("Average TFLOPs/sec test passed.")
99+
100+
49101
def main(argv: Sequence[str]) -> None:
102+
if len(argv) < 2:
103+
print("Usage: python script.py <test_scenario> [test_vars...]")
104+
print("Available scenarios: final_loss, avg_step_time, avg_tflops")
105+
raise ValueError("Test scenario not specified.")
50106

51107
_, test_scenario, *test_vars = argv
52108

53109
if test_scenario == "final_loss":
54-
test_final_loss(*test_vars)
110+
if len(test_vars) < 2:
111+
raise ValueError("Usage: final_loss <metrics_file> <target_loss> [num_samples]")
112+
metrics_file, target_loss, *num_samples_opt = test_vars
113+
num_samples = num_samples_opt[0] if num_samples_opt else "10"
114+
test_final_loss(metrics_file, target_loss, num_samples)
115+
elif test_scenario == "avg_step_time":
116+
if len(test_vars) < 2:
117+
raise ValueError("Usage: avg_step_time <metrics_file> <max_avg_step_time> [num_samples]")
118+
metrics_file, max_avg_step_time, *num_samples_opt = test_vars
119+
num_samples = num_samples_opt[0] if num_samples_opt else "10"
120+
test_avg_step_time(metrics_file, max_avg_step_time, num_samples)
121+
elif test_scenario == "avg_tflops":
122+
if len(test_vars) < 2:
123+
raise ValueError("Usage: avg_tflops <metrics_file> <min_avg_tflops> [num_samples]")
124+
metrics_file, min_avg_tflops, *num_samples_opt = test_vars
125+
num_samples = num_samples_opt[0] if num_samples_opt else "10"
126+
test_avg_tflops(metrics_file, min_avg_tflops, num_samples)
55127
else:
56-
raise ValueError(f"Unrecognized test_scenario {test_scenario}")
128+
raise ValueError(f"Unrecognized test_scenario '{test_scenario}'. Available: final_loss, avg_step_time, avg_tflops")
57129

58130

59131
if __name__ == "__main__":

end_to_end/tpu/test_sdxl_training_loss.sh

100644100755
File mode changed.

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ git+https://github.com/mlperf/logging.git
2626
opencv-python-headless==4.10.0.84
2727
orbax-checkpoint==0.10.3
2828
tokenizers==0.21.0
29-
huggingface_hub==0.24.7
29+
huggingface_hub==0.30.2
3030
transformers==4.48.1
3131
einops==0.8.0
3232
sentencepiece

requirements_with_jax_stable_stack.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ ftfy
88
git+https://github.com/mlperf/logging.git
99
google-cloud-storage==2.17.0
1010
grain-nightly==0.0.10
11-
huggingface_hub==0.24.7
11+
huggingface_hub==0.30.2
1212
jax>=0.4.30
1313
jaxlib>=0.4.30
1414
Jinja2

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@
9797
"filelock",
9898
"flax>=0.4.1",
9999
"hf-doc-builder>=0.3.0",
100-
"huggingface-hub==0.24.7",
100+
"huggingface-hub==0.30.0",
101101
"requests-mock==1.10.0",
102102
"importlib_metadata",
103103
"invisible-watermark>=0.2.0",

src/maxdiffusion/configs/base14.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,10 @@ metrics_file: "" # for testing, local file that stores scalar metrics. If empty,
1919
# If true save metrics such as loss and TFLOPS to GCS in {base_output_directory}/{run_name}/metrics/
2020
write_metrics: True
2121
gcs_metrics: True
22+
23+
timing_metrics_file: "" # for testing, local file that stores function timing metrics such as state creation, compilation. If empty, no metrics are written.
24+
write_timing_metrics: True
25+
2226
# If true save config to GCS in {base_output_directory}/{run_name}/
2327
save_config_to_gcs: False
2428
log_period: 10000000000 # Flushes Tensorboard

src/maxdiffusion/configs/base21.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,10 @@ metrics_file: "" # for testing, local file that stores scalar metrics. If empty,
1919
# If true save metrics such as loss and TFLOPS to GCS in {base_output_directory}/{run_name}/metrics/
2020
write_metrics: True
2121
gcs_metrics: True
22+
23+
timing_metrics_file: "" # for testing, local file that stores function timing metrics such as state creation, compilation. If empty, no metrics are written.
24+
write_timing_metrics: True
25+
2226
# If true save config to GCS in {base_output_directory}/{run_name}/
2327
save_config_to_gcs: False
2428
log_period: 10000000000 # Flushes Tensorboard

src/maxdiffusion/configs/base_2_base.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,10 @@ metrics_file: "" # for testing, local file that stores scalar metrics. If empty,
1919
# If true save metrics such as loss and TFLOPS to GCS in {base_output_directory}/{run_name}/metrics/
2020
write_metrics: True
2121
gcs_metrics: False
22+
23+
timing_metrics_file: "" # for testing, local file that stores function timing metrics such as state creation, compilation. If empty, no metrics are written.
24+
write_timing_metrics: True
25+
2226
# If true save config to GCS in {base_output_directory}/{run_name}/
2327
save_config_to_gcs: False
2428
log_period: 10000000000 # Flushes Tensorboard

src/maxdiffusion/configs/base_flux_dev.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,10 @@ run_name: ''
1818
metrics_file: "" # for testing, local file that stores scalar metrics. If empty, no metrics are written.
1919
# If true save metrics such as loss and TFLOPS to GCS in {base_output_directory}/{run_name}/metrics/
2020
write_metrics: True
21+
22+
timing_metrics_file: "" # for testing, local file that stores function timing metrics such as state creation, compilation. If empty, no metrics are written.
23+
write_timing_metrics: True
24+
2125
gcs_metrics: False
2226
# If true save config to GCS in {base_output_directory}/{run_name}/
2327
save_config_to_gcs: False

src/maxdiffusion/configs/base_flux_dev_multi_res.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,10 @@ metrics_file: "" # for testing, local file that stores scalar metrics. If empty,
1919
# If true save metrics such as loss and TFLOPS to GCS in {base_output_directory}/{run_name}/metrics/
2020
write_metrics: True
2121
gcs_metrics: False
22+
23+
timing_metrics_file: "" # for testing, local file that stores function timing metrics such as state creation, compilation. If empty, no metrics are written.
24+
write_timing_metrics: True
25+
2226
# If true save config to GCS in {base_output_directory}/{run_name}/
2327
save_config_to_gcs: False
2428
log_period: 100

0 commit comments

Comments
 (0)