|
55 | 55 | help='Type of global pool, "avg", "max", "avgmax", "avgmaxc" (default: "avg")') |
56 | 56 | parser.add_argument('--img-size', type=int, default=None, metavar='N', |
57 | 57 | help='Image patch size (default: None => model default)') |
| 58 | +parser.add_argument('--crop-pct', default=None, type=float, |
| 59 | + metavar='N', help='Input image center crop percent (for validation only)') |
58 | 60 | parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN', |
59 | 61 | help='Override mean pixel value of dataset') |
60 | 62 | parser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD', |
|
121 | 123 | help='BatchNorm momentum override (if not None)') |
122 | 124 | parser.add_argument('--bn-eps', type=float, default=None, |
123 | 125 | help='BatchNorm epsilon override (if not None)') |
| 126 | +parser.add_argument('--sync-bn', action='store_true', |
| 127 | + help='Enable NVIDIA Apex or Torch synchronized BatchNorm.') |
| 128 | +parser.add_argument('--dist-bn', type=str, default='', |
| 129 | + help='Distribute BatchNorm stats between nodes after each epoch ("broadcast", "reduce", or "")') |
124 | 130 | # Model Exponential Moving Average |
125 | 131 | parser.add_argument('--model-ema', action='store_true', default=False, |
126 | 132 | help='Enable tracking moving average of model weights') |
|
143 | 149 | help='save images of input bathes every log interval for debugging') |
144 | 150 | parser.add_argument('--amp', action='store_true', default=False, |
145 | 151 | help='use NVIDIA amp for mixed precision training') |
146 | | -parser.add_argument('--sync-bn', action='store_true', |
147 | | - help='enabling apex sync BN.') |
148 | 152 | parser.add_argument('--no-prefetcher', action='store_true', default=False, |
149 | 153 | help='disable fast prefetcher') |
150 | 154 | parser.add_argument('--output', default='', type=str, metavar='PATH', |
@@ -256,7 +260,7 @@ def main(): |
256 | 260 | if args.local_rank == 0: |
257 | 261 | logging.info('Restoring NVIDIA AMP state from checkpoint') |
258 | 262 | amp.load_state_dict(resume_state['amp']) |
259 | | - resume_state = None # clear it |
| 263 | + del resume_state |
260 | 264 |
|
261 | 265 | model_ema = None |
262 | 266 | if args.model_ema: |
@@ -347,6 +351,7 @@ def main(): |
347 | 351 | std=data_config['std'], |
348 | 352 | num_workers=args.workers, |
349 | 353 | distributed=args.distributed, |
| 354 | + crop_pct=data_config['crop_pct'], |
350 | 355 | ) |
351 | 356 |
|
352 | 357 | if args.mixup > 0.: |
@@ -388,9 +393,17 @@ def main(): |
388 | 393 | lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir, |
389 | 394 | use_amp=use_amp, model_ema=model_ema) |
390 | 395 |
|
| 396 | + if args.distributed and args.dist_bn in ('broadcast', 'reduce'): |
| 397 | + if args.local_rank == 0: |
| 398 | + logging.info("Distributing BatchNorm running means and vars") |
| 399 | + distribute_bn(model, args.world_size, args.dist_bn == 'reduce') |
| 400 | + |
391 | 401 | eval_metrics = validate(model, loader_eval, validate_loss_fn, args) |
392 | 402 |
|
393 | 403 | if model_ema is not None and not args.model_ema_force_cpu: |
| 404 | + if args.distributed and args.dist_bn in ('broadcast', 'reduce'): |
| 405 | + distribute_bn(model_ema, args.world_size, args.dist_bn == 'reduce') |
| 406 | + |
394 | 407 | ema_eval_metrics = validate( |
395 | 408 | model_ema.ema, loader_eval, validate_loss_fn, args, log_suffix=' (EMA)') |
396 | 409 | eval_metrics = ema_eval_metrics |
|
0 commit comments