@@ -152,6 +152,14 @@ def init_parameters(self):
152152 self .prior_bias_mu .data .fill_ (self .prior_mean )
153153 self .prior_bias_sigma .data .fill_ (self .prior_variance )
154154
155+ def kl_loss (self ):
156+ sigma_weight = torch .log1p (torch .exp (self .rho_kernel ))
157+ kl = self .kl_div (self .mu_kernel , sigma_weight , self .prior_weight_mu , self .prior_weight_sigma )
158+ if self .bias :
159+ sigma_bias = torch .log1p (torch .exp (self .rho_bias ))
160+ kl += self .kl_div (self .mu_bias , sigma_bias , self .prior_bias_mu , self .prior_bias_sigma )
161+ return kl
162+
155163 def forward (self , x , return_kl = True ):
156164
157165 if self .dnn_to_bnn_flag :
@@ -311,6 +319,14 @@ def init_parameters(self):
311319 self .prior_bias_mu .data .fill_ (self .prior_mean )
312320 self .prior_bias_sigma .data .fill_ (self .prior_variance )
313321
322+ def kl_loss (self ):
323+ sigma_weight = torch .log1p (torch .exp (self .rho_kernel ))
324+ kl = self .kl_div (self .mu_kernel , sigma_weight , self .prior_weight_mu , self .prior_weight_sigma )
325+ if self .bias :
326+ sigma_bias = torch .log1p (torch .exp (self .rho_bias ))
327+ kl += self .kl_div (self .mu_bias , sigma_bias , self .prior_bias_mu , self .prior_bias_sigma )
328+ return kl
329+
314330 def forward (self , x , return_kl = True ):
315331
316332 if self .dnn_to_bnn_flag :
@@ -469,6 +485,14 @@ def init_parameters(self):
469485 self .prior_bias_mu .data .fill_ (self .prior_mean )
470486 self .prior_bias_sigma .data .fill_ (self .prior_variance )
471487
488+ def kl_loss (self ):
489+ sigma_weight = torch .log1p (torch .exp (self .rho_kernel ))
490+ kl = self .kl_div (self .mu_kernel , sigma_weight , self .prior_weight_mu , self .prior_weight_sigma )
491+ if self .bias :
492+ sigma_bias = torch .log1p (torch .exp (self .rho_bias ))
493+ kl += self .kl_div (self .mu_bias , sigma_bias , self .prior_bias_mu , self .prior_bias_sigma )
494+ return kl
495+
472496 def forward (self , x , return_kl = True ):
473497
474498 if self .dnn_to_bnn_flag :
@@ -624,6 +648,14 @@ def init_parameters(self):
624648 self .prior_bias_mu .data .fill_ (self .prior_mean )
625649 self .prior_bias_sigma .data .fill_ (self .prior_variance )
626650
651+ def kl_loss (self ):
652+ sigma_weight = torch .log1p (torch .exp (self .rho_kernel ))
653+ kl = self .kl_div (self .mu_kernel , sigma_weight , self .prior_weight_mu , self .prior_weight_sigma )
654+ if self .bias :
655+ sigma_bias = torch .log1p (torch .exp (self .rho_bias ))
656+ kl += self .kl_div (self .mu_bias , sigma_bias , self .prior_bias_mu , self .prior_bias_sigma )
657+ return kl
658+
627659 def forward (self , x , return_kl = True ):
628660
629661 if self .dnn_to_bnn_flag :
@@ -784,6 +816,14 @@ def init_parameters(self):
784816 self .prior_bias_mu .data .fill_ (self .prior_mean )
785817 self .prior_bias_sigma .data .fill_ (self .prior_variance )
786818
819+ def kl_loss (self ):
820+ sigma_weight = torch .log1p (torch .exp (self .rho_kernel ))
821+ kl = self .kl_div (self .mu_kernel , sigma_weight , self .prior_weight_mu , self .prior_weight_sigma )
822+ if self .bias :
823+ sigma_bias = torch .log1p (torch .exp (self .rho_bias ))
824+ kl += self .kl_div (self .mu_bias , sigma_bias , self .prior_bias_mu , self .prior_bias_sigma )
825+ return kl
826+
787827 def forward (self , x , return_kl = True ):
788828
789829 if self .dnn_to_bnn_flag :
@@ -944,6 +984,14 @@ def init_parameters(self):
944984 self .prior_bias_mu .data .fill_ (self .prior_mean )
945985 self .prior_bias_sigma .data .fill_ (self .prior_variance )
946986
987+ def kl_loss (self ):
988+ sigma_weight = torch .log1p (torch .exp (self .rho_kernel ))
989+ kl = self .kl_div (self .mu_kernel , sigma_weight , self .prior_weight_mu , self .prior_weight_sigma )
990+ if self .bias :
991+ sigma_bias = torch .log1p (torch .exp (self .rho_bias ))
992+ kl += self .kl_div (self .mu_bias , sigma_bias , self .prior_bias_mu , self .prior_bias_sigma )
993+ return kl
994+
947995 def forward (self , x , return_kl = True ):
948996
949997 if self .dnn_to_bnn_flag :
0 commit comments