178178 help = 'lower precision AMP dtype (default: float16)' )
179179group .add_argument ('--amp-impl' , default = 'native' , type = str ,
180180 help = 'AMP impl to use, "native" or "apex" (default: native)' )
181+ group .add_argument ('--model-dtype' , default = None , type = str ,
182+ help = 'Model dtype override (non-AMP) (default: float32)' )
181183group .add_argument ('--no-ddp-bb' , action = 'store_true' , default = False ,
182184 help = 'Force broadcast buffers for native DDP to off.' )
183185group .add_argument ('--synchronize-step' , action = 'store_true' , default = False ,
@@ -436,10 +438,18 @@ def main():
436438 _logger .info (f'Training with a single process on 1 device ({ args .device } ).' )
437439 assert args .rank >= 0
438440
441+ model_dtype = None
442+ if args .model_dtype :
443+ assert args .model_dtype in ('float32' , 'float16' , 'bfloat16' )
444+ model_dtype = getattr (torch , args .model_dtype )
445+ if model_dtype == torch .float16 :
446+ _logger .warning ('float16 is not recommended for training, for half precision bfloat16 is recommended.' )
447+
439448 # resolve AMP arguments based on PyTorch / Apex availability
440449 use_amp = None
441450 amp_dtype = torch .float16
442451 if args .amp :
452+ assert model_dtype is None or model_dtype == torch .float32 , 'float32 model dtype must be used with AMP'
443453 if args .amp_impl == 'apex' :
444454 assert has_apex , 'AMP impl specified as APEX but APEX is not installed.'
445455 use_amp = 'apex'
@@ -519,7 +529,7 @@ def main():
519529 model = convert_splitbn_model (model , max (num_aug_splits , 2 ))
520530
521531 # move model to GPU, enable channels last layout if set
522- model .to (device = device )
532+ model .to (device = device , dtype = model_dtype ) # FIXME move model device & dtype into create_model
523533 if args .channels_last :
524534 model .to (memory_format = torch .channels_last )
525535
@@ -589,7 +599,7 @@ def main():
589599 _logger .info ('Using native Torch AMP. Training in mixed precision.' )
590600 else :
591601 if utils .is_primary (args ):
592- _logger .info ('AMP not enabled. Training in float32.' )
602+ _logger .info (f 'AMP not enabled. Training in { model_dtype or torch . float32 } .' )
593603
594604 # optionally resume from a checkpoint
595605 resume_epoch = None
@@ -734,6 +744,7 @@ def main():
734744 distributed = args .distributed ,
735745 collate_fn = collate_fn ,
736746 pin_memory = args .pin_mem ,
747+ img_dtype = model_dtype ,
737748 device = device ,
738749 use_prefetcher = args .prefetcher ,
739750 use_multi_epochs_loader = args .use_multi_epochs_loader ,
@@ -758,6 +769,7 @@ def main():
758769 distributed = args .distributed ,
759770 crop_pct = data_config ['crop_pct' ],
760771 pin_memory = args .pin_mem ,
772+ img_dtype = model_dtype ,
761773 device = device ,
762774 use_prefetcher = args .prefetcher ,
763775 )
@@ -822,21 +834,21 @@ def main():
822834 with open (os .path .join (output_dir , 'args.yaml' ), 'w' ) as f :
823835 f .write (args_text )
824836
825- if utils . is_primary ( args ) and args .log_wandb :
826- if has_wandb :
827- assert not args .wandb_resume_id or args .resume
828- wandb .init (
829- project = args .wandb_project ,
830- name = args . experiment ,
831- config = args ,
832- tags = args .wandb_tags ,
833- resume = "must" if args .wandb_resume_id else None ,
834- id = args .wandb_resume_id if args .wandb_resume_id else None ,
835- )
836- else :
837- _logger .warning (
838- "You've requested to log metrics to wandb but package not found. "
839- "Metrics not being logged to wandb, try `pip install wandb`" )
837+ if args .log_wandb :
838+ if has_wandb :
839+ assert not args .wandb_resume_id or args .resume
840+ wandb .init (
841+ project = args .wandb_project ,
842+ name = exp_name ,
843+ config = args ,
844+ tags = args .wandb_tags ,
845+ resume = "must" if args .wandb_resume_id else None ,
846+ id = args .wandb_resume_id if args .wandb_resume_id else None ,
847+ )
848+ else :
849+ _logger .warning (
850+ "You've requested to log metrics to wandb but package not found. "
851+ "Metrics not being logged to wandb, try `pip install wandb`" )
840852
841853 # setup learning rate schedule and starting epoch
842854 updates_per_epoch = (len (loader_train ) + args .grad_accum_steps - 1 ) // args .grad_accum_steps
@@ -886,6 +898,7 @@ def main():
886898 output_dir = output_dir ,
887899 amp_autocast = amp_autocast ,
888900 loss_scaler = loss_scaler ,
901+ model_dtype = model_dtype ,
889902 model_ema = model_ema ,
890903 mixup_fn = mixup_fn ,
891904 num_updates_total = num_epochs * updates_per_epoch ,
@@ -904,6 +917,7 @@ def main():
904917 args ,
905918 device = device ,
906919 amp_autocast = amp_autocast ,
920+ model_dtype = model_dtype ,
907921 )
908922
909923 if model_ema is not None and not args .model_ema_force_cpu :
@@ -986,6 +1000,7 @@ def train_one_epoch(
9861000 output_dir = None ,
9871001 amp_autocast = suppress ,
9881002 loss_scaler = None ,
1003+ model_dtype = None ,
9891004 model_ema = None ,
9901005 mixup_fn = None ,
9911006 num_updates_total = None ,
@@ -1022,7 +1037,7 @@ def train_one_epoch(
10221037 accum_steps = last_accum_steps
10231038
10241039 if not args .prefetcher :
1025- input , target = input .to (device ), target .to (device )
1040+ input , target = input .to (device = device , dtype = model_dtype ), target .to (device = device )
10261041 if mixup_fn is not None :
10271042 input , target = mixup_fn (input , target )
10281043 if args .channels_last :
@@ -1149,6 +1164,7 @@ def validate(
11491164 args ,
11501165 device = torch .device ('cuda' ),
11511166 amp_autocast = suppress ,
1167+ model_dtype = None ,
11521168 log_suffix = ''
11531169):
11541170 batch_time_m = utils .AverageMeter ()
@@ -1164,8 +1180,8 @@ def validate(
11641180 for batch_idx , (input , target ) in enumerate (loader ):
11651181 last_batch = batch_idx == last_idx
11661182 if not args .prefetcher :
1167- input = input .to (device )
1168- target = target .to (device )
1183+ input = input .to (device = device , dtype = model_dtype )
1184+ target = target .to (device = device )
11691185 if args .channels_last :
11701186 input = input .contiguous (memory_format = torch .channels_last )
11711187
0 commit comments