Skip to content

Commit 11fa980

Browse files
committed
Keep NoneType in Union TypeVar values
1 parent 9a9a201 commit 11fa980

File tree

3 files changed

+25
-9
lines changed

3 files changed

+25
-9
lines changed

mypy/expandtype.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -50,25 +50,37 @@
5050

5151
@overload
5252
def expand_type(
53-
typ: ProperType, env: Mapping[TypeVarId, Type], allow_erased_callables: bool = ...
53+
typ: ProperType,
54+
env: Mapping[TypeVarId, Type],
55+
allow_erased_callables: bool = ...,
56+
*,
57+
keep_none_type: bool = ...,
5458
) -> ProperType:
5559
...
5660

5761

5862
@overload
5963
def expand_type(
60-
typ: Type, env: Mapping[TypeVarId, Type], allow_erased_callables: bool = ...
64+
typ: Type,
65+
env: Mapping[TypeVarId, Type],
66+
allow_erased_callables: bool = ...,
67+
*,
68+
keep_none_type: bool = ...,
6169
) -> Type:
6270
...
6371

6472

6573
def expand_type(
66-
typ: Type, env: Mapping[TypeVarId, Type], allow_erased_callables: bool = False
74+
typ: Type,
75+
env: Mapping[TypeVarId, Type],
76+
allow_erased_callables: bool = False,
77+
*,
78+
keep_none_type: bool = False,
6779
) -> Type:
6880
"""Substitute any type variable references in a type given by a type
6981
environment.
7082
"""
71-
return typ.accept(ExpandTypeVisitor(env, allow_erased_callables))
83+
return typ.accept(ExpandTypeVisitor(env, allow_erased_callables, keep_none_type))
7284

7385

7486
@overload
@@ -183,10 +195,14 @@ class ExpandTypeVisitor(TypeVisitor[Type]):
183195
variables: Mapping[TypeVarId, Type] # TypeVar id -> TypeVar value
184196

185197
def __init__(
186-
self, variables: Mapping[TypeVarId, Type], allow_erased_callables: bool = False
198+
self,
199+
variables: Mapping[TypeVarId, Type],
200+
allow_erased_callables: bool = False,
201+
keep_none_type: bool = False,
187202
) -> None:
188203
self.variables = variables
189204
self.allow_erased_callables = allow_erased_callables
205+
self.keep_none_type = keep_none_type
190206
self.recursive_guard: set[Type | tuple[int, Type]] = set()
191207

192208
def visit_unbound_type(self, t: UnboundType) -> Type:
@@ -470,7 +486,7 @@ def visit_union_type(self, t: UnionType) -> Type:
470486
# might be subtypes of others, however calling make_simplified_union()
471487
# can cause recursion, so we just remove strict duplicates.
472488
return UnionType.make_union(
473-
remove_trivial(flatten_nested_unions(expanded)), t.line, t.column
489+
remove_trivial(flatten_nested_unions(expanded), self.keep_none_type), t.line, t.column
474490
)
475491

476492
def visit_partial_type(self, t: PartialType) -> Type:

mypy/typeanal.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1706,7 +1706,7 @@ def fix_instance(
17061706
args.append(arg)
17071707
env[tv.id] = arg
17081708
t.args = tuple(args)
1709-
fixed = expand_type(t, env)
1709+
fixed = expand_type(t, env, keep_none_type=True)
17101710
assert isinstance(fixed, Instance)
17111711
t.args = fixed.args
17121712

mypy/types.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3704,7 +3704,7 @@ def store_argument_type(
37043704
defn.arguments[i].variable.type = arg_type
37053705

37063706

3707-
def remove_trivial(types: Iterable[Type]) -> list[Type]:
3707+
def remove_trivial(types: Iterable[Type], keep_none_type: bool = False) -> list[Type]:
37083708
"""Make trivial simplifications on a list of types without calling is_subtype().
37093709
37103710
This makes following simplifications:
@@ -3719,7 +3719,7 @@ def remove_trivial(types: Iterable[Type]) -> list[Type]:
37193719
p_t = get_proper_type(t)
37203720
if isinstance(p_t, UninhabitedType):
37213721
continue
3722-
if isinstance(p_t, NoneType) and not state.strict_optional:
3722+
if isinstance(p_t, NoneType) and not state.strict_optional and not keep_none_type:
37233723
removed_none = True
37243724
continue
37253725
if isinstance(p_t, Instance) and p_t.type.fullname == "builtins.object":

0 commit comments

Comments
 (0)