@@ -112,6 +112,8 @@ def __init__(self,
112112 self .posterior_rho_init = posterior_rho_init ,
113113 self .bias = bias
114114
115+ self .kl = 0
116+
115117 self .mu_kernel = Parameter (
116118 torch .Tensor (out_channels , in_channels // groups , kernel_size ))
117119 self .rho_kernel = Parameter (
@@ -160,7 +162,7 @@ def init_parameters(self):
160162 self .rho_bias .data .normal_ (mean = self .posterior_rho_init [0 ],
161163 std = 0.1 )
162164
163- def forward (self , input ):
165+ def forward (self , input , return_kl = True ):
164166 sigma_weight = torch .log1p (torch .exp (self .rho_kernel ))
165167 eps_kernel = self .eps_kernel .data .normal_ ()
166168 weight = self .mu_kernel + (sigma_weight * eps_kernel )
@@ -182,7 +184,11 @@ def forward(self, input):
182184 else :
183185 kl = kl_weight
184186
185- return out , kl
187+ self .kl = kl
188+
189+ if return_kl :
190+ return out , kl
191+ return out
186192
187193
188194class Conv2dReparameterization (BaseVariationalLayer_ ):
@@ -239,6 +245,8 @@ def __init__(self,
239245 self .posterior_rho_init = posterior_rho_init ,
240246 self .bias = bias
241247
248+ self .kl = 0
249+
242250 self .mu_kernel = Parameter (
243251 torch .Tensor (out_channels , in_channels // groups , kernel_size ,
244252 kernel_size ))
@@ -292,7 +300,7 @@ def init_parameters(self):
292300 self .rho_bias .data .normal_ (mean = self .posterior_rho_init [0 ],
293301 std = 0.1 )
294302
295- def forward (self , input ):
303+ def forward (self , input , return_kl = True ):
296304 sigma_weight = torch .log1p (torch .exp (self .rho_kernel ))
297305 eps_kernel = self .eps_kernel .data .normal_ ()
298306 weight = self .mu_kernel + (sigma_weight * eps_kernel )
@@ -313,8 +321,12 @@ def forward(self, input):
313321 kl = kl_weight + kl_bias
314322 else :
315323 kl = kl_weight
324+
325+ self .kl = kl
316326
317- return out , kl
327+ if return_kl :
328+ return out , kl
329+ return out
318330
319331
320332class Conv3dReparameterization (BaseVariationalLayer_ ):
@@ -371,6 +383,8 @@ def __init__(self,
371383 self .posterior_rho_init = posterior_rho_init ,
372384 self .bias = bias
373385
386+ self .kl = 0
387+
374388 self .mu_kernel = Parameter (
375389 torch .Tensor (out_channels , in_channels // groups , kernel_size ,
376390 kernel_size , kernel_size ))
@@ -424,7 +438,7 @@ def init_parameters(self):
424438 self .rho_bias .data .normal_ (mean = self .posterior_rho_init [0 ],
425439 std = 0.1 )
426440
427- def forward (self , input ):
441+ def forward (self , input , return_kl = True ):
428442 sigma_weight = torch .log1p (torch .exp (self .rho_kernel ))
429443 eps_kernel = self .eps_kernel .data .normal_ ()
430444 weight = self .mu_kernel + (sigma_weight * eps_kernel )
@@ -446,7 +460,11 @@ def forward(self, input):
446460 else :
447461 kl = kl_weight
448462
449- return out , kl
463+ self .kl = kl
464+
465+ if return_kl :
466+ return out , kl
467+ return out
450468
451469
452470class ConvTranspose1dReparameterization (BaseVariationalLayer_ ):
@@ -504,6 +522,8 @@ def __init__(self,
504522 self .posterior_rho_init = posterior_rho_init ,
505523 self .bias = bias
506524
525+ self .kl = 0
526+
507527 self .mu_kernel = Parameter (
508528 torch .Tensor (in_channels , out_channels // groups , kernel_size ))
509529 self .rho_kernel = Parameter (
@@ -552,7 +572,7 @@ def init_parameters(self):
552572 self .rho_bias .data .normal_ (mean = self .posterior_rho_init [0 ],
553573 std = 0.1 )
554574
555- def forward (self , input ):
575+ def forward (self , input , return_kl = True ):
556576 sigma_weight = torch .log1p (torch .exp (self .rho_kernel ))
557577 eps_kernel = self .eps_kernel .data .normal_ ()
558578 weight = self .mu_kernel + (sigma_weight * eps_kernel )
@@ -575,7 +595,11 @@ def forward(self, input):
575595 else :
576596 kl = kl_weight
577597
578- return out , kl
598+ self .kl = kl
599+
600+ if return_kl :
601+ return out , kl
602+ return out
579603
580604
581605class ConvTranspose2dReparameterization (BaseVariationalLayer_ ):
@@ -633,6 +657,8 @@ def __init__(self,
633657 self .posterior_rho_init = posterior_rho_init ,
634658 self .bias = bias
635659
660+ self .kl = 0
661+
636662 self .mu_kernel = Parameter (
637663 torch .Tensor (in_channels , out_channels // groups , kernel_size ,
638664 kernel_size ))
@@ -686,7 +712,7 @@ def init_parameters(self):
686712 self .rho_bias .data .normal_ (mean = self .posterior_rho_init [0 ],
687713 std = 0.1 )
688714
689- def forward (self , input ):
715+ def forward (self , input , return_kl = True ):
690716 sigma_weight = torch .log1p (torch .exp (self .rho_kernel ))
691717 eps_kernel = self .eps_kernel .data .normal_ ()
692718 weight = self .mu_kernel + (sigma_weight * eps_kernel )
@@ -709,7 +735,11 @@ def forward(self, input):
709735 else :
710736 kl = kl_weight
711737
712- return out , kl
738+ self .kl = kl
739+
740+ if return_kl :
741+ return out , kl
742+ return out
713743
714744
715745class ConvTranspose3dReparameterization (BaseVariationalLayer_ ):
@@ -768,6 +798,8 @@ def __init__(self,
768798 self .posterior_rho_init = posterior_rho_init ,
769799 self .bias = bias
770800
801+ self .kl = 0
802+
771803 self .mu_kernel = Parameter (
772804 torch .Tensor (in_channels , out_channels // groups , kernel_size ,
773805 kernel_size , kernel_size ))
@@ -821,7 +853,7 @@ def init_parameters(self):
821853 self .rho_bias .data .normal_ (mean = self .posterior_rho_init [0 ],
822854 std = 0.1 )
823855
824- def forward (self , input ):
856+ def forward (self , input , return_kl = True ):
825857 sigma_weight = torch .log1p (torch .exp (self .rho_kernel ))
826858 eps_kernel = self .eps_kernel .data .normal_ ()
827859 weight = self .mu_kernel + (sigma_weight * eps_kernel )
@@ -844,4 +876,8 @@ def forward(self, input):
844876 else :
845877 kl = kl_weight
846878
847- return out , kl
879+ self .kl = kl
880+
881+ if return_kl :
882+ return out , kl
883+ return out
0 commit comments