Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions deepmd/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,11 @@ def main_parser() -> argparse.ArgumentParser:
action="store_true",
help="(Supported backend: PyTorch) Force load from ckpt, other missing tensors will init from scratch",
)
parser_train.add_argument(
"--allow-ref",
action="store_true",
help="Allow loading external JSON/YAML snippets through `$ref`. Disabled by default for security.",
)

# * freeze script ******************************************************************
parser_frz = subparsers.add_parser(
Expand Down
12 changes: 11 additions & 1 deletion deepmd/pd/entrypoints/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,16 @@ def train(
use_pretrain_script: bool = False,
force_load: bool = False,
output: str = "out.json",
allow_ref: bool = False,
) -> None:
"""Train a model with Paddle backend.

Parameters
----------
allow_ref : bool, default=False
Whether to allow loading external JSON/YAML snippets via ``$ref``
in the training input. Disabled by default for security.
"""
log.info("Configuration path: %s", input_file)
if LOCAL_RANK == 0:
SummaryPrinter()()
Expand Down Expand Up @@ -292,7 +301,7 @@ def train(

# argcheck
config = update_deepmd_input(config, warning=True, dump="input_v2_compat.json")
config = normalize(config, multi_task=multi_task)
config = normalize(config, multi_task=multi_task, allow_ref=allow_ref)

# do neighbor stat
min_nbor_dist = None
Expand Down Expand Up @@ -600,6 +609,7 @@ def main(args: list[str] | argparse.Namespace | None = None) -> None:
use_pretrain_script=FLAGS.use_pretrain_script,
force_load=FLAGS.force_load,
output=FLAGS.output,
allow_ref=FLAGS.allow_ref,
)
elif FLAGS.command == "freeze":
if Path(FLAGS.checkpoint_folder).is_dir():
Expand Down
12 changes: 11 additions & 1 deletion deepmd/pt/entrypoints/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,16 @@ def train(
use_pretrain_script: bool = False,
force_load: bool = False,
output: str = "out.json",
allow_ref: bool = False,
) -> None:
"""Train a model with PyTorch backend.

Parameters
----------
allow_ref : bool, default=False
Whether to allow loading external JSON/YAML snippets via ``$ref``
in the training input. Disabled by default for security.
"""
log.info("Configuration path: %s", input_file)
env.CUSTOM_OP_USE_JIT = True
if LOCAL_RANK == 0:
Expand Down Expand Up @@ -325,7 +334,7 @@ def train(

# argcheck
config = update_deepmd_input(config, warning=True, dump="input_v2_compat.json")
config = normalize(config, multi_task=multi_task)
config = normalize(config, multi_task=multi_task, allow_ref=allow_ref)

# do neighbor stat
min_nbor_dist = None
Expand Down Expand Up @@ -578,6 +587,7 @@ def main(args: list[str] | argparse.Namespace | None = None) -> None:
use_pretrain_script=FLAGS.use_pretrain_script,
force_load=FLAGS.force_load,
output=FLAGS.output,
allow_ref=FLAGS.allow_ref,
)
elif FLAGS.command == "freeze":
if Path(FLAGS.checkpoint_folder).is_dir():
Expand Down
6 changes: 5 additions & 1 deletion deepmd/tf/entrypoints/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ def train(
skip_neighbor_stat: bool = False,
finetune: str | None = None,
use_pretrain_script: bool = False,
allow_ref: bool = False,
**kwargs: Any,
) -> None:
"""Run DeePMD model training.
Expand Down Expand Up @@ -101,6 +102,9 @@ def train(
use_pretrain_script : bool
Whether to use model script in pretrained model when doing init-model or init-frz-model.
Note that this option is true and unchangeable for fine-tuning.
allow_ref : bool, default=False
Whether to allow loading external JSON/YAML snippets via ``$ref``
in the training input. Disabled by default for security.
**kwargs
additional arguments

Expand Down Expand Up @@ -168,7 +172,7 @@ def train(

jdata = update_deepmd_input(jdata, warning=True, dump="input_v2_compat.json")

jdata = normalize(jdata)
jdata = normalize(jdata, allow_ref=allow_ref)

if not is_compress and not skip_neighbor_stat:
jdata = update_sel(jdata)
Expand Down
27 changes: 24 additions & 3 deletions deepmd/utils/argcheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -3821,10 +3821,31 @@ def gen_json_schema(multi_task: bool = False) -> str:
return json.dumps(generate_json_schema(arg))


def normalize(data: dict[str, Any], multi_task: bool = False) -> dict[str, Any]:
def normalize(
data: dict[str, Any],
multi_task: bool = False,
allow_ref: bool = False,
) -> dict[str, Any]:
"""Normalize and validate DeePMD input config.

Parameters
----------
data : dict[str, Any]
Input training configuration.
multi_task : bool, default=False
Whether to use multi-task argument schema.
allow_ref : bool, default=False
Whether to allow loading external JSON/YAML snippets via ``$ref``.
Disabled by default for security.

Returns
-------
dict[str, Any]
Normalized and validated configuration.
"""
base = Argument("base", dict, gen_args(multi_task=multi_task))
data = base.normalize_value(data, trim_pattern="_*")
base.check_value(data, strict=True)
data = base.normalize_value(data, trim_pattern="_*", allow_ref=allow_ref)
base.check_value(data, strict=True, allow_ref=allow_ref)

return data

Expand Down
2 changes: 2 additions & 0 deletions doc/train/training-advanced.md
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,8 @@ An explanation will be provided

**`--skip-neighbor-stat`** will skip calculating neighbor statistics if one is concerned about performance. Some features will be disabled.

**`--allow-ref`** enables loading external JSON/YAML snippets via `$ref` during input validation. This option is disabled by default for security.

To maximize the performance, one should follow [FAQ: How to control the parallelism of a job](../troubleshooting/howtoset_num_nodes.md) to control the number of threads.
See [Runtime environment variables](../env.md) for all runtime environment variables.

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ dependencies = [
'numpy>=1.21',
'scipy',
'pyyaml',
'dargs >= 0.4.7',
'dargs >= 0.5.0',
'typing_extensions>=4.0.0',
'importlib_metadata>=1.4; python_version < "3.8"',
'h5py',
Expand Down
8 changes: 8 additions & 0 deletions source/tests/common/test_argument_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,14 @@ def test_parser_train_wrong_subcommand(self) -> None:
with self.assertRaises(SystemExit):
self.run_test(command="train", mapping=ARGS)

def test_parser_train_allow_ref(self) -> None:
"""Test train --allow-ref option."""
args = parse_args(["train", "INFILE", "--allow-ref"])
self.assertTrue(args.allow_ref)

args_default = parse_args(["train", "INFILE"])
self.assertFalse(args_default.allow_ref)

def test_parser_freeze(self) -> None:
"""Test freeze subparser."""
ARGS = {
Expand Down