Skip to content

Commit aee6c28

Browse files
committed
Fix #4
1 parent 2907a85 commit aee6c28

File tree

1 file changed

+6
-0
lines changed

1 file changed

+6
-0
lines changed

bnn/layers/conv.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,9 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
3636
input_proc = self.activation_pre_process(input)
3737
input_proc = self._conv_forward(input_proc, self.weight_pre_process(self.weight), bias=self.bias)
3838

39+
if isinstance(input_proc, tuple) and len(input_proc) == 1:
40+
input_proc = input_proc[0]
41+
3942
return self.activation_post_process(
4043
input_proc,
4144
input
@@ -91,6 +94,9 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
9194
input_proc = self.activation_pre_process(input)
9295
input_proc = self._conv_forward(input_proc, self.weight_pre_process(self.weight), bias=self.bias)
9396

97+
if isinstance(input_proc, tuple) and len(input_proc) == 1:
98+
input_proc = input_proc[0]
99+
94100
return self.activation_post_process(
95101
input_proc,
96102
input

0 commit comments

Comments
 (0)