Skip to content

Commit 2907a85

Browse files
committed
Add test + missing case for conv1d
1 parent 8044a8f commit 2907a85

File tree

2 files changed

+72
-0
lines changed

2 files changed

+72
-0
lines changed

bnn/ops.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,8 @@ def _compute_alpha(self, x: torch.Tensor) -> torch.Tensor:
117117
n = x[0].nelement()
118118
if x.dim() == 4:
119119
alpha = x.norm(1, 3, keepdim=True).sum([2, 1], keepdim=True).div_(n)
120+
elif x.dim() == 3:
121+
alpha = x.norm(1, 2, keepdim=True).sum([1], keepdim=True).div_(n)
120122
elif x.dim() == 2:
121123
alpha = x.norm(1, 1, keepdim=True).div_(n)
122124
else:

test/test_layers.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
import copy
2+
import unittest
3+
4+
import torch
5+
import torch.nn as nn
6+
7+
from bnn import BConfig, prepare_binary_model
8+
from bnn.ops import (
9+
BasicInputBinarizer,
10+
BasicScaleBinarizer,
11+
XNORWeightBinarizer
12+
)
13+
14+
class BinaryLayersTestCase(unittest.TestCase):
15+
def setUp(self) -> None:
16+
self.test_bconfig = BConfig(
17+
activation_pre_process=BasicInputBinarizer,
18+
activation_post_process=BasicScaleBinarizer,
19+
weight_pre_process=XNORWeightBinarizer
20+
)
21+
self.data = torch.tensor([-0.05263, -0.05068, -0.03849, 0.03104, 0.0772, 0.03038, -0.06640, 0.05894,
22+
0.13059, 0.03433, -0.25811, 0.13785]).view(1, 3, 2, 2)
23+
self.weights = torch.tensor([-0.0252, 0.0084, -0.0676, 0.0891, -0.0010, 0.0518, 0.0380, 0.2866,
24+
-0.0050])
25+
26+
def tearDown(self) -> None:
27+
pass
28+
29+
def test_linear_layer(self):
30+
layer = nn.Linear(3, 3, bias=False)
31+
layer.weight.data.copy_(self.weights.view(3, 3))
32+
x = self.data[:, :, 0, 0].view(1, 3)
33+
layer = prepare_binary_model(layer, bconfig=self.test_bconfig)
34+
35+
output = layer(x)
36+
expected = torch.tensor([[0.0337, -0.0473, -0.1099]])
37+
self.assertTrue(torch.allclose(expected, output, atol=1e-4))
38+
39+
def test_conv1d_layer(self):
40+
layer = nn.Conv1d(3, 3, 1, bias=False)
41+
layer.weight.data.copy_(self.weights.view(3, 3, 1))
42+
x = self.data[:,:,:,0].view(1, 3, 2)
43+
layer = prepare_binary_model(layer, bconfig=self.test_bconfig)
44+
45+
output = layer(x)
46+
expected = torch.tensor([[[ 0.0337, 0.0337],
47+
[-0.0473, -0.0473],
48+
[-0.1099, -0.1099]]])
49+
self.assertTrue(torch.allclose(expected, output, atol=1e-4))
50+
51+
def test_conv2d_layer(self):
52+
layer = nn.Conv2d(3, 3, 1, bias=False)
53+
layer.weight.data.copy_(self.weights.view(3, 3, 1, 1))
54+
x = self.data
55+
layer = prepare_binary_model(layer, bconfig=self.test_bconfig)
56+
57+
output = layer(x)
58+
expected = torch.tensor([[[[ 0.0337, 0.0337],
59+
[ 0.0337, -0.0337]],
60+
61+
[[-0.0473, -0.0473],
62+
[-0.0473, 0.0473]],
63+
64+
[[-0.1099, -0.1099],
65+
[-0.1099, 0.1099]]]])
66+
self.assertTrue(torch.allclose(expected, output, atol=1e-4))
67+
68+
69+
if __name__ == '__main__':
70+
unittest.main()

0 commit comments

Comments
 (0)