Skip to content

Commit 5fd23b8

Browse files
KumoLiupre-commit-ci[bot]ericspod
authored
Improve ckpt_export (#6965)
Fixes #6953 ### Description ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: KumoLiu <yunl@nvidia.com> Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com>
1 parent 5b9d24e commit 5fd23b8

File tree

2 files changed

+47
-5
lines changed

2 files changed

+47
-5
lines changed

monai/bundle/scripts.py

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1214,9 +1214,13 @@ def ckpt_export(
12141214
12151215
Args:
12161216
net_id: ID name of the network component in the config, it must be `torch.nn.Module`.
1217+
Default to "network_def".
12171218
filepath: filepath to export, if filename has no extension it becomes `.ts`.
1219+
Default to "models/model.ts" under "os.getcwd()" if `bundle_root` is not specified.
12181220
ckpt_file: filepath of the model checkpoint to load.
1221+
Default to "models/model.pt" under "os.getcwd()" if `bundle_root` is not specified.
12191222
meta_file: filepath of the metadata file, if it is a list of file paths, the content of them will be merged.
1223+
Default to "configs/metadata.json" under "os.getcwd()" if `bundle_root` is not specified.
12201224
config_file: filepath of the config file to save in TorchScript model and extract network information,
12211225
the saved key in the TorchScript model is the config filename without extension, and the saved config
12221226
value is always serialized in JSON format no matter the original file format is JSON or YAML.
@@ -1250,9 +1254,10 @@ def ckpt_export(
12501254
)
12511255
_log_input_summary(tag="ckpt_export", args=_args)
12521256
(
1257+
config_file_,
12531258
filepath_,
12541259
ckpt_file_,
1255-
config_file_,
1260+
bundle_root_,
12561261
net_id_,
12571262
meta_file_,
12581263
key_in_ckpt_,
@@ -1261,10 +1266,11 @@ def ckpt_export(
12611266
converter_kwargs_,
12621267
) = _pop_args(
12631268
_args,
1264-
"filepath",
1265-
"ckpt_file",
12661269
"config_file",
1267-
net_id="",
1270+
filepath=None,
1271+
ckpt_file=None,
1272+
bundle_root=os.getcwd(),
1273+
net_id=None,
12681274
meta_file=None,
12691275
key_in_ckpt="",
12701276
use_trace=False,
@@ -1275,9 +1281,22 @@ def ckpt_export(
12751281
parser = ConfigParser()
12761282

12771283
parser.read_config(f=config_file_)
1278-
if meta_file_ is not None:
1284+
meta_file_ = os.path.join(bundle_root_, "configs", "metadata.json") if meta_file_ is None else meta_file_
1285+
filepath_ = os.path.join(bundle_root_, "models", "model.ts") if filepath_ is None else filepath_
1286+
ckpt_file_ = os.path.join(bundle_root_, "models", "model.pt") if ckpt_file_ is None else ckpt_file_
1287+
if not os.path.exists(ckpt_file_):
1288+
raise FileNotFoundError(f'Checkpoint file "{ckpt_file_}" not found, please specify it in argument "ckpt_file".')
1289+
if os.path.exists(meta_file_):
12791290
parser.read_meta(f=meta_file_)
12801291

1292+
net_id_ = "network_def" if net_id_ is None else net_id_
1293+
try:
1294+
parser.get_parsed_content(net_id_)
1295+
except ValueError as e:
1296+
raise ValueError(
1297+
f'Network definition "{net_id_}" cannot be found in "{config_file_}", specify name with argument "net_id".'
1298+
) from e
1299+
12811300
# the rest key-values in the _args are to override config content
12821301
for k, v in _args.items():
12831302
parser[k] = v

tests/test_bundle_ckpt_export.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,29 @@ def test_export(self, key_in_ckpt, use_trace):
7575
self.assertTrue("meta_file" in json.loads(extra_files["def_args.json"]))
7676
self.assertTrue("network_def" in json.loads(extra_files["inference.json"]))
7777

78+
@parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3])
79+
def test_default_value(self, key_in_ckpt, use_trace):
80+
config_file = os.path.join(os.path.dirname(__file__), "testing_data", "inference.json")
81+
with tempfile.TemporaryDirectory() as tempdir:
82+
def_args = {"meta_file": "will be replaced by `meta_file` arg"}
83+
def_args_file = os.path.join(tempdir, "def_args.yaml")
84+
ckpt_file = os.path.join(tempdir, "models/model.pt")
85+
ts_file = os.path.join(tempdir, "models/model.ts")
86+
87+
parser = ConfigParser()
88+
parser.export_config_file(config=def_args, filepath=def_args_file)
89+
parser.read_config(config_file)
90+
net = parser.get_parsed_content("network_def")
91+
save_state(src=net if key_in_ckpt == "" else {key_in_ckpt: net}, path=ckpt_file)
92+
93+
# check with default value
94+
cmd = ["coverage", "run", "-m", "monai.bundle", "ckpt_export", "--key_in_ckpt", key_in_ckpt]
95+
cmd += ["--config_file", config_file, "--bundle_root", tempdir]
96+
if use_trace == "True":
97+
cmd += ["--use_trace", use_trace, "--input_shape", "[1, 1, 96, 96, 96]"]
98+
command_line_tests(cmd)
99+
self.assertTrue(os.path.exists(ts_file))
100+
78101

79102
if __name__ == "__main__":
80103
unittest.main()

0 commit comments

Comments
 (0)