@@ -23,7 +23,7 @@ def __init__(
2323 padding_mode : str = 'zeros' ,
2424 bconfig : BConfig = None
2525 ) -> None :
26- super (Conv2d , self ).__init__ (in_channels , out_channels , kernel_size ,
26+ super (Conv1d , self ).__init__ (in_channels , out_channels , kernel_size ,
2727 stride = stride , padding = padding , dilation = dilation ,
2828 groups = groups , bias = bias , padding_mode = padding_mode )
2929 assert bconfig , 'bconfig is required for a binarized module'
@@ -34,8 +34,10 @@ def __init__(
3434
3535 def forward (self , input : torch .Tensor ) -> torch .Tensor :
3636 input_proc = self .activation_pre_process (input )
37+ input_proc = self ._conv_forward (input_proc , self .weight_pre_process (self .weight ), bias = self .bias )
38+
3739 return self .activation_post_process (
38- self . _conv_forward ( input_proc , self . weight_pre_process ( self . weight ), bias = self . bias ) ,
40+ input_proc ,
3941 input
4042 )
4143
@@ -87,8 +89,10 @@ def __init__(
8789
8890 def forward (self , input : torch .Tensor ) -> torch .Tensor :
8991 input_proc = self .activation_pre_process (input )
92+ input_proc = self ._conv_forward (input_proc , self .weight_pre_process (self .weight ), bias = self .bias )
93+
9094 return self .activation_post_process (
91- self . _conv_forward ( input_proc , self . weight_pre_process ( self . weight ), bias = self . bias ) ,
95+ input_proc ,
9296 input
9397 )
9498
0 commit comments