Skip to content

Commit be4e1f5

Browse files
authored
Add bundle root directory to Python search directories automatically (#6910)
Fixes #6722 . ### Description Add scripts directory to Python search directories automatically in the `run` function in `ConfigWorkflow`. ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: KumoLiu <yunl@nvidia.com>
1 parent a4e4894 commit be4e1f5

File tree

2 files changed

+87
-8
lines changed

2 files changed

+87
-8
lines changed

monai/bundle/workflows.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from __future__ import annotations
1313

1414
import os
15+
import sys
1516
import time
1617
import warnings
1718
from abc import ABC, abstractmethod
@@ -170,6 +171,7 @@ class ConfigWorkflow(BundleWorkflow):
170171
"""
171172
Specification for the config-based bundle workflow.
172173
Standardized the `initialize`, `run`, `finalize` behavior in a config-based training, evaluation, or inference.
174+
Before `run`, we add bundle root directory to Python search directories automatically.
173175
For more information: https://docs.monai.io/en/latest/mb_specification.html.
174176
175177
Args:
@@ -224,23 +226,23 @@ def __init__(
224226
super().__init__(workflow_type=workflow_type)
225227
if config_file is not None:
226228
_config_files = ensure_tuple(config_file)
227-
config_root_path = Path(_config_files[0]).parent
229+
self.config_root_path = Path(_config_files[0]).parent
228230
for _config_file in _config_files:
229231
_config_file = Path(_config_file)
230-
if _config_file.parent != config_root_path:
232+
if _config_file.parent != self.config_root_path:
231233
warnings.warn(
232-
f"Not all config files are in {config_root_path}. If logging_file and meta_file are"
233-
f"not specified, {config_root_path} will be used as the default config root directory."
234+
f"Not all config files are in {self.config_root_path}. If logging_file and meta_file are"
235+
f"not specified, {self.config_root_path} will be used as the default config root directory."
234236
)
235237
if not _config_file.is_file():
236238
raise FileNotFoundError(f"Cannot find the config file: {_config_file}.")
237239
else:
238-
config_root_path = Path("configs")
240+
self.config_root_path = Path("configs")
239241

240-
logging_file = str(config_root_path / "logging.conf") if logging_file is None else logging_file
242+
logging_file = str(self.config_root_path / "logging.conf") if logging_file is None else logging_file
241243
if logging_file is not None:
242244
if not os.path.exists(logging_file):
243-
if logging_file == str(config_root_path / "logging.conf"):
245+
if logging_file == str(self.config_root_path / "logging.conf"):
244246
warnings.warn(f"Default logging file in {logging_file} does not exist, skipping logging.")
245247
else:
246248
raise FileNotFoundError(f"Cannot find the logging config file: {logging_file}.")
@@ -250,7 +252,7 @@ def __init__(
250252

251253
self.parser = ConfigParser()
252254
self.parser.read_config(f=config_file)
253-
meta_file = str(config_root_path / "metadata.json") if meta_file is None else meta_file
255+
meta_file = str(self.config_root_path / "metadata.json") if meta_file is None else meta_file
254256
if isinstance(meta_file, str) and not os.path.exists(meta_file):
255257
raise FileNotFoundError(f"Cannot find the metadata config file: {meta_file}.")
256258
else:
@@ -283,8 +285,13 @@ def initialize(self) -> Any:
283285
def run(self) -> Any:
284286
"""
285287
Run the bundle workflow, it can be a training, evaluation or inference.
288+
Before run, we add bundle root directory to Python search directories automatically.
286289
287290
"""
291+
_bundle_root_path = (
292+
self.config_root_path.parent if self.config_root_path.name == "configs" else self.config_root_path
293+
)
294+
sys.path.insert(1, str(_bundle_root_path))
288295
if self.run_id not in self.parser:
289296
raise ValueError(f"run ID '{self.run_id}' doesn't exist in the config file.")
290297
return self._run_expr(id=self.run_id)

tests/test_integration_bundle_run.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import json
1515
import os
1616
import shutil
17+
import subprocess
1718
import sys
1819
import tempfile
1920
import unittest
@@ -44,6 +45,14 @@ def run(self):
4445
return self.val
4546

4647

48+
class _Runnable43:
49+
def __init__(self, func):
50+
self.func = func
51+
52+
def run(self):
53+
self.func()
54+
55+
4756
class TestBundleRun(unittest.TestCase):
4857
def setUp(self):
4958
self.data_dir = tempfile.mkdtemp()
@@ -77,6 +86,69 @@ def test_tiny(self):
7786
with self.assertRaises(RuntimeError):
7887
# test wrong run_id="run"
7988
command_line_tests(cmd + ["run", "run", "--config_file", config_file])
89+
with self.assertRaises(RuntimeError):
90+
# test missing meta file
91+
command_line_tests(cmd + ["run", "training", "--config_file", config_file])
92+
93+
def test_scripts_fold(self):
94+
# test scripts directory has been added to Python search directories automatically
95+
config_file = os.path.join(self.data_dir, "tiny_config.json")
96+
meta_file = os.path.join(self.data_dir, "tiny_meta.json")
97+
scripts_dir = os.path.join(self.data_dir, "scripts")
98+
script_file = os.path.join(scripts_dir, "test_scripts_fold.py")
99+
init_file = os.path.join(scripts_dir, "__init__.py")
100+
101+
with open(config_file, "w") as f:
102+
json.dump(
103+
{
104+
"imports": ["$import scripts"],
105+
"trainer": {
106+
"_target_": "tests.test_integration_bundle_run._Runnable43",
107+
"func": "$scripts.tiny_test",
108+
},
109+
# keep this test case to cover the "runner_id" arg
110+
"training": "$@trainer.run()",
111+
},
112+
f,
113+
)
114+
with open(meta_file, "w") as f:
115+
json.dump(
116+
{"version": "0.1.0", "monai_version": "1.1.0", "pytorch_version": "1.13.1", "numpy_version": "1.22.2"},
117+
f,
118+
)
119+
120+
os.mkdir(scripts_dir)
121+
script_file_lines = ["def tiny_test():\n", " print('successfully added scripts fold!') \n"]
122+
init_file_line = "from .test_scripts_fold import tiny_test\n"
123+
with open(script_file, "w") as f:
124+
f.writelines(script_file_lines)
125+
f.close()
126+
with open(init_file, "w") as f:
127+
f.write(init_file_line)
128+
f.close()
129+
130+
cmd = ["coverage", "run", "-m", "monai.bundle"]
131+
# test both CLI entry "run" and "run_workflow"
132+
expected_condition = "successfully added scripts fold!"
133+
command_run = cmd + ["run", "training", "--config_file", config_file, "--meta_file", meta_file]
134+
completed_process = subprocess.run(command_run, check=True, capture_output=True, text=True)
135+
output = repr(completed_process.stdout).replace("\\n", "\n").replace("\\t", "\t") # Get the captured output
136+
print(output)
137+
138+
self.assertTrue(expected_condition in output)
139+
command_run_workflow = cmd + [
140+
"run_workflow",
141+
"--run_id",
142+
"training",
143+
"--config_file",
144+
config_file,
145+
"--meta_file",
146+
meta_file,
147+
]
148+
completed_process = subprocess.run(command_run_workflow, check=True, capture_output=True, text=True)
149+
output = repr(completed_process.stdout).replace("\\n", "\n").replace("\\t", "\t") # Get the captured output
150+
print(output)
151+
self.assertTrue(expected_condition in output)
80152

81153
with self.assertRaises(RuntimeError):
82154
# test missing meta file

0 commit comments

Comments
 (0)