|
| 1 | +import torch |
1 | 2 | from torchbench.image_classification import ImageNet |
2 | 3 | from timm import create_model |
3 | 4 | from timm.data import resolve_data_config, create_transform |
@@ -77,7 +78,7 @@ def _entry(model_name, paper_model_name, paper_arxiv_id, batch_size=BATCH_SIZE, |
77 | 78 | _entry('mixnet_m', 'MixNet-M', '1907.09595'), |
78 | 79 | _entry('mixnet_s', 'MixNet-S', '1907.09595'), |
79 | 80 | _entry('mnasnet_100', 'MnasNet-B1', '1807.11626'), |
80 | | - _entry('mobilenetv3_100', 'MobileNet V3(1.0)', '1905.02244', |
| 81 | + _entry('mobilenetv3_100', 'MobileNet V3-Large 1.0', '1905.02244', |
81 | 82 | model_desc='Trained in PyTorch with RMSProp, exponential LR decay, and hyper-params matching ' |
82 | 83 | 'paper as closely as possible.'), |
83 | 84 | _entry('resnet18', 'ResNet-18', '1812.01187'), |
@@ -216,4 +217,6 @@ def _entry(model_name, paper_model_name, paper_arxiv_id, batch_size=BATCH_SIZE, |
216 | 217 | data_root=os.environ.get('IMAGENET_DIR', './imagenet') |
217 | 218 | ) |
218 | 219 |
|
| 220 | + torch.cuda.empty_cache() |
| 221 | + |
219 | 222 |
|
0 commit comments