Skip to content

Commit d3aeec3

Browse files
authored
5052 Fix bundle.load API (#5053)
Fixes #5052 . ### Description This PR updated the `bundle.load` API to fix the wrong path issue. ### Status **Ready** ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. Signed-off-by: Nic Ma <nma@nvidia.com>
1 parent 857e92b commit d3aeec3

File tree

2 files changed

+50
-25
lines changed

2 files changed

+50
-25
lines changed

monai/bundle/scripts.py

Lines changed: 28 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,7 @@ def _process_bundle_dir(bundle_dir: Optional[PathLike] = None):
153153

154154
def download(
155155
name: Optional[str] = None,
156+
version: Optional[str] = None,
156157
bundle_dir: Optional[PathLike] = None,
157158
source: str = "github",
158159
repo: str = "Project-MONAI/model-zoo/hosting_storage_v1",
@@ -170,11 +171,14 @@ def download(
170171
171172
.. code-block:: bash
172173
174+
# Execute this module as a CLI entry, and download bundle from the model-zoo repo:
175+
python -m monai.bundle download --name <bundle_name> --version "0.1.0" --bundle_dir "./"
176+
173177
# Execute this module as a CLI entry, and download bundle:
174-
python -m monai.bundle download --name "bundle_name" --source "github" --repo "repo_owner/repo_name/release_tag"
178+
python -m monai.bundle download --name <bundle_name> --source "github" --repo "repo_owner/repo_name/release_tag"
175179
176180
# Execute this module as a CLI entry, and download bundle via URL:
177-
python -m monai.bundle download --name "bundle_name" --url <url>
181+
python -m monai.bundle download --name <bundle_name> --url <url>
178182
179183
# Set default args of `run` in a JSON / YAML file, help to record and simplify the command line.
180184
# Other args still can override the default args at runtime.
@@ -185,6 +189,9 @@ def download(
185189
186190
Args:
187191
name: bundle name. If `None` and `url` is `None`, it must be provided in `args_file`.
192+
for example: "spleen_ct_segmentation", "prostate_mri_anatomy" in the model-zoo:
193+
https://github.com/Project-MONAI/model-zoo/releases/tag/hosting_storage_v1.
194+
version: version name of the target bundle to download, like: "0.1.0".
188195
bundle_dir: target directory to store the downloaded data.
189196
Default is `bundle` subfolder under `torch.hub.get_dir()`.
190197
source: storage location name. This argument is used when `url` is `None`.
@@ -200,19 +207,28 @@ def download(
200207
201208
"""
202209
_args = _update_args(
203-
args=args_file, name=name, bundle_dir=bundle_dir, source=source, repo=repo, url=url, progress=progress
210+
args=args_file,
211+
name=name,
212+
version=version,
213+
bundle_dir=bundle_dir,
214+
source=source,
215+
repo=repo,
216+
url=url,
217+
progress=progress,
204218
)
205219

206220
_log_input_summary(tag="download", args=_args)
207-
source_, repo_, progress_, name_, bundle_dir_, url_ = _pop_args(
208-
_args, "source", "repo", "progress", name=None, bundle_dir=None, url=None
221+
source_, repo_, progress_, name_, version_, bundle_dir_, url_ = _pop_args(
222+
_args, "source", "repo", "progress", name=None, version=None, bundle_dir=None, url=None
209223
)
210224

211225
bundle_dir_ = _process_bundle_dir(bundle_dir_)
226+
if name_ is not None and version_ is not None:
227+
name_ = "_v".join([name_, version_])
212228

213229
if url_ is not None:
214-
if name is not None:
215-
filepath = bundle_dir_ / f"{name}.zip"
230+
if name_ is not None:
231+
filepath = bundle_dir_ / f"{name_}.zip"
216232
else:
217233
filepath = bundle_dir_ / f"{_basename(url_)}"
218234
download_url(url=url_, filepath=filepath, hash_val=None, progress=progress_)
@@ -229,6 +245,7 @@ def download(
229245

230246
def load(
231247
name: str,
248+
version: Optional[str] = None,
232249
model_file: Optional[str] = None,
233250
load_ts_module: bool = False,
234251
bundle_dir: Optional[PathLike] = None,
@@ -245,7 +262,9 @@ def load(
245262
Load model weights or TorchScript module of a bundle.
246263
247264
Args:
248-
name: bundle name.
265+
name: bundle name, for example: "spleen_ct_segmentation", "prostate_mri_anatomy" in the model-zoo:
266+
https://github.com/Project-MONAI/model-zoo/releases/tag/hosting_storage_v1.
267+
version: version name of the target bundle to download, like: "0.1.0".
249268
model_file: the relative path of the model weights or TorchScript module within bundle.
250269
If `None`, "models/model.pt" or "models/model.ts" will be used.
251270
load_ts_module: a flag to specify if loading the TorchScript module.
@@ -280,7 +299,7 @@ def load(
280299
model_file = os.path.join("models", "model.ts" if load_ts_module is True else "model.pt")
281300
full_path = os.path.join(bundle_dir_, name, model_file)
282301
if not os.path.exists(full_path):
283-
download(name=name, bundle_dir=bundle_dir_, source=source, repo=repo, progress=progress)
302+
download(name=name, version=version, bundle_dir=bundle_dir_, source=source, repo=repo, progress=progress)
284303

285304
if device is None:
286305
device = "cuda:0" if is_available() else "cpu"

tests/test_bundle_download.py

Lines changed: 22 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -28,31 +28,31 @@
2828
skip_if_windows,
2929
)
3030

31-
TEST_CASE_1 = [
32-
["model.pt", "model.ts", "network.json", "test_output.pt", "test_input.pt"],
33-
"test_bundle",
34-
"Project-MONAI/MONAI-extra-test-data/0.8.1",
35-
"a131d39a0af717af32d19e565b434928",
36-
]
31+
TEST_CASE_1 = ["test_bundle", None]
32+
33+
TEST_CASE_2 = ["test_bundle_v0.1.1", None]
3734

38-
TEST_CASE_2 = [
35+
TEST_CASE_3 = ["test_bundle", "0.1.1"]
36+
37+
TEST_CASE_4 = [
3938
["model.pt", "model.ts", "network.json", "test_output.pt", "test_input.pt"],
4039
"test_bundle",
4140
"https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/test_bundle.zip",
4241
"a131d39a0af717af32d19e565b434928",
4342
]
4443

45-
TEST_CASE_3 = [
44+
TEST_CASE_5 = [
4645
["model.pt", "model.ts", "network.json", "test_output.pt", "test_input.pt"],
4746
"test_bundle",
4847
"Project-MONAI/MONAI-extra-test-data/0.8.1",
4948
"cuda" if torch.cuda.is_available() else "cpu",
5049
"model.pt",
5150
]
5251

53-
TEST_CASE_4 = [
52+
TEST_CASE_6 = [
5453
["test_output.pt", "test_input.pt"],
5554
"test_bundle",
55+
"0.1.1",
5656
"Project-MONAI/MONAI-extra-test-data/0.8.1",
5757
"cuda" if torch.cuda.is_available() else "cpu",
5858
"model.ts",
@@ -61,22 +61,27 @@
6161

6262
@skip_if_windows
6363
class TestDownload(unittest.TestCase):
64-
@parameterized.expand([TEST_CASE_1])
64+
@parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3])
6565
@skip_if_quick
66-
def test_download_bundle(self, bundle_files, bundle_name, repo, hash_val):
66+
def test_download_bundle(self, bundle_name, version):
67+
bundle_files = ["model.pt", "model.ts", "network.json", "test_output.pt", "test_input.pt"]
68+
repo = "Project-MONAI/MONAI-extra-test-data/0.8.1"
69+
hash_val = "a131d39a0af717af32d19e565b434928"
6770
with skip_if_downloading_fails():
6871
# download a whole bundle from github releases
6972
with tempfile.TemporaryDirectory() as tempdir:
7073
cmd = ["coverage", "run", "-m", "monai.bundle", "download", "--name", bundle_name, "--source", "github"]
7174
cmd += ["--bundle_dir", tempdir, "--repo", repo, "--progress", "False"]
75+
if version is not None:
76+
cmd += ["--version", version]
7277
command_line_tests(cmd)
7378
for file in bundle_files:
74-
file_path = os.path.join(tempdir, bundle_name, file)
79+
file_path = os.path.join(tempdir, "test_bundle", file)
7580
self.assertTrue(os.path.exists(file_path))
7681
if file == "network.json":
7782
self.assertTrue(check_hash(filepath=file_path, val=hash_val))
7883

79-
@parameterized.expand([TEST_CASE_2])
84+
@parameterized.expand([TEST_CASE_4])
8085
@skip_if_quick
8186
def test_url_download_bundle(self, bundle_files, bundle_name, url, hash_val):
8287
with skip_if_downloading_fails():
@@ -97,7 +102,7 @@ def test_url_download_bundle(self, bundle_files, bundle_name, url, hash_val):
97102

98103

99104
class TestLoad(unittest.TestCase):
100-
@parameterized.expand([TEST_CASE_3])
105+
@parameterized.expand([TEST_CASE_5])
101106
@skip_if_quick
102107
def test_load_weights(self, bundle_files, bundle_name, repo, device, model_file):
103108
with skip_if_downloading_fails():
@@ -144,16 +149,17 @@ def test_load_weights(self, bundle_files, bundle_name, repo, device, model_file)
144149
output_2 = model_2.forward(input_tensor)
145150
torch.testing.assert_allclose(output_2, expected_output)
146151

147-
@parameterized.expand([TEST_CASE_4])
152+
@parameterized.expand([TEST_CASE_6])
148153
@skip_if_quick
149154
@SkipIfBeforePyTorchVersion((1, 7, 1))
150-
def test_load_ts_module(self, bundle_files, bundle_name, repo, device, model_file):
155+
def test_load_ts_module(self, bundle_files, bundle_name, version, repo, device, model_file):
151156
with skip_if_downloading_fails():
152157
# load ts module
153158
with tempfile.TemporaryDirectory() as tempdir:
154159
# load ts module
155160
model_ts, metadata, extra_file_dict = load(
156161
name=bundle_name,
162+
version=version,
157163
model_file=model_file,
158164
load_ts_module=True,
159165
bundle_dir=tempdir,

0 commit comments

Comments
 (0)