Skip to content

Commit 0d38e4b

Browse files
authored
Disallow incomplete defs in optimizers module (#5928)
Part of #5884. ### Description Fully type annotate any functions with at least one type annotation in module `optimizers`. ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [x] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. Signed-off-by: Felix Schnabel <f.schnabel@tum.de>
1 parent e90bb84 commit 0d38e4b

File tree

4 files changed

+20
-7
lines changed

4 files changed

+20
-7
lines changed

monai/optimizers/lr_finder.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from __future__ import annotations
1313

1414
import pickle
15+
import types
1516
import warnings
1617
from functools import partial
1718
from typing import TYPE_CHECKING, Any, Callable
@@ -187,7 +188,7 @@ def __init__(
187188
memory_cache: bool = True,
188189
cache_dir: str | None = None,
189190
amp: bool = False,
190-
pickle_module=pickle,
191+
pickle_module: types.ModuleType = pickle,
191192
pickle_protocol: int = DEFAULT_PROTOCOL,
192193
verbose: bool = True,
193194
) -> None:
@@ -389,7 +390,9 @@ def _check_for_scheduler(self):
389390
if "initial_lr" in param_group:
390391
raise RuntimeError("Optimizer already has a scheduler attached to it")
391392

392-
def _train_batch(self, train_iter, accumulation_steps: int, non_blocking_transfer: bool = True) -> float:
393+
def _train_batch(
394+
self, train_iter: TrainDataLoaderIter, accumulation_steps: int, non_blocking_transfer: bool = True
395+
) -> float:
393396
self.model.train()
394397
total_loss = 0
395398

@@ -478,7 +481,14 @@ def get_steepest_gradient(self, skip_start: int = 0, skip_end: int = 0) -> tuple
478481
print("Failed to compute the gradients, there might not be enough points.")
479482
return None, None
480483

481-
def plot(self, skip_start: int = 0, skip_end: int = 0, log_lr: bool = True, ax=None, steepest_lr: bool = True):
484+
def plot(
485+
self,
486+
skip_start: int = 0,
487+
skip_end: int = 0,
488+
log_lr: bool = True,
489+
ax: Any | None = None,
490+
steepest_lr: bool = True,
491+
) -> Any | None:
482492
"""Plots the learning rate range test.
483493
484494
Args:

monai/optimizers/novograd.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,14 @@
1111

1212
from __future__ import annotations
1313

14-
from typing import Callable, Iterable
14+
from collections.abc import Callable, Iterable
15+
from typing import TypeVar
1516

1617
import torch
1718
from torch.optim import Optimizer
1819

20+
T = TypeVar("T")
21+
1922

2023
class Novograd(Optimizer):
2124
"""
@@ -67,7 +70,7 @@ def __setstate__(self, state):
6770
for group in self.param_groups:
6871
group.setdefault("amsgrad", False)
6972

70-
def step(self, closure: Callable | None = None):
73+
def step(self, closure: Callable[[], T] | None = None) -> T | None:
7174
"""Performs a single optimization step.
7275
7376
Arguments:

monai/optimizers/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def generate_param_groups(
2626
match_types: Sequence[str],
2727
lr_values: Sequence[float],
2828
include_others: bool = True,
29-
):
29+
) -> list[dict]:
3030
"""
3131
Utility function to generate parameter groups with different LR values for optimizer.
3232
The output parameter groups have the same order as `layer_match` functions.

setup.cfg

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,7 @@ check_untyped_defs = True
222222
# Warns about usage of untyped decorators.
223223
disallow_untyped_decorators = True
224224

225-
[mypy-monai.visualize.*,monai.utils.*]
225+
[mypy-monai.visualize.*,monai.utils.*,monai.optimizers.*]
226226
disallow_incomplete_defs = True
227227

228228
[coverage:run]

0 commit comments

Comments
 (0)