|
4 | 4 | import pytest |
5 | 5 | import torch |
6 | 6 | import random |
| 7 | +import importlib |
7 | 8 |
|
8 | 9 | # mock detection module |
9 | 10 | sys.modules['torchvision._C'] = mock.Mock() |
@@ -35,18 +36,17 @@ def _select_names(names, k=2): |
35 | 36 | return names |
36 | 37 |
|
37 | 38 |
|
38 | | -def _test_forward_backward(model_fn, encoder_name): |
39 | | - |
40 | | - model = model_fn(encoder_name, encoder_weights=None) |
| 39 | +def _test_forward_backward(model_fn, encoder_name, **model_params): |
| 40 | + model = model_fn(encoder_name, encoder_weights=None, **model_params) |
41 | 41 |
|
42 | 42 | x = torch.ones((1, 3, 64, 64)) |
43 | 43 | y = model.forward(x) |
44 | 44 | l = y.mean() |
45 | 45 | l.backward() |
46 | 46 |
|
47 | 47 |
|
48 | | -def _test_pretrained_model(model_fn, encoder_name, encoder_weights): |
49 | | - model = model_fn(encoder_name, encoder_weights=encoder_weights) |
| 48 | +def _test_pretrained_model(model_fn, encoder_name, encoder_weights, **model_params): |
| 49 | + model = model_fn(encoder_name, encoder_weights=encoder_weights, **model_params) |
50 | 50 |
|
51 | 51 | x = torch.ones((1, 3, 64, 64)) |
52 | 52 | y = model.predict(x) |
@@ -82,5 +82,11 @@ def test_pspnet(encoder_name): |
82 | 82 | _test_pretrained_model(smp.PSPNet, encoder_name, get_pretrained_weights_name(encoder_name)) |
83 | 83 |
|
84 | 84 |
|
| 85 | +@pytest.mark.skipif(importlib.util.find_spec('inplace_abn') is None, reason='') |
| 86 | +def test_inplace_abn(): |
| 87 | + _test_forward_backward(smp.Unet, 'resnet18', decoder_use_batchnorm='inplace') |
| 88 | + _test_pretrained_model(smp.Unet, 'resnet18', get_pretrained_weights_name('resnet18'), decoder_use_batchnorm='inplace') |
| 89 | + |
| 90 | + |
85 | 91 | if __name__ == '__main__': |
86 | 92 | pytest.main([__file__]) |
0 commit comments