Skip to content

Commit 21ad9c8

Browse files
KumoLiuwyli
andauthored
Fix 1D data error in VarAutoEncoder (#5236)
Signed-off-by: KumoLiu <yunl@nvidia.com> Fixes #5225 . ### Description 1. fix 1d data error in `VarAutoEncoder` 2. add `use_sigmoid` flag ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [x] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [x] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [x] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. Signed-off-by: KumoLiu <yunl@nvidia.com> Co-authored-by: Wenqi Li <831580+wyli@users.noreply.github.com>
1 parent 252b797 commit 21ad9c8

File tree

3 files changed

+33
-3
lines changed

3 files changed

+33
-3
lines changed

monai/networks/layers/convutils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def calculate_out_shape(
7474
out_shape_np = ((in_shape_np - kernel_size_np + padding_np + padding_np) // stride_np) + 1
7575
out_shape = tuple(int(s) for s in out_shape_np)
7676

77-
return out_shape if len(out_shape) > 1 else out_shape[0]
77+
return out_shape
7878

7979

8080
def gaussian_1d(

monai/networks/nets/varautoencoder.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ class VarAutoEncoder(AutoEncoder):
4848
bias: whether to have a bias term in convolution blocks. Defaults to True.
4949
According to `Performance Tuning Guide <https://pytorch.org/tutorials/recipes/recipes/tuning_guide.html>`_,
5050
if a conv layer is directly followed by a batch norm layer, bias should be False.
51+
use_sigmoid: whether to use the sigmoid function on final output. Defaults to True.
5152
5253
Examples::
5354
@@ -86,9 +87,11 @@ def __init__(
8687
norm: Union[Tuple, str] = Norm.INSTANCE,
8788
dropout: Optional[Union[Tuple, str, float]] = None,
8889
bias: bool = True,
90+
use_sigmoid: bool = True,
8991
) -> None:
9092

9193
self.in_channels, *self.in_shape = in_shape
94+
self.use_sigmoid = use_sigmoid
9295

9396
self.latent_size = latent_size
9497
self.final_size = np.asarray(self.in_shape, dtype=int)
@@ -148,4 +151,4 @@ def reparameterize(self, mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor
148151
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
149152
mu, logvar = self.encode_forward(x)
150153
z = self.reparameterize(mu, logvar)
151-
return self.decode_forward(z), mu, logvar, z
154+
return self.decode_forward(z, self.use_sigmoid), mu, logvar, z

tests/test_varautoencoder.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,34 @@
7575
(1, 3, 128, 128, 128),
7676
]
7777

78-
CASES = [TEST_CASE_0, TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]
78+
TEST_CASE_4 = [ # 4-channel 1D, batch 4
79+
{
80+
"spatial_dims": 1,
81+
"in_shape": (4, 128),
82+
"out_channels": 3,
83+
"latent_size": 2,
84+
"channels": (4, 8, 16),
85+
"strides": (2, 2, 2),
86+
},
87+
(1, 4, 128),
88+
(1, 3, 128),
89+
]
90+
91+
TEST_CASE_5 = [ # 4-channel 1D, batch 4, use_sigmoid = False
92+
{
93+
"spatial_dims": 1,
94+
"in_shape": (4, 128),
95+
"out_channels": 3,
96+
"latent_size": 2,
97+
"channels": (4, 8, 16),
98+
"strides": (2, 2, 2),
99+
"use_sigmoid": False,
100+
},
101+
(1, 4, 128),
102+
(1, 3, 128),
103+
]
104+
105+
CASES = [TEST_CASE_0, TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5]
79106

80107

81108
class TestVarAutoEncoder(unittest.TestCase):

0 commit comments

Comments
 (0)