Skip to content

Commit 8044a8f

Browse files
committed
Fix typo
1 parent e80f6a4 commit 8044a8f

File tree

1 file changed

+7
-3
lines changed

1 file changed

+7
-3
lines changed

bnn/layers/conv.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def __init__(
2323
padding_mode: str = 'zeros',
2424
bconfig: BConfig = None
2525
) -> None:
26-
super(Conv2d, self).__init__(in_channels, out_channels, kernel_size,
26+
super(Conv1d, self).__init__(in_channels, out_channels, kernel_size,
2727
stride=stride, padding=padding, dilation=dilation,
2828
groups=groups, bias=bias, padding_mode=padding_mode)
2929
assert bconfig, 'bconfig is required for a binarized module'
@@ -34,8 +34,10 @@ def __init__(
3434

3535
def forward(self, input: torch.Tensor) -> torch.Tensor:
3636
input_proc = self.activation_pre_process(input)
37+
input_proc = self._conv_forward(input_proc, self.weight_pre_process(self.weight), bias=self.bias)
38+
3739
return self.activation_post_process(
38-
self._conv_forward(input_proc, self.weight_pre_process(self.weight), bias=self.bias),
40+
input_proc,
3941
input
4042
)
4143

@@ -87,8 +89,10 @@ def __init__(
8789

8890
def forward(self, input: torch.Tensor) -> torch.Tensor:
8991
input_proc = self.activation_pre_process(input)
92+
input_proc = self._conv_forward(input_proc, self.weight_pre_process(self.weight), bias=self.bias)
93+
9094
return self.activation_post_process(
91-
self._conv_forward(input_proc, self.weight_pre_process(self.weight), bias=self.bias),
95+
input_proc,
9296
input
9397
)
9498

0 commit comments

Comments
 (0)