Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -6783,6 +6783,36 @@ def narrow_type_by_identity_equality(
enum_comparison_is_ambiguous
and len(expr_enum_keys | ambiguous_enum_equality_keys(target_type)) > 1
):
# For unions with both StrEnum/IntEnum and Literal/None items,
# narrow the Literal/None items while keeping enum items as-is.
orig_type = get_proper_type(coerce_to_literal(operand_types[i]))
if isinstance(orig_type, UnionType):
yes_items: list[Type] = []
no_items: list[Type] = []
has_narrowable = False
target = TypeRange(target_type, is_upper_bound=False)
for item in orig_type.items:
p_item = get_proper_type(item)
is_enum = bool(ambiguous_enum_equality_keys(item) - {"<other>"})
if not is_enum and isinstance(p_item, (LiteralType, NoneType)):
has_narrowable = True
y, n = conditional_types(
item, [target], default=item, from_equality=True
)
yes_items.append(y)
no_items.append(n)
else:
yes_items.append(item)
no_items.append(item)
if has_narrowable:
if_map, else_map = conditional_types_to_typemaps(
operands[i],
UnionType.make_union(yes_items),
UnionType.make_union(no_items),
)
all_if_maps.append(if_map)
if is_target_for_value_narrowing(get_proper_type(target_type)):
all_else_maps.append(else_map)
continue

target = TypeRange(target_type, is_upper_bound=False)
Expand Down
5 changes: 2 additions & 3 deletions test-data/unit/check-enum.test
Original file line number Diff line number Diff line change
Expand Up @@ -2784,9 +2784,8 @@ def f1(a: Foo | Literal['foo']) -> Foo:
reveal_type(a) # N: Revealed type is "__main__.Foo | Literal['foo']"
return Foo.FOO

# Ideally this passes
reveal_type(a) # N: Revealed type is "__main__.Foo | Literal['foo']"
return a # E: Incompatible return value type (got "Foo | Literal['foo']", expected "Foo")
reveal_type(a) # N: Revealed type is "Literal[__main__.Foo.FOO]"
return a
[builtins fixtures/primitives.pyi]

[case testStrEnumEqualityAlias]
Expand Down