@@ -55,9 +55,9 @@ def mutual_information(mc_preds):
5555 Compute the difference between the entropy of the mean of the
5656 predictive distribution and the mean of the entropy.
5757 """
58- MI = entropy (np .mean (mc_preds , axis = 0 )) - np .mean (entropy (mc_preds ),
59- axis = 0 )
60- return MI
58+ mutual_info = entropy (np .mean (mc_preds , axis = 0 )) - np .mean (entropy (mc_preds ),
59+ axis = 0 )
60+ return mutual_info
6161
6262
6363def get_rho (sigma , delta ):
@@ -86,39 +86,51 @@ def MOPED(model, det_model, det_checkpoint, delta):
8686 for (idx , layer ), (det_idx ,
8787 det_layer ) in zip (enumerate (model .modules ()),
8888 enumerate (det_model .modules ())):
89- if (str (layer ) == 'Conv1dVariational()'
90- or str (layer ) == 'Conv2dVariational()'
91- or str (layer ) == 'Conv3dVariational()'
92- or str (layer ) == 'ConvTranspose1dVariational()'
93- or str (layer ) == 'ConvTranspose2dVariational()'
94- or str (layer ) == 'ConvTranspose3dVariational()' ):
89+ if (str (layer ) == 'Conv1dReparametrization()'
90+ or str (layer ) == 'Conv2dReparameterization()'
91+ or str (layer ) == 'Conv3dReparameterization()'
92+ or str (layer ) == 'ConvTranspose1dReparameterization()'
93+ or str (layer ) == 'ConvTranspose2dReparameterization()'
94+ or str (layer ) == 'ConvTranspose3dReparameterization()'
95+ or str (layer ) == 'Conv1dFlipout()'
96+ or str (layer ) == 'Conv2dFlipout()'
97+ or str (layer ) == 'Conv3dFlipout()'
98+ or str (layer ) == 'ConvTranspose1dFlipout()'
99+ or str (layer ) == 'ConvTranspose2dFlipout()'
100+ or str (layer ) == 'ConvTranspose3dFlipout()' ):
95101 #set the priors
96- layer .prior_weight_mu .data = det_layer .weight
97- layer .prior_bias_mu .data = det_layer .bias
102+ layer .prior_weight_mu = det_layer .weight .data
103+ if layer .prior_bias_mu is not None :
104+ layer .prior_bias_mu = det_layer .bias .data
98105
99106 #initialize surrogate posteriors
100- layer .mu_kernel .data = det_layer .weight
107+ layer .mu_kernel .data = det_layer .weight . data
101108 layer .rho_kernel .data = get_rho (det_layer .weight .data , delta )
102- layer .mu_bias .data = det_layer .bias
103- layer .rho_bias .data = get_rho (det_layer .bias .data , delta )
104- elif (str (layer ) == 'LinearVariational()' ):
109+ if layer .mu_bias is not None :
110+ layer .mu_bias .data = det_layer .bias .data
111+ layer .rho_bias .data = get_rho (det_layer .bias .data , delta )
112+ elif (str (layer ) == 'LinearReparameterization()'
113+ or str (layer ) == 'LinearFlipout()' ):
105114 #set the priors
106- layer .prior_weight_mu .data = det_layer .weight
107- layer .prior_bias_mu .data = det_layer .bias
115+ layer .prior_weight_mu = det_layer .weight .data
116+ if layer .prior_bias_mu is not None :
117+ layer .prior_bias_mu .data = det_layer .bias
108118
109119 #initialize the surrogate posteriors
110- layer .mu_weight .data = det_layer .weight
120+ layer .mu_weight .data = det_layer .weight . data
111121 layer .rho_weight .data = get_rho (det_layer .weight .data , delta )
112- layer .mu_bias .data = det_layer .bias
113- layer .rho_bias .data = get_rho (det_layer .bias .data , delta )
122+ if layer .mu_bias is not None :
123+ layer .mu_bias .data = det_layer .bias .data
124+ layer .rho_bias .data = get_rho (det_layer .bias .data , delta )
114125
115126 elif str (layer ).startswith ('Batch' ):
116127 #initialize parameters
117- layer .weight .data = det_layer .weight
118- layer .bias .data = det_layer .bias
119- layer .running_mean .data = det_layer .running_mean
120- layer .running_var .data = det_layer .running_var
121- layer .num_batches_tracked .data = det_layer .num_batches_tracked
128+ layer .weight .data = det_layer .weight .data
129+ if layer .bias is not None :
130+ layer .bias .data = det_layer .bias
131+ layer .running_mean .data = det_layer .running_mean .data
132+ layer .running_var .data = det_layer .running_var .data
133+ layer .num_batches_tracked .data = det_layer .num_batches_tracked .data
122134
123135 model .state_dict ()
124136 return model
0 commit comments