diff --git a/mypy/checker.py b/mypy/checker.py index 396aee8d2503..399ccd9e9247 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -5090,23 +5090,22 @@ def visit_if_stmt(self, s: IfStmt) -> None: # This frame records the knowledge from previous if/elif clauses not being taken. # Fall-through to the original frame is handled explicitly in each block. with self.binder.frame_context(can_skip=False, conditional_frame=True, fall_through=0): - for e, b in zip(s.expr, s.body): - t = get_proper_type(self.expr_checker.accept(e)) + t = get_proper_type(self.expr_checker.accept(s.expr)) - if isinstance(t, DeletedType): - self.msg.deleted_as_rvalue(t, s) + if isinstance(t, DeletedType): + self.msg.deleted_as_rvalue(t, s) - if_map, else_map = self.find_isinstance_check(e) + if_map, else_map = self.find_isinstance_check(s.expr) - s.unreachable_else = is_unreachable_map(else_map) + s.unreachable_else = is_unreachable_map(else_map) - # XXX Issue a warning if condition is always False? - with self.binder.frame_context(can_skip=True, fall_through=2): - self.push_type_map(if_map, from_assignment=False) - self.accept(b) + # XXX Issue a warning if condition is always False? + with self.binder.frame_context(can_skip=True, fall_through=2): + self.push_type_map(if_map, from_assignment=False) + self.accept(s.body) - # XXX Issue a warning if condition is always True? - self.push_type_map(else_map, from_assignment=False) + # XXX Issue a warning if condition is always True? + self.push_type_map(else_map, from_assignment=False) with self.binder.frame_context(can_skip=False, fall_through=2): if s.else_body: @@ -5114,7 +5113,7 @@ def visit_if_stmt(self, s: IfStmt) -> None: def visit_while_stmt(self, s: WhileStmt) -> None: """Type check a while statement.""" - if_stmt = IfStmt([s.expr], [s.body], None) + if_stmt = IfStmt(s.expr, s.body, None) if_stmt.set_line(s) self.accept_loop(if_stmt, s.else_body, exit_condition=s.expr) diff --git a/mypy/fastparse.py b/mypy/fastparse.py index 8ef905a567d1..cb66536367a0 100644 --- a/mypy/fastparse.py +++ b/mypy/fastparse.py @@ -773,19 +773,19 @@ def _check_ifstmt_for_overloads( # Check that block only contains a single Decorator, FuncDef, or OverloadedFuncDef. # Multiple overloads have already been merged as OverloadedFuncDef. if not ( - len(stmt.body[0].body) == 1 + len(stmt.body.body) == 1 and ( - isinstance(stmt.body[0].body[0], (Decorator, OverloadedFuncDef)) + isinstance(stmt.body.body[0], (Decorator, OverloadedFuncDef)) or current_overload_name is not None - and isinstance(stmt.body[0].body[0], FuncDef) + and isinstance(stmt.body.body[0], FuncDef) ) - or len(stmt.body[0].body) > 1 - and isinstance(stmt.body[0].body[-1], OverloadedFuncDef) - and all(self._is_stripped_if_stmt(if_stmt) for if_stmt in stmt.body[0].body[:-1]) + or len(stmt.body.body) > 1 + and isinstance(stmt.body.body[-1], OverloadedFuncDef) + and all(self._is_stripped_if_stmt(if_stmt) for if_stmt in stmt.body.body[:-1]) ): return None - overload_name = cast(Decorator | FuncDef | OverloadedFuncDef, stmt.body[0].body[-1]).name + overload_name = cast(Decorator | FuncDef | OverloadedFuncDef, stmt.body.body[-1]).name if stmt.else_body is None: return overload_name @@ -816,20 +816,20 @@ def _get_executable_if_block_with_overloads( i.e. the truth value is unknown. """ infer_reachability_of_if_statement(stmt, self.options) - if stmt.else_body is None and stmt.body[0].is_unreachable is True: + if stmt.else_body is None and stmt.body.is_unreachable is True: # always False condition with no else return None, None if ( stmt.else_body is None - or stmt.body[0].is_unreachable is False + or stmt.body.is_unreachable is False and stmt.else_body.is_unreachable is False ): # The truth value is unknown, thus not conclusive return None, stmt if stmt.else_body.is_unreachable is True: # else_body will be set unreachable if condition is always True - return stmt.body[0], None - if stmt.body[0].is_unreachable is True: + return stmt.body, None + if stmt.body.is_unreachable is True: # body will be set unreachable if condition is always False # else_body can contain an IfStmt itself (for elif) -> do a recursive check if isinstance(stmt.else_body.body[0], IfStmt): @@ -843,8 +843,7 @@ def _strip_contents_from_if_stmt(self, stmt: IfStmt) -> None: Needed to still be able to check the conditions after the contents have been merged with the surrounding function overloads. """ - if len(stmt.body) == 1: - stmt.body[0].body = [] + stmt.body.body = [] if stmt.else_body and len(stmt.else_body.body) == 1: if isinstance(stmt.else_body.body[0], IfStmt): self._strip_contents_from_if_stmt(stmt.else_body.body[0]) @@ -859,7 +858,7 @@ def _is_stripped_if_stmt(self, stmt: Statement) -> bool: if not isinstance(stmt, IfStmt): return False - if not (len(stmt.body) == 1 and len(stmt.body[0].body) == 0): + if stmt.body.body: # Body not empty return False @@ -1328,9 +1327,7 @@ def visit_While(self, n: ast3.While) -> WhileStmt: # If(expr test, stmt* body, stmt* orelse) def visit_If(self, n: ast3.If) -> IfStmt: - node = IfStmt( - [self.visit(n.test)], [self.as_required_block(n.body)], self.as_block(n.orelse) - ) + node = IfStmt(self.visit(n.test), self.as_required_block(n.body), self.as_block(n.orelse)) return self.set_line(node, n) # With(withitem* items, stmt* body, string? type_comment) diff --git a/mypy/nodes.py b/mypy/nodes.py index 4168b2e00f15..db219d3630a8 100644 --- a/mypy/nodes.py +++ b/mypy/nodes.py @@ -1836,8 +1836,8 @@ class IfStmt(Statement): __match_args__ = ("expr", "body", "else_body", "unreachable_else") - expr: list[Expression] - body: list[Block] + expr: Expression + body: Block else_body: Block | None # (If there is actually no else statement, semantic analysis may nevertheless create an # empty else block and mark it permanently as unreachable to tell that the control flow @@ -1846,7 +1846,7 @@ class IfStmt(Statement): # (Type checking may modify this flag repeatedly to indicate whether an actually available # or unavailable else block is unreachable, considering the current type information.) - def __init__(self, expr: list[Expression], body: list[Block], else_body: Block | None) -> None: + def __init__(self, expr: Expression, body: Block, else_body: Block | None) -> None: super().__init__() self.expr = expr self.body = body diff --git a/mypy/partially_defined.py b/mypy/partially_defined.py index 2bff1669becb..35a0cc02bfdc 100644 --- a/mypy/partially_defined.py +++ b/mypy/partially_defined.py @@ -391,13 +391,10 @@ def visit_assignment_expr(self, o: AssignmentExpr) -> None: self.process_lvalue(o.target) def visit_if_stmt(self, o: IfStmt) -> None: - for e in o.expr: - e.accept(self) + o.expr.accept(self) self.tracker.start_branch_statement() - for b in o.body: - if b.is_unreachable: - continue - b.accept(self) + if not o.body.is_unreachable: + o.body.accept(self) self.tracker.next_branch() if o.unreachable_else: self.tracker.skip_branch() diff --git a/mypy/plugins/dataclasses.py b/mypy/plugins/dataclasses.py index 97faadbad064..3eb8d815ad6e 100644 --- a/mypy/plugins/dataclasses.py +++ b/mypy/plugins/dataclasses.py @@ -513,9 +513,8 @@ def reset_init_only_vars(self, info: TypeInfo, attributes: list[DataclassAttribu def _get_assignment_statements_from_if_statement( self, stmt: IfStmt ) -> Iterator[AssignmentStmt]: - for body in stmt.body: - if not body.is_unreachable: - yield from self._get_assignment_statements_from_block(body) + if not stmt.body.is_unreachable: + yield from self._get_assignment_statements_from_block(stmt.body) if stmt.else_body is not None and not stmt.else_body.is_unreachable: yield from self._get_assignment_statements_from_block(stmt.else_body) diff --git a/mypy/reachability.py b/mypy/reachability.py index 132c269e96af..7f09a0867003 100644 --- a/mypy/reachability.py +++ b/mypy/reachability.py @@ -51,29 +51,25 @@ def infer_reachability_of_if_statement(s: IfStmt, options: Options) -> None: - for i in range(len(s.expr)): - result = infer_condition_value(s.expr[i], options) - if result in (ALWAYS_FALSE, MYPY_FALSE): - # The condition is considered always false, so we skip the if/elif body. - mark_block_unreachable(s.body[i]) - elif result in (ALWAYS_TRUE, MYPY_TRUE): - # This condition is considered always true, so all of the remaining - # elif/else bodies should not be checked. - if result == MYPY_TRUE: - # This condition is false at runtime; this will affect - # import priorities. - mark_block_mypy_only(s.body[i]) - for body in s.body[i + 1 :]: - mark_block_unreachable(body) - - # Make sure else body always exists and is marked as - # unreachable so the type checker always knows that - # all control flow paths will flow through the if - # statement body. - if not s.else_body: - s.else_body = Block([]) - mark_block_unreachable(s.else_body) - break + result = infer_condition_value(s.expr, options) + if result in (ALWAYS_FALSE, MYPY_FALSE): + # The condition is considered always false, so we skip the if/elif body. + mark_block_unreachable(s.body) + elif result in (ALWAYS_TRUE, MYPY_TRUE): + # This condition is considered always true, so all of the remaining + # elif/else bodies should not be checked. + if result == MYPY_TRUE: + # This condition is false at runtime; this will affect + # import priorities. + mark_block_mypy_only(s.body) + + # Make sure else body always exists and is marked as + # unreachable so the type checker always knows that + # all control flow paths will flow through the if + # statement body. + if not s.else_body: + s.else_body = Block([]) + mark_block_unreachable(s.else_body) def infer_reachability_of_match_statement(s: MatchStmt, options: Options) -> None: diff --git a/mypy/semanal.py b/mypy/semanal.py index f38a71cb16e3..4f87afc22f34 100644 --- a/mypy/semanal.py +++ b/mypy/semanal.py @@ -622,8 +622,7 @@ def prepare_typing_namespace(self, file_node: MypyFile, aliases: dict[str, str]) def helper(defs: list[Statement]) -> None: for stmt in defs.copy(): if isinstance(stmt, IfStmt): - for body in stmt.body: - helper(body.body) + helper(stmt.body.body) if stmt.else_body: helper(stmt.else_body.body) if ( @@ -5499,9 +5498,8 @@ def visit_continue_stmt(self, s: ContinueStmt) -> None: def visit_if_stmt(self, s: IfStmt) -> None: self.statement = s infer_reachability_of_if_statement(s, self.options) - for i in range(len(s.expr)): - s.expr[i].accept(self) - self.visit_block(s.body[i]) + s.expr.accept(self) + self.visit_block(s.body) self.visit_block_maybe(s.else_body) def visit_try_stmt(self, s: TryStmt) -> None: diff --git a/mypy/semanal_pass1.py b/mypy/semanal_pass1.py index 266fd236a01f..bce15429186e 100644 --- a/mypy/semanal_pass1.py +++ b/mypy/semanal_pass1.py @@ -118,10 +118,8 @@ def visit_import(self, node: Import) -> None: def visit_if_stmt(self, s: IfStmt) -> None: infer_reachability_of_if_statement(s, self.options) - for expr in s.expr: - expr.accept(self) - for node in s.body: - node.accept(self) + s.expr.accept(self) + s.body.accept(self) if s.else_body: s.else_body.accept(self) diff --git a/mypy/strconv.py b/mypy/strconv.py index 168a8bcffdc7..e219b6796dd0 100644 --- a/mypy/strconv.py +++ b/mypy/strconv.py @@ -263,11 +263,7 @@ def visit_return_stmt(self, o: mypy.nodes.ReturnStmt) -> str: return self.dump([o.expr], o) def visit_if_stmt(self, o: mypy.nodes.IfStmt) -> str: - a: list[Any] = [] - for i in range(len(o.expr)): - a.append(("If", [o.expr[i]])) - a.append(("Then", o.body[i].body)) - + a: list[Any] = [("If", [o.expr]), ("Then", o.body.body)] if not o.else_body: return self.dump(a, o) else: diff --git a/mypy/stubgen.py b/mypy/stubgen.py index 5d6149b97507..48adfecfc8ef 100755 --- a/mypy/stubgen.py +++ b/mypy/stubgen.py @@ -1222,13 +1222,12 @@ def visit_type_alias_stmt(self, o: TypeAliasStmt) -> None: def visit_if_stmt(self, o: IfStmt) -> None: # Ignore if __name__ == '__main__'. - expr = o.expr[0] if ( - isinstance(expr, ComparisonExpr) - and isinstance(expr.operands[0], NameExpr) - and isinstance(expr.operands[1], StrExpr) - and expr.operands[0].name == "__name__" - and "__main__" in expr.operands[1].value + isinstance(o.expr, ComparisonExpr) + and isinstance(o.expr.operands[0], NameExpr) + and isinstance(o.expr.operands[1], StrExpr) + and o.expr.operands[0].name == "__name__" + and "__main__" in o.expr.operands[1].value ): return super().visit_if_stmt(o) diff --git a/mypy/traverser.py b/mypy/traverser.py index c313e70308ad..7347f463acfa 100644 --- a/mypy/traverser.py +++ b/mypy/traverser.py @@ -208,10 +208,8 @@ def visit_del_stmt(self, o: DelStmt, /) -> None: o.expr.accept(self) def visit_if_stmt(self, o: IfStmt, /) -> None: - for e in o.expr: - e.accept(self) - for b in o.body: - b.accept(self) + o.expr.accept(self) + o.body.accept(self) if o.else_body: o.else_body.accept(self) @@ -1179,8 +1177,7 @@ def visit_assert_stmt(self, o: AssertStmt, /) -> None: pass def visit_if_stmt(self, o: IfStmt, /) -> None: - for b in o.body: - b.accept(self) + o.body.accept(self) if o.else_body: o.else_body.accept(self) diff --git a/mypy/treetransform.py b/mypy/treetransform.py index 1a76a50a2d94..f1590382926e 100644 --- a/mypy/treetransform.py +++ b/mypy/treetransform.py @@ -360,9 +360,7 @@ def visit_del_stmt(self, node: DelStmt) -> DelStmt: def visit_if_stmt(self, node: IfStmt) -> IfStmt: return IfStmt( - self.expressions(node.expr), - self.blocks(node.body), - self.optional_block(node.else_body), + self.expr(node.expr), self.block(node.body), self.optional_block(node.else_body) ) def visit_break_stmt(self, node: BreakStmt) -> BreakStmt: diff --git a/mypyc/irbuild/statement.py b/mypyc/irbuild/statement.py index 22ca267d7a2b..0c34773fa6f6 100644 --- a/mypyc/irbuild/statement.py +++ b/mypyc/irbuild/statement.py @@ -440,12 +440,9 @@ def transform_if_stmt(builder: IRBuilder, stmt: IfStmt) -> None: if_body, next = BasicBlock(), BasicBlock() else_body = BasicBlock() if stmt.else_body else next - # If statements are normalized - assert len(stmt.expr) == 1 - - process_conditional(builder, stmt.expr[0], if_body, else_body) + process_conditional(builder, stmt.expr, if_body, else_body) builder.activate_block(if_body) - builder.accept(stmt.body[0]) + builder.accept(stmt.body) builder.goto(next) if stmt.else_body: builder.activate_block(else_body)