|
46 | 46 | import torch.nn as nn |
47 | 47 | import torch.nn.functional as F |
48 | 48 | from torch.nn import Parameter |
49 | | -from ..base_variational_layer import BaseVariationalLayer_ |
| 49 | +from ..base_variational_layer import BaseVariationalLayer_, get_kernel_size |
50 | 50 | import math |
51 | 51 |
|
52 | 52 | __all__ = [ |
@@ -255,26 +255,28 @@ def __init__(self, |
255 | 255 | self.posterior_rho_init = posterior_rho_init, |
256 | 256 | self.bias = bias |
257 | 257 |
|
| 258 | + kernel_size = get_kernel_size(kernel_size, 2) |
| 259 | + |
258 | 260 | self.mu_kernel = Parameter( |
259 | | - torch.Tensor(out_channels, in_channels // groups, kernel_size, |
260 | | - kernel_size)) |
| 261 | + torch.Tensor(out_channels, in_channels // groups, kernel_size[0], |
| 262 | + kernel_size[1])) |
261 | 263 | self.rho_kernel = Parameter( |
262 | | - torch.Tensor(out_channels, in_channels // groups, kernel_size, |
263 | | - kernel_size)) |
| 264 | + torch.Tensor(out_channels, in_channels // groups, kernel_size[0], |
| 265 | + kernel_size[1])) |
264 | 266 | self.register_buffer( |
265 | 267 | 'eps_kernel', |
266 | | - torch.Tensor(out_channels, in_channels // groups, kernel_size, |
267 | | - kernel_size), |
| 268 | + torch.Tensor(out_channels, in_channels // groups, kernel_size[0], |
| 269 | + kernel_size[1]), |
268 | 270 | persistent=False) |
269 | 271 | self.register_buffer( |
270 | 272 | 'prior_weight_mu', |
271 | | - torch.Tensor(out_channels, in_channels // groups, kernel_size, |
272 | | - kernel_size), |
| 273 | + torch.Tensor(out_channels, in_channels // groups, kernel_size[0], |
| 274 | + kernel_size[1]), |
273 | 275 | persistent=False) |
274 | 276 | self.register_buffer( |
275 | 277 | 'prior_weight_sigma', |
276 | | - torch.Tensor(out_channels, in_channels // groups, kernel_size, |
277 | | - kernel_size), |
| 278 | + torch.Tensor(out_channels, in_channels // groups, kernel_size[0], |
| 279 | + kernel_size[1]), |
278 | 280 | persistent=False) |
279 | 281 |
|
280 | 282 | if self.bias: |
@@ -403,27 +405,27 @@ def __init__(self, |
403 | 405 | # variance of weight --> sigma = log (1 + exp(rho)) |
404 | 406 | self.posterior_rho_init = posterior_rho_init, |
405 | 407 | self.bias = bias |
406 | | - |
| 408 | + kernel_size = get_kernel_size(kernel_size, 3) |
407 | 409 | self.mu_kernel = Parameter( |
408 | | - torch.Tensor(out_channels, in_channels // groups, kernel_size, |
409 | | - kernel_size, kernel_size)) |
| 410 | + torch.Tensor(out_channels, in_channels // groups, kernel_size[0], |
| 411 | + kernel_size[1], kernel_size[2])) |
410 | 412 | self.rho_kernel = Parameter( |
411 | | - torch.Tensor(out_channels, in_channels // groups, kernel_size, |
412 | | - kernel_size, kernel_size)) |
| 413 | + torch.Tensor(out_channels, in_channels // groups, kernel_size[0], |
| 414 | + kernel_size[1], kernel_size[2])) |
413 | 415 | self.register_buffer( |
414 | 416 | 'eps_kernel', |
415 | | - torch.Tensor(out_channels, in_channels // groups, kernel_size, |
416 | | - kernel_size, kernel_size), |
| 417 | + torch.Tensor(out_channels, in_channels // groups, kernel_size[0], |
| 418 | + kernel_size[1], kernel_size[2]), |
417 | 419 | persistent=False) |
418 | 420 | self.register_buffer( |
419 | 421 | 'prior_weight_mu', |
420 | | - torch.Tensor(out_channels, in_channels // groups, kernel_size, |
421 | | - kernel_size, kernel_size), |
| 422 | + torch.Tensor(out_channels, in_channels // groups, kernel_size[0], |
| 423 | + kernel_size[1], kernel_size[2]), |
422 | 424 | persistent=False) |
423 | 425 | self.register_buffer( |
424 | 426 | 'prior_weight_sigma', |
425 | | - torch.Tensor(out_channels, in_channels // groups, kernel_size, |
426 | | - kernel_size, kernel_size), |
| 427 | + torch.Tensor(out_channels, in_channels // groups, kernel_size[0], |
| 428 | + kernel_size[1], kernel_size[2]), |
427 | 429 | persistent=False) |
428 | 430 |
|
429 | 431 | if self.bias: |
@@ -698,27 +700,27 @@ def __init__(self, |
698 | 700 | # variance of weight --> sigma = log (1 + exp(rho)) |
699 | 701 | self.posterior_rho_init = posterior_rho_init, |
700 | 702 | self.bias = bias |
701 | | - |
| 703 | + kernel_size = get_kernel_size(kernel_size, 2) |
702 | 704 | self.mu_kernel = Parameter( |
703 | | - torch.Tensor(in_channels, out_channels // groups, kernel_size, |
704 | | - kernel_size)) |
| 705 | + torch.Tensor(in_channels, out_channels // groups, kernel_size[0], |
| 706 | + kernel_size[1])) |
705 | 707 | self.rho_kernel = Parameter( |
706 | | - torch.Tensor(in_channels, out_channels // groups, kernel_size, |
707 | | - kernel_size)) |
| 708 | + torch.Tensor(in_channels, out_channels // groups, kernel_size[0], |
| 709 | + kernel_size[1])) |
708 | 710 | self.register_buffer( |
709 | 711 | 'eps_kernel', |
710 | | - torch.Tensor(in_channels, out_channels // groups, kernel_size, |
711 | | - kernel_size), |
| 712 | + torch.Tensor(in_channels, out_channels // groups, kernel_size[0], |
| 713 | + kernel_size[1]), |
712 | 714 | persistent=False) |
713 | 715 | self.register_buffer( |
714 | 716 | 'prior_weight_mu', |
715 | | - torch.Tensor(in_channels, out_channels // groups, kernel_size, |
716 | | - kernel_size), |
| 717 | + torch.Tensor(in_channels, out_channels // groups, kernel_size[0], |
| 718 | + kernel_size[1]), |
717 | 719 | persistent=False) |
718 | 720 | self.register_buffer( |
719 | 721 | 'prior_weight_sigma', |
720 | | - torch.Tensor(in_channels, out_channels // groups, kernel_size, |
721 | | - kernel_size), |
| 722 | + torch.Tensor(in_channels, out_channels // groups, kernel_size[0], |
| 723 | + kernel_size[1]), |
722 | 724 | persistent=False) |
723 | 725 |
|
724 | 726 | if self.bias: |
@@ -850,27 +852,27 @@ def __init__(self, |
850 | 852 | # variance of weight --> sigma = log (1 + exp(rho)) |
851 | 853 | self.posterior_rho_init = posterior_rho_init, |
852 | 854 | self.bias = bias |
853 | | - |
| 855 | + kernel_size = get_kernel_size(kernel_size, 3) |
854 | 856 | self.mu_kernel = Parameter( |
855 | | - torch.Tensor(in_channels, out_channels // groups, kernel_size, |
856 | | - kernel_size, kernel_size)) |
| 857 | + torch.Tensor(in_channels, out_channels // groups, kernel_size[0], |
| 858 | + kernel_size[1], kernel_size[2])) |
857 | 859 | self.rho_kernel = Parameter( |
858 | | - torch.Tensor(in_channels, out_channels // groups, kernel_size, |
859 | | - kernel_size, kernel_size)) |
| 860 | + torch.Tensor(in_channels, out_channels // groups, kernel_size[0], |
| 861 | + kernel_size[1], kernel_size[2])) |
860 | 862 | self.register_buffer( |
861 | 863 | 'eps_kernel', |
862 | | - torch.Tensor(in_channels, out_channels // groups, kernel_size, |
863 | | - kernel_size, kernel_size), |
| 864 | + torch.Tensor(in_channels, out_channels // groups, kernel_size[0], |
| 865 | + kernel_size[1], kernel_size[2]), |
864 | 866 | persistent=False) |
865 | 867 | self.register_buffer( |
866 | 868 | 'prior_weight_mu', |
867 | | - torch.Tensor(in_channels, out_channels // groups, kernel_size, |
868 | | - kernel_size, kernel_size), |
| 869 | + torch.Tensor(in_channels, out_channels // groups, kernel_size[0], |
| 870 | + kernel_size[1], kernel_size[2]), |
869 | 871 | persistent=False) |
870 | 872 | self.register_buffer( |
871 | 873 | 'prior_weight_sigma', |
872 | | - torch.Tensor(in_channels, out_channels // groups, kernel_size, |
873 | | - kernel_size, kernel_size), |
| 874 | + torch.Tensor(in_channels, out_channels // groups, kernel_size[0], |
| 875 | + kernel_size[1], kernel_size[2]), |
874 | 876 | persistent=False) |
875 | 877 |
|
876 | 878 | if self.bias: |
|
0 commit comments