Skip to content

Commit 83f5091

Browse files
Port MONAI Generative utils (#7134)
Towards completing #6676 . ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [x] In-line docstrings updated. - [x] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Mark Graham <markgraham539@gmail.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 1c17f0e commit 83f5091

File tree

5 files changed

+157
-0
lines changed

5 files changed

+157
-0
lines changed

docs/source/utils.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,3 +76,8 @@ State Cacher
7676
------------
7777
.. automodule:: monai.utils.state_cacher
7878
:members:
79+
80+
Component store
81+
---------------
82+
.. autoclass:: monai.utils.component_store.ComponentStore
83+
:members:

monai/utils/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
from .deprecate_utils import DeprecatedError, deprecated, deprecated_arg, deprecated_arg_default
1919
from .dist import RankFilter, evenly_divisible_all_gather, get_dist_device, string_list_all_gather
2020
from .enums import (
21+
AdversarialIterationEvents,
22+
AdversarialKeys,
2123
AlgoKeys,
2224
Average,
2325
BlendMode,
@@ -47,6 +49,8 @@
4749
MetricReduction,
4850
NdimageMode,
4951
NumpyPadMode,
52+
OrderingTransformations,
53+
OrderingType,
5054
PatchKeys,
5155
PostFix,
5256
ProbMapKeys,
@@ -95,6 +99,8 @@
9599
str2bool,
96100
str2list,
97101
to_tuple_of_dictionaries,
102+
unsqueeze_left,
103+
unsqueeze_right,
98104
zip_with,
99105
)
100106
from .module import (

monai/utils/enums.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,11 @@
1313

1414
import random
1515
from enum import Enum
16+
from typing import TYPE_CHECKING
1617

18+
from monai.config import IgniteInfo
1719
from monai.utils import deprecated
20+
from monai.utils.module import min_version, optional_import
1821

1922
__all__ = [
2023
"StrEnum",
@@ -88,6 +91,14 @@ def __repr__(self):
8891
return self.value
8992

9093

94+
if TYPE_CHECKING:
95+
from ignite.engine import EventEnum
96+
else:
97+
EventEnum, _ = optional_import(
98+
"ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "EventEnum", as_type="base"
99+
)
100+
101+
91102
class NumpyPadMode(StrEnum):
92103
"""
93104
See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html
@@ -692,3 +703,57 @@ class AlgoKeys(StrEnum):
692703
ALGO = "algo_instance"
693704
IS_TRAINED = "is_trained"
694705
SCORE = "best_metric"
706+
707+
708+
class AdversarialKeys(StrEnum):
709+
"""
710+
Keys used by the AdversarialTrainer.
711+
`REALS` are real images from the batch.
712+
`FAKES` are fake images generated by the generator. Are the same as PRED.
713+
`REAL_LOGITS` are logits of the discriminator for the real images.
714+
`FAKE_LOGIT` are logits of the discriminator for the fake images.
715+
`RECONSTRUCTION_LOSS` is the loss value computed by the reconstruction loss function.
716+
`GENERATOR_LOSS` is the loss value computed by the generator loss function. It is the
717+
discriminator loss for the fake images. That is backpropagated through the generator only.
718+
`DISCRIMINATOR_LOSS` is the loss value computed by the discriminator loss function. It is the
719+
discriminator loss for the real images and the fake images. That is backpropagated through the
720+
discriminator only.
721+
"""
722+
723+
REALS = "reals"
724+
REAL_LOGITS = "real_logits"
725+
FAKES = "fakes"
726+
FAKE_LOGITS = "fake_logits"
727+
RECONSTRUCTION_LOSS = "reconstruction_loss"
728+
GENERATOR_LOSS = "generator_loss"
729+
DISCRIMINATOR_LOSS = "discriminator_loss"
730+
731+
732+
class AdversarialIterationEvents(EventEnum):
733+
"""
734+
Keys used to define events as used in the AdversarialTrainer.
735+
"""
736+
737+
RECONSTRUCTION_LOSS_COMPLETED = "reconstruction_loss_completed"
738+
GENERATOR_FORWARD_COMPLETED = "generator_forward_completed"
739+
GENERATOR_DISCRIMINATOR_FORWARD_COMPLETED = "generator_discriminator_forward_completed"
740+
GENERATOR_LOSS_COMPLETED = "generator_loss_completed"
741+
GENERATOR_BACKWARD_COMPLETED = "generator_backward_completed"
742+
GENERATOR_MODEL_COMPLETED = "generator_model_completed"
743+
DISCRIMINATOR_REALS_FORWARD_COMPLETED = "discriminator_reals_forward_completed"
744+
DISCRIMINATOR_FAKES_FORWARD_COMPLETED = "discriminator_fakes_forward_completed"
745+
DISCRIMINATOR_LOSS_COMPLETED = "discriminator_loss_completed"
746+
DISCRIMINATOR_BACKWARD_COMPLETED = "discriminator_backward_completed"
747+
DISCRIMINATOR_MODEL_COMPLETED = "discriminator_model_completed"
748+
749+
750+
class OrderingType(StrEnum):
751+
RASTER_SCAN = "raster_scan"
752+
S_CURVE = "s_curve"
753+
RANDOM = "random"
754+
755+
756+
class OrderingTransformations(StrEnum):
757+
ROTATE_90 = "rotate_90"
758+
TRANSPOSE = "transpose"
759+
REFLECT = "reflect"

monai/utils/misc.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -888,3 +888,13 @@ def is_sqrt(num: Sequence[int] | int) -> bool:
888888
sqrt_num = [int(math.sqrt(_num)) for _num in num]
889889
ret = [_i * _j for _i, _j in zip(sqrt_num, sqrt_num)]
890890
return ensure_tuple(ret) == num
891+
892+
893+
def unsqueeze_right(arr: NdarrayOrTensor, ndim: int) -> NdarrayOrTensor:
894+
"""Append 1-sized dimensions to `arr` to create a result with `ndim` dimensions."""
895+
return arr[(...,) + (None,) * (ndim - arr.ndim)]
896+
897+
898+
def unsqueeze_left(arr: NdarrayOrTensor, ndim: int) -> NdarrayOrTensor:
899+
"""Prepend 1-sized dimensions to `arr` to create a result with `ndim` dimensions."""
900+
return arr[(None,) * (ndim - arr.ndim)]

tests/test_squeeze_unsqueeze.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
# Copyright (c) MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
from __future__ import annotations
13+
14+
import unittest
15+
16+
import numpy as np
17+
import torch
18+
from parameterized import parameterized
19+
20+
from monai.utils import unsqueeze_left, unsqueeze_right
21+
22+
RIGHT_CASES = [
23+
(np.random.rand(3, 4).astype(np.float32), 5, (3, 4, 1, 1, 1)),
24+
(torch.rand(3, 4).type(torch.float32), 5, (3, 4, 1, 1, 1)),
25+
(np.random.rand(3, 4).astype(np.float64), 5, (3, 4, 1, 1, 1)),
26+
(torch.rand(3, 4).type(torch.float64), 5, (3, 4, 1, 1, 1)),
27+
(np.random.rand(3, 4).astype(np.int32), 5, (3, 4, 1, 1, 1)),
28+
(torch.rand(3, 4).type(torch.int32), 5, (3, 4, 1, 1, 1)),
29+
]
30+
31+
32+
LEFT_CASES = [
33+
(np.random.rand(3, 4).astype(np.float32), 5, (1, 1, 1, 3, 4)),
34+
(torch.rand(3, 4).type(torch.float32), 5, (1, 1, 1, 3, 4)),
35+
(np.random.rand(3, 4).astype(np.float64), 5, (1, 1, 1, 3, 4)),
36+
(torch.rand(3, 4).type(torch.float64), 5, (1, 1, 1, 3, 4)),
37+
(np.random.rand(3, 4).astype(np.int32), 5, (1, 1, 1, 3, 4)),
38+
(torch.rand(3, 4).type(torch.int32), 5, (1, 1, 1, 3, 4)),
39+
]
40+
ALL_CASES = [
41+
(np.random.rand(3, 4), 2, (3, 4)),
42+
(np.random.rand(3, 4), 0, (3, 4)),
43+
(np.random.rand(3, 4), -1, (3, 4)),
44+
(np.array(3), 4, (1, 1, 1, 1)),
45+
(np.array(3), 0, ()),
46+
(np.random.rand(3, 4).astype(np.int32), 2, (3, 4)),
47+
(np.random.rand(3, 4).astype(np.int32), 0, (3, 4)),
48+
(np.random.rand(3, 4).astype(np.int32), -1, (3, 4)),
49+
(np.array(3).astype(np.int32), 4, (1, 1, 1, 1)),
50+
(np.array(3).astype(np.int32), 0, ()),
51+
(torch.rand(3, 4), 2, (3, 4)),
52+
(torch.rand(3, 4), 0, (3, 4)),
53+
(torch.rand(3, 4), -1, (3, 4)),
54+
(torch.tensor(3), 4, (1, 1, 1, 1)),
55+
(torch.tensor(3), 0, ()),
56+
(torch.rand(3, 4).type(torch.int32), 2, (3, 4)),
57+
(torch.rand(3, 4).type(torch.int32), 0, (3, 4)),
58+
(torch.rand(3, 4).type(torch.int32), -1, (3, 4)),
59+
(torch.tensor(3).type(torch.int32), 4, (1, 1, 1, 1)),
60+
(torch.tensor(3).type(torch.int32), 0, ()),
61+
]
62+
63+
64+
class TestUnsqueeze(unittest.TestCase):
65+
@parameterized.expand(RIGHT_CASES + ALL_CASES)
66+
def test_unsqueeze_right(self, arr, ndim, shape):
67+
self.assertEqual(unsqueeze_right(arr, ndim).shape, shape)
68+
69+
@parameterized.expand(LEFT_CASES + ALL_CASES)
70+
def test_unsqueeze_left(self, arr, ndim, shape):
71+
self.assertEqual(unsqueeze_left(arr, ndim).shape, shape)

0 commit comments

Comments
 (0)