Skip to content

Commit 78d4f42

Browse files
authored
5648 track bundle config in MLFlow (#5666)
Fixes #5648 . ### Description This PR added support to track the executed bundle config in the MLFlow, also fixed #4057 (comment). ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. Signed-off-by: Nic Ma <nma@nvidia.com>
1 parent 0737a33 commit 78d4f42

File tree

10 files changed

+68
-32
lines changed

10 files changed

+68
-32
lines changed

monai/bundle/reference_resolver.py

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -221,21 +221,23 @@ def update_refs_pattern(cls, value: str, refs: Dict) -> str:
221221
result = cls.id_matcher.findall(value)
222222
value_is_expr = ConfigExpression.is_expression(value)
223223
for item in result:
224-
ref_id = item[len(cls.ref) :] # remove the ref prefix "@"
225-
if ref_id not in refs:
226-
msg = f"can not find expected ID '{ref_id}' in the references."
227-
if cls.allow_missing_reference:
228-
warnings.warn(msg)
229-
continue
230-
else:
231-
raise KeyError(msg)
232-
if value_is_expr:
233-
# replace with local code, `{"__local_refs": self.resolved_content}` will be added to
234-
# the `globals` argument of python `eval` in the `evaluate`
235-
value = value.replace(item, f"{cls._vars}['{ref_id}']")
236-
elif value == item:
237-
# the whole content is "@XXX", it will avoid the case that regular string contains "@"
238-
value = refs[ref_id]
224+
# only update reference when string starts with "$" or the whole content is "@XXX"
225+
if value_is_expr or value == item:
226+
ref_id = item[len(cls.ref) :] # remove the ref prefix "@"
227+
if ref_id not in refs:
228+
msg = f"can not find expected ID '{ref_id}' in the references."
229+
if cls.allow_missing_reference:
230+
warnings.warn(msg)
231+
continue
232+
else:
233+
raise KeyError(msg)
234+
if value_is_expr:
235+
# replace with local code, `{"__local_refs": self.resolved_content}` will be added to
236+
# the `globals` argument of python `eval` in the `evaluate`
237+
value = value.replace(item, f"{cls._vars}['{ref_id}']")
238+
elif value == item:
239+
# the whole content is "@XXX", it will avoid the case that regular string contains "@"
240+
value = refs[ref_id]
239241
return value
240242

241243
@classmethod

monai/bundle/scripts.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import os
1515
import pprint
1616
import re
17+
import time
1718
import warnings
1819
from logging.config import fileConfig
1920
from pathlib import Path
@@ -489,6 +490,18 @@ def patch_bundle_tracking(parser: ConfigParser, settings: dict):
489490
handlers.append(v)
490491
elif k not in parser:
491492
parser[k] = v
493+
# save the executed config into file
494+
default_name = f"config_{time.strftime('%Y%m%d_%H%M%S')}.json"
495+
filepath = parser.get("execute_config", None)
496+
if filepath is None:
497+
if "output_dir" not in parser:
498+
# if no "output_dir" in the bundle config, default to "<bundle root>/eval"
499+
parser["output_dir"] = "$@bundle_root + '/eval'"
500+
# experiment management tools can refer to this config item to track the config info
501+
parser["execute_config"] = parser["output_dir"] + f" + '/{default_name}'"
502+
filepath = os.path.join(parser.get_parsed_content("output_dir"), default_name)
503+
Path(filepath).parent.mkdir(parents=True, exist_ok=True)
504+
parser.export_config_file(parser.get(), filepath)
492505

493506

494507
def run(

monai/bundle/utils.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,9 +104,14 @@
104104
DEFAULT_MLFLOW_SETTINGS = {
105105
"handlers_id": DEFAULT_HANDLERS_ID,
106106
"configs": {
107-
"tracking_uri": "$@output_dir + '/mlruns'",
107+
# if no "output_dir" in the bundle config, default to "<bundle root>/eval"
108+
"output_dir": "$@bundle_root + '/eval'",
109+
# use URI to support linux, mac and windows os
110+
"tracking_uri": "$monai.utils.path_to_uri(@output_dir) + '/mlruns'",
108111
"experiment_name": "monai_experiment",
109112
"run_name": None,
113+
# may fill it at runtime
114+
"execute_config": None,
110115
"is_not_rank0": (
111116
"$torch.distributed.is_available() \
112117
and torch.distributed.is_initialized() and torch.distributed.get_rank() > 0"
@@ -118,6 +123,7 @@
118123
"tracking_uri": "@tracking_uri",
119124
"experiment_name": "@experiment_name",
120125
"run_name": "@run_name",
126+
"artifacts": "@execute_config",
121127
"iteration_log": True,
122128
"epoch_log": True,
123129
"tag_name": "train_loss",
@@ -140,6 +146,7 @@
140146
"tracking_uri": "@tracking_uri",
141147
"experiment_name": "@experiment_name",
142148
"run_name": "@run_name",
149+
"artifacts": "@execute_config",
143150
"iteration_log": False,
144151
"close_on_complete": True,
145152
},

monai/utils/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@
7373
is_scalar_tensor,
7474
issequenceiterable,
7575
list_to_dict,
76+
path_to_uri,
7677
progress_bar,
7778
sample_slices,
7879
save_obj,

monai/utils/misc.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@
5656
"check_parent_dir",
5757
"save_obj",
5858
"label_union",
59+
"path_to_uri",
5960
]
6061

6162
_seed = None
@@ -584,3 +585,14 @@ def prob2class(x, sigmoid: bool = False, threshold: float = 0.5, **kwargs):
584585
threshold: threshold value to activate the sigmoid function.
585586
"""
586587
return torch.argmax(x, **kwargs) if not sigmoid else (x > threshold).int()
588+
589+
590+
def path_to_uri(path: PathLike) -> str:
591+
"""
592+
Convert a file path to URI. if not absolute path, will convert to absolute path first.
593+
594+
Args:
595+
path: input file path to convert, can be a string or `Path` object.
596+
597+
"""
598+
return Path(path).absolute().as_uri()

tests/test_fl_monai_algo.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
import shutil
1414
import tempfile
1515
import unittest
16-
from pathlib import Path
1716

1817
from parameterized import parameterized
1918

@@ -22,6 +21,7 @@
2221
from monai.fl.client.monai_algo import MonaiAlgo
2322
from monai.fl.utils.constants import ExtraItems
2423
from monai.fl.utils.exchange_object import ExchangeObject
24+
from monai.utils import path_to_uri
2525
from tests.utils import SkipIfNoModule
2626

2727
_root_dir = os.path.abspath(os.path.join(os.path.dirname(__file__)))
@@ -151,12 +151,13 @@ def test_train(self, input_params):
151151
input_params["tracking"] = {
152152
"handlers_id": DEFAULT_HANDLERS_ID,
153153
"configs": {
154+
"execute_config": f"{data_dir}/config_executed.json",
154155
"trainer": {
155156
"_target_": "MLFlowHandler",
156-
"tracking_uri": Path(data_dir).as_uri() + "/mlflow_override",
157+
"tracking_uri": path_to_uri(data_dir) + "/mlflow_override",
157158
"output_transform": "$monai.handlers.from_engine(['loss'], first=True)",
158159
"close_on_complete": True,
159-
}
160+
},
160161
},
161162
}
162163

@@ -177,6 +178,7 @@ def test_train(self, input_params):
177178
algo.train(data=data, extra={})
178179
algo.finalize()
179180
self.assertTrue(os.path.exists(f"{data_dir}/mlflow_override"))
181+
self.assertTrue(os.path.exists(f"{data_dir}/config_executed.json"))
180182
shutil.rmtree(data_dir)
181183

182184
@parameterized.expand([TEST_EVALUATE_1, TEST_EVALUATE_2, TEST_EVALUATE_3])

tests/test_handler_mlflow.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,12 @@
1313
import os
1414
import tempfile
1515
import unittest
16-
from pathlib import Path
1716

1817
import numpy as np
1918
from ignite.engine import Engine, Events
2019

2120
from monai.handlers import MLFlowHandler
21+
from monai.utils import path_to_uri
2222

2323

2424
class TestHandlerMLFlow(unittest.TestCase):
@@ -49,7 +49,7 @@ def _update_metric(engine):
4949
handler = MLFlowHandler(
5050
iteration_log=False,
5151
epoch_log=True,
52-
tracking_uri=Path(test_path).as_uri(),
52+
tracking_uri=path_to_uri(test_path),
5353
state_attributes=["test"],
5454
experiment_param=experiment_param,
5555
artifacts=[artifact_path],

tests/test_integration_bundle_run.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import sys
1616
import tempfile
1717
import unittest
18-
from pathlib import Path
18+
from glob import glob
1919

2020
import nibabel as nib
2121
import numpy as np
@@ -24,6 +24,7 @@
2424
from monai.bundle import ConfigParser
2525
from monai.bundle.utils import DEFAULT_HANDLERS_ID
2626
from monai.transforms import LoadImage
27+
from monai.utils import path_to_uri
2728
from tests.utils import command_line_tests
2829

2930
TEST_CASE_1 = [os.path.join(os.path.dirname(__file__), "testing_data", "inference.json"), (128, 128, 128)]
@@ -85,7 +86,7 @@ def test_shape(self, config_file, expected_shape):
8586
"no_epoch": True, # test override config in the settings file
8687
"evaluator": {
8788
"_target_": "MLFlowHandler",
88-
"tracking_uri": "$@output_dir + '/mlflow_override1'",
89+
"tracking_uri": "$monai.utils.path_to_uri(@output_dir) + '/mlflow_override1'",
8990
"iteration_log": "@no_epoch",
9091
},
9192
},
@@ -105,14 +106,12 @@ def test_shape(self, config_file, expected_shape):
105106
json.dump("Dataset", f)
106107

107108
if sys.platform == "win32":
108-
outdir = Path(tempdir).as_uri()
109109
override = "--network $@network_def.to(@device) --dataset#_target_ Dataset"
110110
else:
111-
outdir = tempdir
112111
override = f"--network %{overridefile1}#move_net --dataset#_target_ %{overridefile2}"
113112
# test with `monai.bundle` as CLI entry directly
114113
cmd = "-m monai.bundle run evaluating --postprocessing#transforms#2#output_postfix seg"
115-
cmd += f" {override} --no_epoch False --save_dir {tempdir} --output_dir {outdir}"
114+
cmd += f" {override} --no_epoch False --output_dir {tempdir}"
116115
la = ["coverage", "run"] + cmd.split(" ") + ["--meta_file", meta_file] + ["--config_file", config_file]
117116
test_env = os.environ.copy()
118117
print(f"CUDA_VISIBLE_DEVICES in {__file__}", test_env.get("CUDA_VISIBLE_DEVICES"))
@@ -121,14 +120,16 @@ def test_shape(self, config_file, expected_shape):
121120
self.assertTupleEqual(loader(os.path.join(tempdir, "image", "image_seg.nii.gz")).shape, expected_shape)
122121
self.assertTrue(os.path.exists(f"{tempdir}/mlflow_override1"))
123122

124-
tracking_uri = outdir + "/mlflow_override2" # test override experiment management configs
123+
tracking_uri = path_to_uri(tempdir) + "/mlflow_override2" # test override experiment management configs
125124
# here test the script with `google fire` tool as CLI
126125
cmd = "-m fire monai.bundle.scripts run --runner_id evaluating --tracking mlflow --evaluator#amp False"
127-
cmd += f" --tracking_uri {tracking_uri} {override} --save_dir {tempdir} --output_dir {outdir}"
126+
cmd += f" --tracking_uri {tracking_uri} {override} --output_dir {tempdir}"
128127
la = ["coverage", "run"] + cmd.split(" ") + ["--meta_file", meta_file] + ["--config_file", config_file]
129128
command_line_tests(la)
130129
self.assertTupleEqual(loader(os.path.join(tempdir, "image", "image_trans.nii.gz")).shape, expected_shape)
131130
self.assertTrue(os.path.exists(f"{tempdir}/mlflow_override2"))
131+
# test the saved execution configs
132+
self.assertTrue(len(glob(f"{tempdir}/config_*.json")), 2)
132133

133134

134135
if __name__ == "__main__":

tests/testing_data/inference.json

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
{
22
"dataset_dir": "/workspace/data/Task09_Spleen",
3-
"save_dir": "need_override",
43
"output_dir": "need override",
54
"prediction_shape": "prediction shape:",
65
"import_glob": "$import glob",
@@ -89,7 +88,7 @@
8988
{
9089
"_target_": "SaveImaged",
9190
"keys": "pred",
92-
"output_dir": "@save_dir"
91+
"output_dir": "@output_dir"
9392
},
9493
{
9594
"_target_": "Lambdad",

tests/testing_data/inference.yaml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
---
22
dataset_dir: "/workspace/data/Task09_Spleen"
3-
save_dir: "need_override"
43
output_dir: "need override"
54
prediction_shape: "prediction shape:"
65
device: "$torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')"
@@ -65,7 +64,7 @@ postprocessing:
6564
argmax: true
6665
- _target_: SaveImaged
6766
keys: pred
68-
output_dir: "@save_dir"
67+
output_dir: "@output_dir"
6968
- _target_: Lambdad
7069
keys: pred
7170
func: "$lambda x: print(@prediction_shape + str(x.shape))"

0 commit comments

Comments
 (0)