Skip to content

Commit ff8688c

Browse files
authored
Merge pull request #62 from rwightman/reduce-bn
Distribute BatchNorm stats
2 parents 5d7af97 + 5719b49 commit ff8688c

File tree

3 files changed

+38
-10
lines changed

3 files changed

+38
-10
lines changed

timm/data/random_erasing.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,10 @@ def _get_pixels(per_pixel, rand_color, patch_size, dtype=torch.float32, device='
77
# NOTE I've seen CUDA illegal memory access errors being caused by the normal_()
88
# paths, flip the order so normal is run on CPU if this becomes a problem
99
# Issue has been fixed in master https://github.com/pytorch/pytorch/issues/19508
10-
# will revert back to doing normal_() on GPU when it's in next release
1110
if per_pixel:
12-
return torch.empty(
13-
patch_size, dtype=dtype).normal_().to(device=device)
11+
return torch.empty(patch_size, dtype=dtype, device=device).normal_()
1412
elif rand_color:
15-
return torch.empty((patch_size[0], 1, 1), dtype=dtype).normal_().to(device=device)
13+
return torch.empty((patch_size[0], 1, 1), dtype=dtype, device=device).normal_()
1614
else:
1715
return torch.zeros((patch_size[0], 1, 1), dtype=dtype, device=device)
1816

timm/utils.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,15 @@
2121
from torch import distributed as dist
2222

2323

24-
def get_state_dict(model):
24+
def unwrap_model(model):
2525
if isinstance(model, ModelEma):
26-
return get_state_dict(model.ema)
26+
return unwrap_model(model.ema)
2727
else:
28-
return model.module.state_dict() if hasattr(model, 'module') else model.state_dict()
28+
return model.module if hasattr(model, 'module') else model
29+
30+
31+
def get_state_dict(model):
32+
return unwrap_model(model).state_dict()
2933

3034

3135
class CheckpointSaver:
@@ -206,6 +210,19 @@ def reduce_tensor(tensor, n):
206210
return rt
207211

208212

213+
def distribute_bn(model, world_size, reduce=False):
214+
# ensure every node has the same running bn stats
215+
for bn_name, bn_buf in unwrap_model(model).named_buffers(recurse=True):
216+
if ('running_mean' in bn_name) or ('running_var' in bn_name):
217+
if reduce:
218+
# average bn stats across whole group
219+
torch.distributed.all_reduce(bn_buf, op=dist.ReduceOp.SUM)
220+
bn_buf /= float(world_size)
221+
else:
222+
# broadcast bn stats from rank 0 to whole group
223+
torch.distributed.broadcast(bn_buf, 0)
224+
225+
209226
class ModelEma:
210227
""" Model Exponential Moving Average
211228
Keep a moving average of everything in the model state_dict (parameters and buffers).

train.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,8 @@
5555
help='Type of global pool, "avg", "max", "avgmax", "avgmaxc" (default: "avg")')
5656
parser.add_argument('--img-size', type=int, default=None, metavar='N',
5757
help='Image patch size (default: None => model default)')
58+
parser.add_argument('--crop-pct', default=None, type=float,
59+
metavar='N', help='Input image center crop percent (for validation only)')
5860
parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN',
5961
help='Override mean pixel value of dataset')
6062
parser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD',
@@ -121,6 +123,10 @@
121123
help='BatchNorm momentum override (if not None)')
122124
parser.add_argument('--bn-eps', type=float, default=None,
123125
help='BatchNorm epsilon override (if not None)')
126+
parser.add_argument('--sync-bn', action='store_true',
127+
help='Enable NVIDIA Apex or Torch synchronized BatchNorm.')
128+
parser.add_argument('--dist-bn', type=str, default='',
129+
help='Distribute BatchNorm stats between nodes after each epoch ("broadcast", "reduce", or "")')
124130
# Model Exponential Moving Average
125131
parser.add_argument('--model-ema', action='store_true', default=False,
126132
help='Enable tracking moving average of model weights')
@@ -143,8 +149,6 @@
143149
help='save images of input bathes every log interval for debugging')
144150
parser.add_argument('--amp', action='store_true', default=False,
145151
help='use NVIDIA amp for mixed precision training')
146-
parser.add_argument('--sync-bn', action='store_true',
147-
help='enabling apex sync BN.')
148152
parser.add_argument('--no-prefetcher', action='store_true', default=False,
149153
help='disable fast prefetcher')
150154
parser.add_argument('--output', default='', type=str, metavar='PATH',
@@ -256,7 +260,7 @@ def main():
256260
if args.local_rank == 0:
257261
logging.info('Restoring NVIDIA AMP state from checkpoint')
258262
amp.load_state_dict(resume_state['amp'])
259-
resume_state = None # clear it
263+
del resume_state
260264

261265
model_ema = None
262266
if args.model_ema:
@@ -347,6 +351,7 @@ def main():
347351
std=data_config['std'],
348352
num_workers=args.workers,
349353
distributed=args.distributed,
354+
crop_pct=data_config['crop_pct'],
350355
)
351356

352357
if args.mixup > 0.:
@@ -388,9 +393,17 @@ def main():
388393
lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir,
389394
use_amp=use_amp, model_ema=model_ema)
390395

396+
if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
397+
if args.local_rank == 0:
398+
logging.info("Distributing BatchNorm running means and vars")
399+
distribute_bn(model, args.world_size, args.dist_bn == 'reduce')
400+
391401
eval_metrics = validate(model, loader_eval, validate_loss_fn, args)
392402

393403
if model_ema is not None and not args.model_ema_force_cpu:
404+
if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
405+
distribute_bn(model_ema, args.world_size, args.dist_bn == 'reduce')
406+
394407
ema_eval_metrics = validate(
395408
model_ema.ema, loader_eval, validate_loss_fn, args, log_suffix=' (EMA)')
396409
eval_metrics = ema_eval_metrics

0 commit comments

Comments
 (0)