From a69197d720b37965cb848e1b1ded62e8dd617bf7 Mon Sep 17 00:00:00 2001 From: hauntsaninja Date: Thu, 5 Feb 2026 17:40:44 -0800 Subject: [PATCH] Fix bug when narrowing union containing custom eq against custom eq Fixes #20750 --- mypy/checker.py | 6 +++- test-data/unit/check-narrowing.test | 44 +++++++++++++++++++++++++++++ 2 files changed, 49 insertions(+), 1 deletion(-) diff --git a/mypy/checker.py b/mypy/checker.py index 396aee8d2503..cb5cac93e810 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -6736,13 +6736,17 @@ def narrow_type_by_identity_equality( or_else_maps: list[TypeMap] = [] for expr_type in union_expr_type.items: if has_custom_eq_checks(expr_type): - # Always include union items with custom __eq__ in the type + # Always include the union items with custom __eq__ in the type # we narrow to in the if_map or_if_maps.append({operands[i]: expr_type}) expr_type = coerce_to_literal(try_expanding_sum_type_to_union(expr_type, None)) for j in expr_indices: if j in custom_eq_indices: + if i == j: + continue + # If we compare to a target with custom __eq__, we cannot narrow at all + or_if_maps.append({}) continue target_type = operand_types[j] if should_coerce_literals: diff --git a/test-data/unit/check-narrowing.test b/test-data/unit/check-narrowing.test index 117d0e72ed79..cf89518ae0c2 100644 --- a/test-data/unit/check-narrowing.test +++ b/test-data/unit/check-narrowing.test @@ -1005,6 +1005,50 @@ def f(x: Custom | None, y: int | None): reveal_type(y) # N: Revealed type is "builtins.int | None" [builtins fixtures/primitives.pyi] +[case testNarrowingCustomEqualityUnion4] +# flags: --strict-equality --warn-unreachable +from __future__ import annotations +from typing import Any + +class Custom1: + def __eq__(self, other: object) -> bool: + raise + +class Custom2: + def __eq__(self, other: object) -> bool: + raise + +def f(x: Custom1 | int, y: Custom2 | int): + if x == y: + reveal_type(x) # N: Revealed type is "__main__.Custom1 | builtins.int" + reveal_type(y) # N: Revealed type is "__main__.Custom2 | builtins.int" + else: + reveal_type(x) # N: Revealed type is "__main__.Custom1 | builtins.int" + reveal_type(y) # N: Revealed type is "__main__.Custom2 | builtins.int" +[builtins fixtures/primitives.pyi] + +[case testNarrowingCustomEqualitySubclass] +# flags: --strict-equality --warn-unreachable +from __future__ import annotations +from typing import Any + +class Custom: + def __eq__(self, other: object) -> bool: + raise + +class CustomSub(Custom): + def __eq__(self, other: object) -> bool: + raise + +def f(x: Custom, y: CustomSub): + if x == y: + reveal_type(x) # N: Revealed type is "__main__.Custom" + reveal_type(y) # N: Revealed type is "__main__.CustomSub" + else: + reveal_type(x) # N: Revealed type is "__main__.Custom" + reveal_type(y) # N: Revealed type is "__main__.CustomSub" +[builtins fixtures/tuple.pyi] + [case testNarrowingUnreachableCases] # flags: --strict-equality --warn-unreachable from typing import Literal, Union