Skip to content

Commit d29914d

Browse files
authored
6415-fix-mlflow-handler-run-bug (#6446)
Fixes #6415 . ### Description Fix the mlflow handler bug. When running a bundle with ` MLFLowHandler` back to back without assigning the run name , the later run info will be recorded into the former run, although the former run is finished. This PR checks the status of runs and filters the finished ones. ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [x] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: binliu <binliu@nvidia.com>
1 parent 4219e0f commit d29914d

File tree

3 files changed

+58
-12
lines changed

3 files changed

+58
-12
lines changed

monai/handlers/mlflow_handler.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,10 @@
2323
from monai.utils import ensure_tuple, min_version, optional_import
2424

2525
Events, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Events")
26-
mlflow, _ = optional_import("mlflow")
27-
mlflow.entities, _ = optional_import("mlflow.entities")
26+
mlflow, _ = optional_import("mlflow", descriptor="Please install mlflow before using MLFlowHandler.")
27+
mlflow.entities, _ = optional_import(
28+
"mlflow.entities", descriptor="Please install mlflow.entities before using MLFlowHandler."
29+
)
2830

2931
if TYPE_CHECKING:
3032
from ignite.engine import Engine
@@ -76,21 +78,23 @@ class MLFlowHandler:
7678
The default behavior is to track loss from output[0] as output is a decollated list
7779
and we replicated loss value for every item of the decollated list.
7880
`engine.state` and `output_transform` inherit from the ignite concept:
79-
https://pytorch.org/ignite/concepts.html#state, explanation and usage example are in the tutorial:
81+
https://pytorch-ignite.ai/concepts/03-state/, explanation and usage example are in the tutorial:
8082
https://github.com/Project-MONAI/tutorials/blob/master/modules/batch_output_transform.ipynb.
8183
global_epoch_transform: a callable that is used to customize global epoch number.
8284
For example, in evaluation, the evaluator engine might want to track synced epoch number
8385
with the trainer engine.
8486
state_attributes: expected attributes from `engine.state`, if provided, will extract them
8587
when epoch completed.
8688
tag_name: when iteration output is a scalar, `tag_name` is used to track, defaults to `'Loss'`.
87-
experiment_name: name for an experiment, defaults to `default_experiment`.
88-
run_name: name for run in an experiment.
89-
experiment_param: a dict recording parameters which will not change through whole experiment,
89+
experiment_name: the experiment name of MLflow, default to `'monai_experiment'`. An experiment can be
90+
used to record several runs.
91+
run_name: the run name in an experiment. A run can be used to record information about a workflow,
92+
like the loss, metrics and so on.
93+
experiment_param: a dict recording parameters which will not change through the whole workflow,
9094
like torch version, cuda version and so on.
91-
artifacts: paths to images that need to be recorded after a whole run.
92-
optimizer_param_names: parameters' name in optimizer that need to be record during running,
93-
defaults to "lr".
95+
artifacts: paths to images that need to be recorded after running the workflow.
96+
optimizer_param_names: parameter names in the optimizer that need to be recorded during running the
97+
workflow, default to `'lr'`.
9498
close_on_complete: whether to close the mlflow run in `complete` phase in workflow, default to False.
9599
96100
For more details of MLFlow usage, please refer to: https://mlflow.org/docs/latest/index.html.
@@ -132,6 +136,7 @@ def __init__(
132136
self.artifacts = ensure_tuple(artifacts)
133137
self.optimizer_param_names = ensure_tuple(optimizer_param_names)
134138
self.client = mlflow.MlflowClient(tracking_uri=tracking_uri if tracking_uri else None)
139+
self.run_finish_status = mlflow.entities.RunStatus.to_string(mlflow.entities.RunStatus.FINISHED)
135140
self.close_on_complete = close_on_complete
136141
self.experiment = None
137142
self.cur_run = None
@@ -191,6 +196,8 @@ def start(self, engine: Engine) -> None:
191196
run_name = f"run_{time.strftime('%Y%m%d_%H%M%S')}" if self.run_name is None else self.run_name
192197
runs = self.client.search_runs(self.experiment.experiment_id)
193198
runs = [r for r in runs if r.info.run_name == run_name or not self.run_name]
199+
# runs marked as finish should not record info any more
200+
runs = [r for r in runs if r.info.status != self.run_finish_status]
194201
if runs:
195202
self.cur_run = self.client.get_run(runs[-1].info.run_id) # pick latest active run
196203
else:
@@ -264,8 +271,7 @@ def close(self) -> None:
264271
265272
"""
266273
if self.cur_run:
267-
status = mlflow.entities.RunStatus.to_string(mlflow.entities.RunStatus.FINISHED)
268-
self.client.set_terminated(self.cur_run.info.run_id, status)
274+
self.client.set_terminated(self.cur_run.info.run_id, self.run_finish_status)
269275
self.cur_run = None
270276

271277
def epoch_completed(self, engine: Engine) -> None:

tests/test_auto3dseg_bundlegen.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,13 @@
2525
from monai.bundle.config_parser import ConfigParser
2626
from monai.data import create_test_image_3d
2727
from monai.utils import set_determinism
28-
from tests.utils import get_testing_algo_template_path, skip_if_downloading_fails, skip_if_no_cuda, skip_if_quick
28+
from tests.utils import (
29+
SkipIfBeforePyTorchVersion,
30+
get_testing_algo_template_path,
31+
skip_if_downloading_fails,
32+
skip_if_no_cuda,
33+
skip_if_quick,
34+
)
2935

3036
num_images_perfold = max(torch.cuda.device_count(), 4)
3137
num_images_per_batch = 2
@@ -97,6 +103,7 @@ def run_auto3dseg_before_bundlegen(test_path, work_dir):
97103

98104

99105
@skip_if_no_cuda
106+
@SkipIfBeforePyTorchVersion((1, 11, 1))
100107
@skip_if_quick
101108
class TestBundleGen(unittest.TestCase):
102109
def setUp(self) -> None:

tests/test_handler_mlflow.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,39 @@ def tearDown(self):
6868
if tmpdir and os.path.exists(tmpdir):
6969
shutil.rmtree(tmpdir)
7070

71+
def test_multi_run(self):
72+
with tempfile.TemporaryDirectory() as tempdir:
73+
# set up the train function for engine
74+
def _train_func(engine, batch):
75+
return [batch + 1.0]
76+
77+
# create and run an engine several times to get several runs
78+
create_engine_times = 3
79+
for _ in range(create_engine_times):
80+
engine = Engine(_train_func)
81+
82+
@engine.on(Events.EPOCH_COMPLETED)
83+
def _update_metric(engine):
84+
current_metric = engine.state.metrics.get("acc", 0.1)
85+
engine.state.metrics["acc"] = current_metric + 0.1
86+
engine.state.test = current_metric
87+
88+
# set up testing handler
89+
test_path = os.path.join(tempdir, "mlflow_test")
90+
handler = MLFlowHandler(
91+
iteration_log=False,
92+
epoch_log=True,
93+
tracking_uri=path_to_uri(test_path),
94+
state_attributes=["test"],
95+
close_on_complete=True,
96+
)
97+
handler.attach(engine)
98+
engine.run(range(3), max_epochs=2)
99+
run_cnt = len(handler.client.search_runs(handler.experiment.experiment_id))
100+
handler.close()
101+
# the run count should equal to the times of creating engine
102+
self.assertEqual(create_engine_times, run_cnt)
103+
71104
def test_metrics_track(self):
72105
experiment_param = {"backbone": "efficientnet_b0"}
73106
with tempfile.TemporaryDirectory() as tempdir:

0 commit comments

Comments
 (0)