|
2 | 2 | import argparse |
3 | 3 | import time |
4 | 4 | import logging |
| 5 | +import yaml |
5 | 6 | from datetime import datetime |
6 | 7 |
|
7 | 8 | try: |
|
26 | 27 |
|
27 | 28 | torch.backends.cudnn.benchmark = True |
28 | 29 |
|
| 30 | + |
| 31 | +# The first arg parser parses out only the --config argument, this argument is used to |
| 32 | +# load a yaml file containing key-values that override the defaults for the main parser below |
| 33 | +config_parser = parser = argparse.ArgumentParser(description='Training Config', add_help=False) |
| 34 | +parser.add_argument('-c', '--config', default='', type=str, metavar='FILE', |
| 35 | + help='YAML config file specifying default arguments') |
| 36 | + |
| 37 | + |
29 | 38 | parser = argparse.ArgumentParser(description='Training') |
30 | 39 | # Dataset / Model parameters |
31 | 40 | parser.add_argument('data', metavar='DIR', |
|
145 | 154 | parser.add_argument("--local_rank", default=0, type=int) |
146 | 155 |
|
147 | 156 |
|
| 157 | +def _parse_args(): |
| 158 | + # Do we have a config file to parse? |
| 159 | + args_config, remaining = config_parser.parse_known_args() |
| 160 | + if args_config.config: |
| 161 | + with open(args_config.config, 'r') as f: |
| 162 | + cfg = yaml.safe_load(f) |
| 163 | + parser.set_defaults(**cfg) |
| 164 | + |
| 165 | + # The main arg parser parses the rest of the args, the usual |
| 166 | + # defaults will have been overridden if config file specified. |
| 167 | + args = parser.parse_args(remaining) |
| 168 | + |
| 169 | + # Cache the args as a text string to save them in the output dir later |
| 170 | + args_text = yaml.safe_dump(args.__dict__, default_flow_style=False) |
| 171 | + return args, args_text |
| 172 | + |
| 173 | + |
148 | 174 | def main(): |
149 | 175 | setup_default_logging() |
150 | | - args = parser.parse_args() |
| 176 | + args, args_text = _parse_args() |
| 177 | + |
151 | 178 | args.prefetcher = not args.no_prefetcher |
152 | 179 | args.distributed = False |
153 | 180 | if 'WORLD_SIZE' in os.environ: |
@@ -345,6 +372,8 @@ def main(): |
345 | 372 | output_dir = get_outdir(output_base, 'train', exp_name) |
346 | 373 | decreasing = True if eval_metric == 'loss' else False |
347 | 374 | saver = CheckpointSaver(checkpoint_dir=output_dir, decreasing=decreasing) |
| 375 | + with open(os.path.join(output_dir, 'args.yaml'), 'w') as f: |
| 376 | + f.write(args_text) |
348 | 377 |
|
349 | 378 | try: |
350 | 379 | for epoch in range(start_epoch, num_epochs): |
|
0 commit comments