Skip to content

Commit 5719b49

Browse files
committed
Missed update dist-bn logic for EMA model
1 parent a435ea1 commit 5719b49

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

train.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -393,16 +393,16 @@ def main():
393393
lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir,
394394
use_amp=use_amp, model_ema=model_ema)
395395

396-
if args.distributed and args.dist_bn and args.dist_bn in ('broadcast', 'reduce'):
396+
if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
397397
if args.local_rank == 0:
398398
logging.info("Distributing BatchNorm running means and vars")
399399
distribute_bn(model, args.world_size, args.dist_bn == 'reduce')
400400

401401
eval_metrics = validate(model, loader_eval, validate_loss_fn, args)
402402

403403
if model_ema is not None and not args.model_ema_force_cpu:
404-
if args.distributed and args.reduce_bn:
405-
distribute_bn(model_ema, args.world_size)
404+
if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
405+
distribute_bn(model_ema, args.world_size, args.dist_bn == 'reduce')
406406

407407
ema_eval_metrics = validate(
408408
model_ema.ema, loader_eval, validate_loss_fn, args, log_suffix=' (EMA)')

0 commit comments

Comments
 (0)