Skip to content

Commit 187ecba

Browse files
committed
Add support for loading args from yaml file (and saving them with each experiment)
1 parent d3ba34e commit 187ecba

File tree

2 files changed

+31
-1
lines changed

2 files changed

+31
-1
lines changed

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
torch>=1.1.0
22
torchvision>=0.3.0
3+
pyyaml

train.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import argparse
33
import time
44
import logging
5+
import yaml
56
from datetime import datetime
67

78
try:
@@ -26,6 +27,14 @@
2627

2728
torch.backends.cudnn.benchmark = True
2829

30+
31+
# The first arg parser parses out only the --config argument, this argument is used to
32+
# load a yaml file containing key-values that override the defaults for the main parser below
33+
config_parser = parser = argparse.ArgumentParser(description='Training Config', add_help=False)
34+
parser.add_argument('-c', '--config', default='', type=str, metavar='FILE',
35+
help='YAML config file specifying default arguments')
36+
37+
2938
parser = argparse.ArgumentParser(description='Training')
3039
# Dataset / Model parameters
3140
parser.add_argument('data', metavar='DIR',
@@ -145,9 +154,27 @@
145154
parser.add_argument("--local_rank", default=0, type=int)
146155

147156

157+
def _parse_args():
158+
# Do we have a config file to parse?
159+
args_config, remaining = config_parser.parse_known_args()
160+
if args_config.config:
161+
with open(args_config.config, 'r') as f:
162+
cfg = yaml.safe_load(f)
163+
parser.set_defaults(**cfg)
164+
165+
# The main arg parser parses the rest of the args, the usual
166+
# defaults will have been overridden if config file specified.
167+
args = parser.parse_args(remaining)
168+
169+
# Cache the args as a text string to save them in the output dir later
170+
args_text = yaml.safe_dump(args.__dict__, default_flow_style=False)
171+
return args, args_text
172+
173+
148174
def main():
149175
setup_default_logging()
150-
args = parser.parse_args()
176+
args, args_text = _parse_args()
177+
151178
args.prefetcher = not args.no_prefetcher
152179
args.distributed = False
153180
if 'WORLD_SIZE' in os.environ:
@@ -345,6 +372,8 @@ def main():
345372
output_dir = get_outdir(output_base, 'train', exp_name)
346373
decreasing = True if eval_metric == 'loss' else False
347374
saver = CheckpointSaver(checkpoint_dir=output_dir, decreasing=decreasing)
375+
with open(os.path.join(output_dir, 'args.yaml'), 'w') as f:
376+
f.write(args_text)
348377

349378
try:
350379
for epoch in range(start_epoch, num_epochs):

0 commit comments

Comments
 (0)