Skip to content

Commit c90b301

Browse files
authored
Propagate narrowing within chained comparisons (#21160)
Fixes #21149 (the part that isn't a regression)
1 parent 8f2e6ad commit c90b301

2 files changed

Lines changed: 39 additions & 3 deletions

File tree

mypy/checker.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6558,6 +6558,7 @@ def comparison_type_narrowing_helper(self, node: ComparisonExpr) -> tuple[TypeMa
65586558
operands = [collapse_walrus(x) for x in node.operands]
65596559
operand_types = []
65606560
narrowable_operand_index_to_hash = {}
6561+
narrowable_operand_hash_to_index = {}
65616562
for i, expr in enumerate(operands):
65626563
if not self.has_type(expr):
65636564
return {}, {}
@@ -6582,6 +6583,7 @@ def comparison_type_narrowing_helper(self, node: ComparisonExpr) -> tuple[TypeMa
65826583
h = literal_hash(expr)
65836584
if h is not None:
65846585
narrowable_operand_index_to_hash[i] = h
6586+
narrowable_operand_hash_to_index[h] = i
65856587

65866588
# Step 2: Group operands chained by either the 'is' or '==' operands
65876589
# together. For all other operands, we keep them in groups of size 2.
@@ -6673,6 +6675,18 @@ def comparison_type_narrowing_helper(self, node: ComparisonExpr) -> tuple[TypeMa
66736675

66746676
partial_type_maps.append((if_map, else_map))
66756677

6678+
# Chained comparisons are conjunctions evaluated left-to-right. Feed what we learned
6679+
# from earlier true comparisons into later comparisons, similarly to `and`.
6680+
if len(simplified_operator_list) > 1:
6681+
for expr, expr_type in if_map.items():
6682+
h = literal_hash(expr)
6683+
if h is None or h not in narrowable_operand_hash_to_index:
6684+
continue
6685+
operand_index = narrowable_operand_hash_to_index[h]
6686+
operand_types[operand_index] = meet_types(
6687+
operand_types[operand_index], expr_type
6688+
)
6689+
66766690
# If we have found non-trivial restrictions from the regular comparisons,
66776691
# then return soon. Otherwise try to infer restrictions involving `len(x)`.
66786692
# TODO: support regular and len() narrowing in the same chain.

test-data/unit/check-narrowing.test

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3264,15 +3264,15 @@ def bad_but_should_pass(has_key: bool, key: bool, s: tuple[bool, ...]) -> None:
32643264
reveal_type(key) # N: Revealed type is "builtins.bool"
32653265
[builtins fixtures/primitives.pyi]
32663266

3267-
[case testNarrowChainedComparisonMeet]
3267+
[case testNarrowChainedComparisonMeetAndForwardPropagation]
32683268
# flags: --strict-equality --warn-unreachable
32693269
from __future__ import annotations
32703270
from typing import Any
32713271

32723272
def f1(a: str | None, b: str | None) -> None:
32733273
if None is not a == b:
32743274
reveal_type(a) # N: Revealed type is "builtins.str"
3275-
reveal_type(b) # N: Revealed type is "builtins.str | None"
3275+
reveal_type(b) # N: Revealed type is "builtins.str"
32763276

32773277
if (None is not a) and (a == b):
32783278
reveal_type(a) # N: Revealed type is "builtins.str"
@@ -3290,11 +3290,33 @@ def f2(a: Any | None, b: str | None) -> None:
32903290
def f3(a: str | None, b: Any | None) -> None:
32913291
if None is not a == b:
32923292
reveal_type(a) # N: Revealed type is "builtins.str"
3293-
reveal_type(b) # N: Revealed type is "Any | builtins.str | None"
3293+
reveal_type(b) # N: Revealed type is "Any | builtins.str"
32943294

32953295
if (None is not a) and (a == b):
32963296
reveal_type(a) # N: Revealed type is "builtins.str"
32973297
reveal_type(b) # N: Revealed type is "Any | builtins.str"
3298+
3299+
def f4(a: str | None, b: str | None, c: str | None) -> None:
3300+
if None is not a == b == c:
3301+
reveal_type(a) # N: Revealed type is "builtins.str"
3302+
reveal_type(b) # N: Revealed type is "builtins.str"
3303+
reveal_type(c) # N: Revealed type is "builtins.str"
3304+
3305+
if (None is not a) and (a == b) and (b == c):
3306+
reveal_type(a) # N: Revealed type is "builtins.str"
3307+
reveal_type(b) # N: Revealed type is "builtins.str"
3308+
reveal_type(c) # N: Revealed type is "builtins.str"
3309+
3310+
def f5(pair: tuple[None, int] | tuple[str, str], other: str | None) -> None:
3311+
if None is not pair[0] == other:
3312+
reveal_type(pair[0]) # N: Revealed type is "builtins.str"
3313+
reveal_type(pair) # N: Revealed type is "tuple[builtins.str, builtins.str]"
3314+
reveal_type(other) # N: Revealed type is "builtins.str"
3315+
3316+
if (None is not pair[0]) and (pair[0] == other):
3317+
reveal_type(pair[0]) # N: Revealed type is "builtins.str"
3318+
reveal_type(pair) # N: Revealed type is "tuple[builtins.str, builtins.str]"
3319+
reveal_type(other) # N: Revealed type is "builtins.str"
32983320
[builtins fixtures/primitives.pyi]
32993321

33003322
[case testNarrowTypeObject]

0 commit comments

Comments
 (0)