Skip to content

Commit daad127

Browse files
cyrjanometa-codesync[bot]
authored andcommitted
Unify the Progress classes (#1668)
Summary: Pull Request resolved: #1668 This diff aims to unify the Progress classes used in Captum. The changes include updating the dependencies and adding a new protocol for the base progress bar interface. The diff also includes changes to the feature_ablation.py file to use the new progress class. Reviewed By: craymichael Differential Revision: D87560189 fbshipit-source-id: 71d43ee39a89e148e5a45d98c30a1f6e0c3af901
1 parent 4236431 commit daad127

File tree

2 files changed

+66
-55
lines changed

2 files changed

+66
-55
lines changed

captum/_utils/progress.py

Lines changed: 64 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -3,25 +3,64 @@
33
# pyre-strict
44

55
import typing
6-
from types import TracebackType
76
from typing import (
87
Any,
98
Callable,
10-
cast,
119
Iterable,
1210
Iterator,
13-
Literal,
1411
Optional,
12+
Protocol,
13+
runtime_checkable,
1514
TextIO,
16-
Type,
1715
TypeVar,
1816
Union,
1917
)
2018

2119
from tqdm.auto import tqdm
20+
from typing_extensions import Self
2221

2322
T = TypeVar("T")
24-
IterableType = TypeVar("IterableType")
23+
IterableType = TypeVar("IterableType", covariant=True)
24+
25+
26+
@runtime_checkable
27+
class BaseProgress(Protocol):
28+
"""
29+
Protocol defining the base progress bar interfaced with
30+
context manager support.
31+
Note: This protocol is based on the tqdm type stubs.
32+
"""
33+
34+
def __enter__(self) -> Self: ...
35+
36+
def __exit__(
37+
self,
38+
exc_type: object,
39+
exc_value: object,
40+
exc_traceback: object,
41+
) -> None: ...
42+
43+
def close(self) -> None: ...
44+
45+
46+
@runtime_checkable
47+
class IterableProgress(BaseProgress, Iterable[IterableType], Protocol[IterableType]):
48+
"""Protocol for progress bars that support iteration.
49+
50+
Note: This protocol is based on the tqdm type stubs.
51+
"""
52+
53+
...
54+
55+
56+
@runtime_checkable
57+
class Progress(BaseProgress, Protocol):
58+
"""Protocol for progress bars that support manual updates.
59+
Note: This protocol is based on the tqdm type stubs.
60+
"""
61+
62+
# This is a weird definition of Progress, but it's what tqdm does.
63+
def update(self, n: float | None = 1) -> bool | None: ...
2564

2665

2766
class DisableErrorIOWrapper(object):
@@ -56,7 +95,7 @@ def flush(self, *args: object, **kwargs: object) -> None:
5695
return self._wrapped_run(self._wrapped.flush, *args, **kwargs)
5796

5897

59-
class NullProgress(Iterable[IterableType]):
98+
class NullProgress(IterableProgress[IterableType], Progress):
6099
"""Passthrough class that implements the progress API.
61100
62101
This class implements the tqdm and SimpleProgressBar api but
@@ -74,25 +113,27 @@ def __init__(
74113
del args, kwargs
75114
self.iterable = iterable
76115

77-
def __enter__(self) -> "NullProgress[IterableType]":
116+
def __iter__(self) -> Iterator[IterableType]:
117+
iterable = self.iterable
118+
if not iterable:
119+
yield from ()
120+
return
121+
for it in iterable:
122+
yield it
123+
124+
def __enter__(self) -> Self:
78125
return self
79126

80127
def __exit__(
81128
self,
82-
exc_type: Union[Type[BaseException], None],
83-
exc_value: Union[BaseException, None],
84-
exc_traceback: Union[TracebackType, None],
85-
) -> Literal[False]:
86-
return False
87-
88-
def __iter__(self) -> Iterator[IterableType]:
89-
if not self.iterable:
90-
return
91-
for it in cast(Iterable[IterableType], self.iterable):
92-
yield it
129+
exc_type: object,
130+
exc_value: object,
131+
exc_traceback: object,
132+
) -> None:
133+
self.close()
93134

94-
def update(self, amount: int = 1) -> None:
95-
pass
135+
def update(self, n: float | None = 1) -> bool | None:
136+
return None
96137

97138
def close(self) -> None:
98139
pass
@@ -106,7 +147,7 @@ def progress(
106147
file: Optional[TextIO] = None,
107148
mininterval: float = 0.5,
108149
**kwargs: object,
109-
) -> tqdm: ...
150+
) -> Progress: ...
110151

111152

112153
@typing.overload
@@ -117,7 +158,7 @@ def progress(
117158
file: Optional[TextIO] = None,
118159
mininterval: float = 0.5,
119160
**kwargs: object,
120-
) -> tqdm: ...
161+
) -> IterableProgress[IterableType]: ...
121162

122163

123164
def progress(
@@ -127,7 +168,7 @@ def progress(
127168
file: Optional[TextIO] = None,
128169
mininterval: float = 0.5,
129170
**kwargs: object,
130-
) -> tqdm:
171+
) -> Union[Progress, IterableProgress[IterableType]]:
131172
return tqdm(
132173
iterable,
133174
desc=desc,

captum/attr/_core/feature_ablation.py

Lines changed: 2 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -4,19 +4,7 @@
44

55
import logging
66
import math
7-
from typing import (
8-
Any,
9-
Callable,
10-
cast,
11-
Dict,
12-
Iterable,
13-
List,
14-
Optional,
15-
Protocol,
16-
Tuple,
17-
TypeVar,
18-
Union,
19-
)
7+
from typing import Any, Callable, cast, Dict, Iterable, List, Optional, Tuple, Union
208

219
import torch
2210
from captum._utils.common import (
@@ -31,7 +19,7 @@
3119
_run_forward,
3220
)
3321
from captum._utils.exceptions import FeatureAblationFutureError
34-
from captum._utils.progress import progress
22+
from captum._utils.progress import NullProgress, progress, Progress
3523
from captum._utils.typing import BaselineType, TargetType, TensorOrTupleOfTensorsGeneric
3624
from captum.attr._utils.attribution import PerturbationAttribution
3725
from captum.attr._utils.common import (
@@ -43,27 +31,9 @@
4331
from torch.futures import collect_all, Future
4432

4533

46-
IterableType = TypeVar("IterableType")
47-
4834
logger: logging.Logger = logging.getLogger(__name__)
4935

5036

51-
class Progress(Protocol):
52-
def update(self, n: int = 1) -> Optional[bool]:
53-
"""TQDM Update method signature."""
54-
55-
def close(self) -> None:
56-
"""TQDM Close method signature."""
57-
58-
59-
class NullProgress:
60-
def update(self, n: int = 1) -> Optional[bool]:
61-
return None
62-
63-
def close(self) -> None:
64-
return None
65-
66-
6737
def _parse_forward_out(forward_output: object) -> Tensor:
6838
"""
6939
A temp wrapper for global _run_forward util to force forward output

0 commit comments

Comments
 (0)