33# pyre-strict
44
55import typing
6- from types import TracebackType
76from typing import (
87 Any ,
98 Callable ,
10- cast ,
119 Iterable ,
1210 Iterator ,
13- Literal ,
1411 Optional ,
12+ Protocol ,
13+ runtime_checkable ,
1514 TextIO ,
16- Type ,
1715 TypeVar ,
1816 Union ,
1917)
2018
2119from tqdm .auto import tqdm
20+ from typing_extensions import Self
2221
2322T = TypeVar ("T" )
24- IterableType = TypeVar ("IterableType" )
23+ IterableType = TypeVar ("IterableType" , covariant = True )
24+
25+
26+ @runtime_checkable
27+ class BaseProgress (Protocol ):
28+ """
29+ Protocol defining the base progress bar interfaced with
30+ context manager support.
31+ Note: This protocol is based on the tqdm type stubs.
32+ """
33+
34+ def __enter__ (self ) -> Self : ...
35+
36+ def __exit__ (
37+ self ,
38+ exc_type : object ,
39+ exc_value : object ,
40+ exc_traceback : object ,
41+ ) -> None : ...
42+
43+ def close (self ) -> None : ...
44+
45+
46+ @runtime_checkable
47+ class IterableProgress (BaseProgress , Iterable [IterableType ], Protocol [IterableType ]):
48+ """Protocol for progress bars that support iteration.
49+
50+ Note: This protocol is based on the tqdm type stubs.
51+ """
52+
53+ ...
54+
55+
56+ @runtime_checkable
57+ class Progress (BaseProgress , Protocol ):
58+ """Protocol for progress bars that support manual updates.
59+ Note: This protocol is based on the tqdm type stubs.
60+ """
61+
62+ # This is a weird definition of Progress, but it's what tqdm does.
63+ def update (self , n : float | None = 1 ) -> bool | None : ...
2564
2665
2766class DisableErrorIOWrapper (object ):
@@ -56,7 +95,7 @@ def flush(self, *args: object, **kwargs: object) -> None:
5695 return self ._wrapped_run (self ._wrapped .flush , * args , ** kwargs )
5796
5897
59- class NullProgress (Iterable [IterableType ]):
98+ class NullProgress (IterableProgress [IterableType ], Progress ):
6099 """Passthrough class that implements the progress API.
61100
62101 This class implements the tqdm and SimpleProgressBar api but
@@ -74,25 +113,27 @@ def __init__(
74113 del args , kwargs
75114 self .iterable = iterable
76115
77- def __enter__ (self ) -> "NullProgress[IterableType]" :
116+ def __iter__ (self ) -> Iterator [IterableType ]:
117+ iterable = self .iterable
118+ if not iterable :
119+ yield from ()
120+ return
121+ for it in iterable :
122+ yield it
123+
124+ def __enter__ (self ) -> Self :
78125 return self
79126
80127 def __exit__ (
81128 self ,
82- exc_type : Union [Type [BaseException ], None ],
83- exc_value : Union [BaseException , None ],
84- exc_traceback : Union [TracebackType , None ],
85- ) -> Literal [False ]:
86- return False
87-
88- def __iter__ (self ) -> Iterator [IterableType ]:
89- if not self .iterable :
90- return
91- for it in cast (Iterable [IterableType ], self .iterable ):
92- yield it
129+ exc_type : object ,
130+ exc_value : object ,
131+ exc_traceback : object ,
132+ ) -> None :
133+ self .close ()
93134
94- def update (self , amount : int = 1 ) -> None :
95- pass
135+ def update (self , n : float | None = 1 ) -> bool | None :
136+ return None
96137
97138 def close (self ) -> None :
98139 pass
@@ -106,7 +147,7 @@ def progress(
106147 file : Optional [TextIO ] = None ,
107148 mininterval : float = 0.5 ,
108149 ** kwargs : object ,
109- ) -> tqdm : ...
150+ ) -> Progress : ...
110151
111152
112153@typing .overload
@@ -117,7 +158,7 @@ def progress(
117158 file : Optional [TextIO ] = None ,
118159 mininterval : float = 0.5 ,
119160 ** kwargs : object ,
120- ) -> tqdm : ...
161+ ) -> IterableProgress [ IterableType ] : ...
121162
122163
123164def progress (
@@ -127,7 +168,7 @@ def progress(
127168 file : Optional [TextIO ] = None ,
128169 mininterval : float = 0.5 ,
129170 ** kwargs : object ,
130- ) -> tqdm :
171+ ) -> Union [ Progress , IterableProgress [ IterableType ]] :
131172 return tqdm (
132173 iterable ,
133174 desc = desc ,
0 commit comments