178178 help = 'lower precision AMP dtype (default: float16)' )
179179group .add_argument ('--amp-impl' , default = 'native' , type = str ,
180180 help = 'AMP impl to use, "native" or "apex" (default: native)' )
181+ group .add_argument ('--model-dtype' , default = None , type = str ,
182+ help = 'Model dtype override (non-AMP) (default: float32)' )
181183group .add_argument ('--no-ddp-bb' , action = 'store_true' , default = False ,
182184 help = 'Force broadcast buffers for native DDP to off.' )
183185group .add_argument ('--synchronize-step' , action = 'store_true' , default = False ,
@@ -434,10 +436,18 @@ def main():
434436 _logger .info (f'Training with a single process on 1 device ({ args .device } ).' )
435437 assert args .rank >= 0
436438
439+ model_dtype = None
440+ if args .model_dtype :
441+ assert args .model_dtype in ('float32' , 'float16' , 'bfloat16' )
442+ model_dtype = getattr (torch , args .model_dtype )
443+ if model_dtype == torch .float16 :
444+ _logger .warning ('float16 is not recommended for training, for half precision bfloat16 is recommended.' )
445+
437446 # resolve AMP arguments based on PyTorch / Apex availability
438447 use_amp = None
439448 amp_dtype = torch .float16
440449 if args .amp :
450+ assert model_dtype is None or model_dtype == torch .float32 , 'float32 model dtype must be used with AMP'
441451 if args .amp_impl == 'apex' :
442452 assert has_apex , 'AMP impl specified as APEX but APEX is not installed.'
443453 use_amp = 'apex'
@@ -517,7 +527,7 @@ def main():
517527 model = convert_splitbn_model (model , max (num_aug_splits , 2 ))
518528
519529 # move model to GPU, enable channels last layout if set
520- model .to (device = device )
530+ model .to (device = device , dtype = model_dtype ) # FIXME move model device & dtype into create_model
521531 if args .channels_last :
522532 model .to (memory_format = torch .channels_last )
523533
@@ -587,7 +597,7 @@ def main():
587597 _logger .info ('Using native Torch AMP. Training in mixed precision.' )
588598 else :
589599 if utils .is_primary (args ):
590- _logger .info ('AMP not enabled. Training in float32.' )
600+ _logger .info (f 'AMP not enabled. Training in { model_dtype or torch . float32 } .' )
591601
592602 # optionally resume from a checkpoint
593603 resume_epoch = None
@@ -732,6 +742,7 @@ def main():
732742 distributed = args .distributed ,
733743 collate_fn = collate_fn ,
734744 pin_memory = args .pin_mem ,
745+ img_dtype = model_dtype ,
735746 device = device ,
736747 use_prefetcher = args .prefetcher ,
737748 use_multi_epochs_loader = args .use_multi_epochs_loader ,
@@ -756,6 +767,7 @@ def main():
756767 distributed = args .distributed ,
757768 crop_pct = data_config ['crop_pct' ],
758769 pin_memory = args .pin_mem ,
770+ img_dtype = model_dtype ,
759771 device = device ,
760772 use_prefetcher = args .prefetcher ,
761773 )
@@ -823,9 +835,13 @@ def main():
823835 if utils .is_primary (args ) and args .log_wandb :
824836 if has_wandb :
825837 assert not args .wandb_resume_id or args .resume
826- wandb .init (project = args .experiment , config = args , tags = args .wandb_tags ,
827- resume = 'must' if args .wandb_resume_id else None ,
828- id = args .wandb_resume_id if args .wandb_resume_id else None )
838+ wandb .init (
839+ project = args .experiment ,
840+ config = args ,
841+ tags = args .wandb_tags ,
842+ resume = 'must' if args .wandb_resume_id else None ,
843+ id = args .wandb_resume_id if args .wandb_resume_id else None ,
844+ )
829845 else :
830846 _logger .warning (
831847 "You've requested to log metrics to wandb but package not found. "
@@ -879,6 +895,7 @@ def main():
879895 output_dir = output_dir ,
880896 amp_autocast = amp_autocast ,
881897 loss_scaler = loss_scaler ,
898+ model_dtype = model_dtype ,
882899 model_ema = model_ema ,
883900 mixup_fn = mixup_fn ,
884901 num_updates_total = num_epochs * updates_per_epoch ,
@@ -897,6 +914,7 @@ def main():
897914 args ,
898915 device = device ,
899916 amp_autocast = amp_autocast ,
917+ model_dtype = model_dtype ,
900918 )
901919
902920 if model_ema is not None and not args .model_ema_force_cpu :
@@ -979,6 +997,7 @@ def train_one_epoch(
979997 output_dir = None ,
980998 amp_autocast = suppress ,
981999 loss_scaler = None ,
1000+ model_dtype = None ,
9821001 model_ema = None ,
9831002 mixup_fn = None ,
9841003 num_updates_total = None ,
@@ -1015,7 +1034,7 @@ def train_one_epoch(
10151034 accum_steps = last_accum_steps
10161035
10171036 if not args .prefetcher :
1018- input , target = input .to (device ), target .to (device )
1037+ input , target = input .to (device = device , dtype = model_dtype ), target .to (device = device )
10191038 if mixup_fn is not None :
10201039 input , target = mixup_fn (input , target )
10211040 if args .channels_last :
@@ -1142,6 +1161,7 @@ def validate(
11421161 args ,
11431162 device = torch .device ('cuda' ),
11441163 amp_autocast = suppress ,
1164+ model_dtype = None ,
11451165 log_suffix = ''
11461166):
11471167 batch_time_m = utils .AverageMeter ()
@@ -1157,8 +1177,8 @@ def validate(
11571177 for batch_idx , (input , target ) in enumerate (loader ):
11581178 last_batch = batch_idx == last_idx
11591179 if not args .prefetcher :
1160- input = input .to (device )
1161- target = target .to (device )
1180+ input = input .to (device = device , dtype = model_dtype )
1181+ target = target .to (device = device )
11621182 if args .channels_last :
11631183 input = input .contiguous (memory_format = torch .channels_last )
11641184
0 commit comments