Skip to content

Commit 0356e77

Browse files
committed
Default to native PyTorch AMP instead of APEX amp. Too many APEX issues cropping up lately.
1 parent b4e216e commit 0356e77

File tree

3 files changed

+14
-9
lines changed

3 files changed

+14
-9
lines changed

timm/models/helpers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,7 @@ def load_pretrained(model, cfg=None, num_classes=1000, in_chans=3, filter_fn=Non
177177
if cfg is None:
178178
cfg = getattr(model, 'default_cfg')
179179
if cfg is None or 'url' not in cfg or not cfg['url']:
180-
_logger.warning("Pretrained model URL does not exist, using random initialization.")
180+
_logger.warning("No pretrained weights exist for this model. Using random initialization.")
181181
return
182182

183183
state_dict = load_state_dict_from_url(cfg['url'], progress=progress, map_location='cpu')

train.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -310,11 +310,11 @@ def main():
310310
# resolve AMP arguments based on PyTorch / Apex availability
311311
use_amp = None
312312
if args.amp:
313-
# for backwards compat, `--amp` arg tries apex before native amp
314-
if has_apex:
315-
args.apex_amp = True
316-
elif has_native_amp:
313+
# `--amp` chooses native amp before apex (APEX ver not actively maintained)
314+
if has_native_amp:
317315
args.native_amp = True
316+
elif has_apex:
317+
args.apex_amp = True
318318
if args.apex_amp and has_apex:
319319
use_amp = 'apex'
320320
elif args.native_amp and has_native_amp:

validate.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -116,15 +116,20 @@ def validate(args):
116116
args.prefetcher = not args.no_prefetcher
117117
amp_autocast = suppress # do nothing
118118
if args.amp:
119-
if has_apex:
120-
args.apex_amp = True
121-
elif has_native_amp:
119+
if has_native_amp:
122120
args.native_amp = True
121+
elif has_apex:
122+
args.apex_amp = True
123123
else:
124-
_logger.warning("Neither APEX or Native Torch AMP is available, using FP32.")
124+
_logger.warning("Neither APEX or Native Torch AMP is available.")
125125
assert not args.apex_amp or not args.native_amp, "Only one AMP mode should be set."
126126
if args.native_amp:
127127
amp_autocast = torch.cuda.amp.autocast
128+
_logger.info('Validating in mixed precision with native PyTorch AMP.')
129+
elif args.apex_amp:
130+
_logger.info('Validating in mixed precision with NVIDIA APEX AMP.')
131+
else:
132+
_logger.info('Validating in float32. AMP not enabled.')
128133

129134
if args.legacy_jit:
130135
set_jit_legacy()

0 commit comments

Comments
 (0)