diff --git a/mypy/checker.py b/mypy/checker.py index fc636e9a7218..323875d47304 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -6783,6 +6783,36 @@ def narrow_type_by_identity_equality( enum_comparison_is_ambiguous and len(expr_enum_keys | ambiguous_enum_equality_keys(target_type)) > 1 ): + # For unions with both StrEnum/IntEnum and Literal/None items, + # narrow the Literal/None items while keeping enum items as-is. + orig_type = get_proper_type(coerce_to_literal(operand_types[i])) + if isinstance(orig_type, UnionType): + yes_items: list[Type] = [] + no_items: list[Type] = [] + has_narrowable = False + target = TypeRange(target_type, is_upper_bound=False) + for item in orig_type.items: + p_item = get_proper_type(item) + is_enum = bool(ambiguous_enum_equality_keys(item) - {""}) + if not is_enum and isinstance(p_item, (LiteralType, NoneType)): + has_narrowable = True + y, n = conditional_types( + item, [target], default=item, from_equality=True + ) + yes_items.append(y) + no_items.append(n) + else: + yes_items.append(item) + no_items.append(item) + if has_narrowable: + if_map, else_map = conditional_types_to_typemaps( + operands[i], + UnionType.make_union(yes_items), + UnionType.make_union(no_items), + ) + all_if_maps.append(if_map) + if is_target_for_value_narrowing(get_proper_type(target_type)): + all_else_maps.append(else_map) continue target = TypeRange(target_type, is_upper_bound=False) diff --git a/test-data/unit/check-enum.test b/test-data/unit/check-enum.test index c05dfdef2bf7..151e031d7cee 100644 --- a/test-data/unit/check-enum.test +++ b/test-data/unit/check-enum.test @@ -2784,9 +2784,8 @@ def f1(a: Foo | Literal['foo']) -> Foo: reveal_type(a) # N: Revealed type is "__main__.Foo | Literal['foo']" return Foo.FOO - # Ideally this passes - reveal_type(a) # N: Revealed type is "__main__.Foo | Literal['foo']" - return a # E: Incompatible return value type (got "Foo | Literal['foo']", expected "Foo") + reveal_type(a) # N: Revealed type is "Literal[__main__.Foo.FOO]" + return a [builtins fixtures/primitives.pyi] [case testStrEnumEqualityAlias]