Skip to content

Comments

⚡️ Speed up function extract_imports_for_class by 427% in PR #1335 (gpu-flag)#1350

Closed
codeflash-ai[bot] wants to merge 5 commits intogpu-flagfrom
codeflash/optimize-pr1335-2026-02-04T00.49.48
Closed

⚡️ Speed up function extract_imports_for_class by 427% in PR #1335 (gpu-flag)#1350
codeflash-ai[bot] wants to merge 5 commits intogpu-flagfrom
codeflash/optimize-pr1335-2026-02-04T00.49.48

Conversation

@codeflash-ai
Copy link
Contributor

@codeflash-ai codeflash-ai bot commented Feb 4, 2026

⚡️ This pull request contains optimizations for PR #1335

If you approve this dependent PR, these changes will be merged into the original PR branch gpu-flag.

This PR will be automatically closed if the original PR is merged.


📄 427% (4.27x) speedup for extract_imports_for_class in codeflash/context/code_context_extractor.py

⏱️ Runtime : 2.69 milliseconds 510 microseconds (best of 250 runs)

📝 Explanation and details

The optimization achieves a 426% speedup (from 2.69ms to 510μs) by addressing two critical performance bottlenecks identified in the line profiler:

Key Optimization 1: Replacing ast.walk() with Direct Traversal (65.8% → 2.5% of runtime)

The original code used ast.walk(class_node) which recursively visits every node in the AST tree (3,281 hits in profiling). The optimized version directly iterates over class_node.body, visiting only top-level class members. This is much more efficient because:

  • ast.walk() traverses the entire nested tree structure including function definitions, nested expressions, etc.
  • Direct iteration over class_node.body only examines the immediate class members where annotations and field calls actually appear
  • The optimization also consolidates the field() detection logic within the annotation check, reducing redundant type checks

Key Optimization 2: Early Termination in Import Collection (0.7% → 2.3% of runtime)

The optimized code tracks remaining_names and exits the import search loop as soon as all needed names are found. In the large-scale test with 200 imports, this cuts iterations from 589 to 420, terminating 169 iterations early (28% reduction). The remaining_names.discard() operations ensure we stop searching once all imports are located.

Impact on Hot Paths:

Based on function_references, this function is called in:

  1. Test benchmarking infrastructure - where it's repeatedly executed for performance measurement
  2. Code context extraction workflows - potentially called multiple times when analyzing class hierarchies

The optimization is particularly effective for:

  • Large classes with many annotations (890% speedup in test with 100 annotations)
  • Modules with many imports (176-300% speedup when 20-50 imports present)
  • Complex nested generics (334% speedup for deeply nested types)
  • Dataclass-heavy codebases where field() detection is critical

The early termination optimization is especially valuable when the needed imports appear early in the module's import statements, which is common in well-organized Python codebases.

Correctness verification report:

Test Status
⚙️ Existing Unit Tests 🔘 None Found
🌀 Generated Regression Tests 39 Passed
⏪ Replay Tests 🔘 None Found
🔎 Concolic Coverage Tests 🔘 None Found
📊 Tests Coverage 96.1%
🌀 Click to see Generated Regression Tests
from __future__ import annotations

# imports
import ast
import ast as _ast_module
import textwrap

import pytest  # used for our unit tests
from codeflash.context.code_context_extractor import extract_imports_for_class

# unit tests

def _get_class_node_and_module(source: str, class_name: str):
    """
    Helper to parse a source string into an AST, return the module AST and the specific ClassDef.
    """
    tree = ast.parse(source)
    # Find the class with the given name (there should be exactly one in our tests)
    for node in tree.body:
        if isinstance(node, ast.ClassDef) and node.name == class_name:
            return tree, node
    raise AssertionError(f"Class {class_name} not found in provided source")

def test_basic_extracts_decorator_annotation_and_field_imports():
    # Basic scenario: dataclass decorator, List annotation and field() call.
    source = textwrap.dedent(
        """\
        from dataclasses import dataclass, field
        from typing import List
        import abc

        @dataclass
        class MyClass:
            x: List[int]
            y = field(default=0)
        """
    )
    # Parse and locate the class
    module_tree, class_node = _get_class_node_and_module(source, "MyClass")

    # Run the function under test
    codeflash_output = extract_imports_for_class(module_tree, class_node, source); result = codeflash_output # 27.2μs -> 7.40μs (267% faster)

    # We expect the dataclasses import line (provides dataclass and field) and typing.List
    expected = "from dataclasses import dataclass, field\nfrom typing import List"

def test_attribute_decorator_and_attribute_base_classes():
    # Edge: decorators and base classes given as attributes, e.g., dataclasses.dataclass and abc.ABC
    source = textwrap.dedent(
        """\
        import dataclasses
        import abc

        @dataclasses.dataclass
        class AttrClass(abc.ABC):
            pass
        """
    )
    module_tree, class_node = _get_class_node_and_module(source, "AttrClass")

    # Execute
    codeflash_output = extract_imports_for_class(module_tree, class_node, source); result = codeflash_output # 16.8μs -> 5.43μs (209% faster)

    # Both import lines should be returned, in the order they appear
    expected = "import dataclasses\nimport abc"

def test_import_alias_and_from_import_alias_handling():
    # Alias handling for both import ... as and from ... import ... as
    source = textwrap.dedent(
        """\
        import numpy as np
        from typing import List as L

        class AliasClass:
            arr: np.ndarray
            items: L[int]
        """
    )
    module_tree, class_node = _get_class_node_and_module(source, "AliasClass")

    codeflash_output = extract_imports_for_class(module_tree, class_node, source); result = codeflash_output # 22.5μs -> 6.62μs (239% faster)

    # The "import numpy as np" line should match because alias.asname == 'np'
    # The "from typing import List as L" line should match because alias.asname == 'L'
    expected = "import numpy as np\nfrom typing import List as L"

def test_union_and_tuple_annotations_collect_names_from_annotations():
    # Complex annotations: union using '|' (BinOp handling) and tuple annotation (Tuple handling)
    source = textwrap.dedent(
        """\
        from mod_a import A
        from mod_b import B
        from mod_c import C, D

        class ComplexAnn:
            a: A | B
            b: (C, D)
        """
    )
    module_tree, class_node = _get_class_node_and_module(source, "ComplexAnn")

    codeflash_output = extract_imports_for_class(module_tree, class_node, source); result = codeflash_output # 25.4μs -> 7.37μs (245% faster)

    # mod_a, mod_b, and mod_c import lines should be present, in that order
    expected = "from mod_a import A\nfrom mod_b import B\nfrom mod_c import C, D"

def test_no_matching_imports_returns_empty_string():
    # When only builtins are used (int, str), nothing should be returned
    source = textwrap.dedent(
        """\
        class BuiltinOnly:
            x: int
            def method(self) -> str:
                return 'ok'
        """
    )
    module_tree, class_node = _get_class_node_and_module(source, "BuiltinOnly")

    codeflash_output = extract_imports_for_class(module_tree, class_node, source); result = codeflash_output # 20.9μs -> 3.67μs (469% faster)

def test_single_import_line_with_multiple_names_only_returned_once():
    # Ensure that when multiple needed names come from the same import-from line,
    # that line only appears once in the output.
    source = textwrap.dedent(
        """\
        from typing import List, Dict

        class DuplicateImports:
            a: List[int]
            b: Dict[str, int]
        """
    )
    module_tree, class_node = _get_class_node_and_module(source, "DuplicateImports")

    codeflash_output = extract_imports_for_class(module_tree, class_node, source); result = codeflash_output # 25.6μs -> 6.27μs (308% faster)

    expected = "from typing import List, Dict"

def test_large_scale_many_imports_with_subset_used():
    # Large-scale scenario: many import lines but only a subset are needed by the class.
    # We generate 200 import lines and reference every 4th Name to create ~50 needed imports.
    num_imports = 200
    imports = [f"from mod_{i} import Name{i}" for i in range(num_imports)]
    # Build class body referencing every 4th Name via annotations
    used_indices = list(range(0, num_imports, 4))  # roughly 50 names
    ann_lines = [f"    a{i}: Name{idx}" for i, idx in enumerate(used_indices)]
    # Compose the full source; the import block must come first so line numbers match
    source_lines = imports + ["", "class BigClass:"] + (ann_lines or ["    pass"])
    source = "\n".join(source_lines) + "\n"
    module_tree = ast.parse(source)

    # Find the class node
    class_node = None
    for node in module_tree.body:
        if isinstance(node, ast.ClassDef) and node.name == "BigClass":
            class_node = node
            break

    # Execute the extractor
    codeflash_output = extract_imports_for_class(module_tree, class_node, source); result = codeflash_output # 244μs -> 87.3μs (180% faster)

    # Expected lines are exactly the import lines for the used indices, in the original order
    expected_lines = [source_lines[i] for i in used_indices]  # because import lines start at index 0
    expected = "\n".join(expected_lines)

def test_field_call_detection_when_field_is_used_in_class_body():
    # Ensure that simple calls like field() are detected even if no annotation references them.
    source = textwrap.dedent(
        """\
        from dataclasses import field

        class FieldUsage:
            x = field(default_factory=list)
        """
    )
    module_tree, class_node = _get_class_node_and_module(source, "FieldUsage")

    codeflash_output = extract_imports_for_class(module_tree, class_node, source); result = codeflash_output # 17.5μs -> 4.25μs (312% faster)

    expected = "from dataclasses import field"
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.
import ast

import pytest
from codeflash.context.code_context_extractor import extract_imports_for_class

class TestExtractImportsForClassBasic:
    """Basic test cases for extract_imports_for_class function."""

    def test_simple_class_with_single_import(self):
        """Test extracting imports for a class with a single base class import."""
        source = """from abc import ABC

class MyClass(ABC):
    pass
"""
        module_tree = ast.parse(source)
        class_node = module_tree.body[1]
        codeflash_output = extract_imports_for_class(module_tree, class_node, source); result = codeflash_output # 12.0μs -> 4.45μs (170% faster)

    def test_class_with_dataclass_decorator(self):
        """Test extracting imports for a class with @dataclass decorator."""
        source = """from dataclasses import dataclass

@dataclass
class MyClass:
    pass
"""
        module_tree = ast.parse(source)
        class_node = module_tree.body[1]
        codeflash_output = extract_imports_for_class(module_tree, class_node, source); result = codeflash_output # 11.7μs -> 4.15μs (183% faster)

    def test_class_with_type_annotation(self):
        """Test extracting imports for a class with type annotations."""
        source = """from typing import List

class MyClass:
    items: List[str]
"""
        module_tree = ast.parse(source)
        class_node = module_tree.body[1]
        codeflash_output = extract_imports_for_class(module_tree, class_node, source); result = codeflash_output # 16.7μs -> 5.09μs (229% faster)

    def test_class_with_multiple_imports(self):
        """Test extracting multiple imports for a class."""
        source = """from abc import ABC
from typing import List

class MyClass(ABC):
    items: List[str]
"""
        module_tree = ast.parse(source)
        class_node = module_tree.body[2]
        codeflash_output = extract_imports_for_class(module_tree, class_node, source); result = codeflash_output # 18.8μs -> 6.13μs (207% faster)

    def test_class_with_attribute_base_class(self):
        """Test extracting imports when base class is an attribute (e.g., abc.ABC)."""
        source = """import abc

class MyClass(abc.ABC):
    pass
"""
        module_tree = ast.parse(source)
        class_node = module_tree.body[1]
        codeflash_output = extract_imports_for_class(module_tree, class_node, source); result = codeflash_output # 13.3μs -> 4.50μs (196% faster)

    def test_class_with_no_imports_needed(self):
        """Test that empty string is returned when no imports are needed."""
        source = """class MyClass:
    pass
"""
        module_tree = ast.parse(source)
        class_node = module_tree.body[0]
        codeflash_output = extract_imports_for_class(module_tree, class_node, source); result = codeflash_output # 8.48μs -> 2.63μs (222% faster)

    def test_class_with_decorator_as_call(self):
        """Test extracting imports for decorator that is a function call."""
        source = """from dataclasses import dataclass, field

@dataclass(frozen=True)
class MyClass:
    pass
"""
        module_tree = ast.parse(source)
        class_node = module_tree.body[1]
        codeflash_output = extract_imports_for_class(module_tree, class_node, source); result = codeflash_output # 16.0μs -> 4.68μs (242% faster)

    def test_class_with_field_in_annotation(self):
        """Test extracting imports for field() calls in class body."""
        source = """from dataclasses import field

class MyClass:
    items: list = field(default_factory=list)
"""
        module_tree = ast.parse(source)
        class_node = module_tree.body[1]
        codeflash_output = extract_imports_for_class(module_tree, class_node, source); result = codeflash_output # 18.7μs -> 5.09μs (267% faster)

class TestExtractImportsForClassEdgeCases:
    """Edge case tests for extract_imports_for_class function."""

    def test_class_with_aliased_import(self):
        """Test extracting imports when class uses aliased imports."""
        source = """from typing import List as L

class MyClass:
    items: L[str]
"""
        module_tree = ast.parse(source)
        class_node = module_tree.body[1]
        codeflash_output = extract_imports_for_class(module_tree, class_node, source); result = codeflash_output # 16.5μs -> 5.14μs (222% faster)

    def test_class_with_generic_types(self):
        """Test extracting imports for generic type annotations."""
        source = """from typing import Dict, List, Optional

class MyClass:
    mapping: Dict[str, List[int]]
    optional_val: Optional[str]
"""
        module_tree = ast.parse(source)
        class_node = module_tree.body[1]
        codeflash_output = extract_imports_for_class(module_tree, class_node, source); result = codeflash_output # 29.1μs -> 7.61μs (282% faster)

    def test_class_with_union_type_annotation(self):
        """Test extracting imports for Union types using | syntax."""
        source = """from typing import Union

class MyClass:
    value: Union[str, int]
"""
        module_tree = ast.parse(source)
        class_node = module_tree.body[1]
        codeflash_output = extract_imports_for_class(module_tree, class_node, source); result = codeflash_output # 19.4μs -> 5.63μs (244% faster)

    def test_class_with_nested_generics(self):
        """Test extracting imports for nested generic types."""
        source = """from typing import Dict, List

class MyClass:
    nested: Dict[str, List[Dict[str, int]]]
"""
        module_tree = ast.parse(source)
        class_node = module_tree.body[1]
        codeflash_output = extract_imports_for_class(module_tree, class_node, source); result = codeflash_output # 27.8μs -> 6.77μs (310% faster)

    def test_class_with_duplicate_imports(self):
        """Test that duplicate imports are not added multiple times."""
        source = """from typing import List

class MyClass:
    items1: List[str]
    items2: List[int]
"""
        module_tree = ast.parse(source)
        class_node = module_tree.body[1]
        codeflash_output = extract_imports_for_class(module_tree, class_node, source); result = codeflash_output # 22.8μs -> 5.86μs (289% faster)
        # Count occurrences of the import
        count = result.count("from typing import List")

    def test_class_with_unused_imports_in_module(self):
        """Test that only used imports are extracted."""
        source = """from typing import List
from datetime import datetime

class MyClass:
    items: List[str]
"""
        module_tree = ast.parse(source)
        class_node = module_tree.body[2]
        codeflash_output = extract_imports_for_class(module_tree, class_node, source); result = codeflash_output # 17.0μs -> 5.47μs (210% faster)

    def test_class_with_multiple_base_classes(self):
        """Test extracting imports for class with multiple base classes."""
        source = """from abc import ABC, ABCMeta

class MyClass(ABC, ABCMeta):
    pass
"""
        module_tree = ast.parse(source)
        class_node = module_tree.body[1]
        codeflash_output = extract_imports_for_class(module_tree, class_node, source); result = codeflash_output # 13.1μs -> 4.51μs (190% faster)

    def test_class_with_multiple_decorators(self):
        """Test extracting imports for class with multiple decorators."""
        source = """from dataclasses import dataclass
from functools import wraps

@dataclass
@wraps
class MyClass:
    pass
"""
        module_tree = ast.parse(source)
        class_node = module_tree.body[2]
        codeflash_output = extract_imports_for_class(module_tree, class_node, source); result = codeflash_output # 13.8μs -> 5.50μs (150% faster)

    def test_class_with_attribute_decorator(self):
        """Test extracting imports when decorator uses attribute access."""
        source = """import functools

@functools.wraps
class MyClass:
    pass
"""
        module_tree = ast.parse(source)
        class_node = module_tree.body[1]
        codeflash_output = extract_imports_for_class(module_tree, class_node, source); result = codeflash_output # 12.6μs -> 3.00μs (320% faster)

    def test_empty_class_body(self):
        """Test class with empty body and only docstring."""
        source = """from abc import ABC

class MyClass(ABC):
    \"\"\"A class that inherits from ABC.\"\"\"
"""
        module_tree = ast.parse(source)
        class_node = module_tree.body[1]
        codeflash_output = extract_imports_for_class(module_tree, class_node, source); result = codeflash_output # 12.9μs -> 4.30μs (200% faster)

    def test_class_with_complex_decorator_arguments(self):
        """Test decorator with complex arguments."""
        source = """from dataclasses import dataclass

@dataclass(frozen=True, init=True)
class MyClass:
    pass
"""
        module_tree = ast.parse(source)
        class_node = module_tree.body[1]
        codeflash_output = extract_imports_for_class(module_tree, class_node, source); result = codeflash_output # 17.3μs -> 4.48μs (287% faster)

    def test_class_with_standard_library_types(self):
        """Test class using standard library types without explicit imports."""
        source = """class MyClass:
    items: list
    mapping: dict
"""
        module_tree = ast.parse(source)
        class_node = module_tree.body[0]
        codeflash_output = extract_imports_for_class(module_tree, class_node, source); result = codeflash_output # 16.1μs -> 3.82μs (323% faster)

    def test_class_with_tuple_annotation(self):
        """Test class with tuple type annotation."""
        source = """from typing import Tuple

class MyClass:
    values: Tuple[str, int, float]
"""
        module_tree = ast.parse(source)
        class_node = module_tree.body[1]
        codeflash_output = extract_imports_for_class(module_tree, class_node, source); result = codeflash_output # 21.2μs -> 5.99μs (254% faster)

    def test_class_with_whitespace_in_imports(self):
        """Test that imports with various whitespace are correctly extracted."""
        source = """from typing import (
    List,
    Dict,
    Optional
)

class MyClass:
    items: List[str]
"""
        module_tree = ast.parse(source)
        class_node = module_tree.body[1]
        codeflash_output = extract_imports_for_class(module_tree, class_node, source); result = codeflash_output # 16.7μs -> 5.23μs (219% faster)

class TestExtractImportsForClassLargeScale:
    """Large scale test cases for extract_imports_for_class function."""

    def test_class_with_many_annotations(self):
        """Test class with many type annotations."""
        # Create a class with 100 annotations using the same import
        annotations = "\n    ".join(
            [f"field_{i}: str" for i in range(100)]
        )
        source = f"""from typing import List

class MyClass:
    {annotations}
"""
        module_tree = ast.parse(source)
        class_node = module_tree.body[1]
        codeflash_output = extract_imports_for_class(module_tree, class_node, source); result = codeflash_output # 344μs -> 34.8μs (890% faster)

    def test_class_with_many_different_imports(self):
        """Test class that needs many different imports."""
        # Create 50 different imports and use them in base classes
        imports_str = "\n".join(
            [f"from module{i} import Class{i}" for i in range(50)]
        )
        bases_str = ", ".join([f"Class{i}" for i in range(50)])
        source = f"""{imports_str}

class MyClass({bases_str}):
    pass
"""
        module_tree = ast.parse(source)
        class_node = module_tree.body[50]
        codeflash_output = extract_imports_for_class(module_tree, class_node, source); result = codeflash_output # 93.3μs -> 33.9μs (176% faster)

    def test_class_with_deeply_nested_types(self):
        """Test class with deeply nested generic types."""
        source = """from typing import Dict, List, Optional

class MyClass:
    deep: Dict[str, List[Dict[str, List[Dict[str, Optional[int]]]]]]
"""
        module_tree = ast.parse(source)
        class_node = module_tree.body[1]
        codeflash_output = extract_imports_for_class(module_tree, class_node, source); result = codeflash_output # 38.8μs -> 8.95μs (334% faster)

    def test_class_with_many_decorators(self):
        """Test class with many decorators."""
        decorators = "\n".join([f"@decorator_{i}" for i in range(30)])
        source = f"""from module import decorator_0, decorator_1

{decorators}
class MyClass:
    pass
"""
        module_tree = ast.parse(source)
        class_node = module_tree.body[1]
        codeflash_output = extract_imports_for_class(module_tree, class_node, source); result = codeflash_output # 49.3μs -> 10.5μs (369% faster)

    def test_class_with_mixed_imports_and_annotations(self):
        """Test class with both many imports and annotations."""
        imports_str = "\n".join(
            [f"from typing import Type{i}" for i in range(20)]
        )
        annotations = "\n    ".join(
            [f"field_{i}: Type{i}" for i in range(20)]
        )
        source = f"""{imports_str}

class MyClass:
    {annotations}
"""
        module_tree = ast.parse(source)
        class_node = module_tree.body[20]
        codeflash_output = extract_imports_for_class(module_tree, class_node, source); result = codeflash_output # 87.9μs -> 22.0μs (300% faster)
        # Check that all 20 imports are extracted
        lines = [line for line in result.split("\n") if line.strip()]

    def test_class_with_many_fields_same_type(self):
        """Test class with many fields of the same type."""
        annotations = "\n    ".join(
            [f"field_{i}: List[str]" for i in range(200)]
        )
        source = f"""from typing import List

class MyClass:
    {annotations}
"""
        module_tree = ast.parse(source)
        class_node = module_tree.body[1]
        codeflash_output = extract_imports_for_class(module_tree, class_node, source); result = codeflash_output # 1.20ms -> 119μs (900% faster)

    def test_module_with_many_classes(self):
        """Test extracting imports for a specific class in a module with many classes."""
        # Create a module with 50 classes
        classes_str = "\n".join(
            [f"class Class{i}:\n    pass\n" for i in range(50)]
        )
        source = f"""from abc import ABC

{classes_str}
class TargetClass(ABC):
    pass
"""
        module_tree = ast.parse(source)
        # Find the TargetClass node
        class_node = None
        for node in module_tree.body:
            if isinstance(node, ast.ClassDef) and node.name == "TargetClass":
                class_node = node
                break
        codeflash_output = extract_imports_for_class(module_tree, class_node, source); result = codeflash_output # 21.9μs -> 9.19μs (139% faster)

    def test_class_with_union_of_many_types(self):
        """Test class with Union of many different types."""
        source = """from typing import Union

class MyClass:
    value: Union[str, int, float, bool, list, dict, tuple, set, frozenset]
"""
        module_tree = ast.parse(source)
        class_node = module_tree.body[1]
        codeflash_output = extract_imports_for_class(module_tree, class_node, source); result = codeflash_output # 29.9μs -> 7.67μs (289% faster)

    def test_performance_with_large_module(self):
        """Test function performance with a large module (500 lines)."""
        # Create a large module with many imports and classes
        imports = "\n".join([f"from module{i} import Class{i}" for i in range(100)])
        classes = "\n\n".join(
            [f"class Class{i}:\n    pass" for i in range(100, 150)]
        )
        source = f"""{imports}

{classes}

class TargetClass:
    field: Class50
"""
        module_tree = ast.parse(source)
        class_node = None
        for node in module_tree.body:
            if isinstance(node, ast.ClassDef) and node.name == "TargetClass":
                class_node = node
                break
        codeflash_output = extract_imports_for_class(module_tree, class_node, source); result = codeflash_output # 49.9μs -> 25.2μs (98.2% faster)
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.

To edit these changes git checkout codeflash/optimize-pr1335-2026-02-04T00.49.48 and push.

Codeflash Static Badge

aseembits93 and others added 5 commits February 3, 2026 14:33
Add a `gpu` parameter to instrument tests with torch.cuda.Event timing
instead of time.perf_counter_ns() for measuring GPU kernel execution time.
Falls back to CPU timing when CUDA is not available/initialized.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Fix unused variables, single-item membership tests, unnecessary lambdas,
and ternary expressions that can use `or` operator.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
The optimization achieves a **426% speedup** (from 2.69ms to 510μs) by addressing two critical performance bottlenecks identified in the line profiler:

**Key Optimization 1: Replacing `ast.walk()` with Direct Traversal (65.8% → 2.5% of runtime)**

The original code used `ast.walk(class_node)` which recursively visits *every* node in the AST tree (3,281 hits in profiling). The optimized version directly iterates over `class_node.body`, visiting only top-level class members. This is much more efficient because:
- `ast.walk()` traverses the entire nested tree structure including function definitions, nested expressions, etc.
- Direct iteration over `class_node.body` only examines the immediate class members where annotations and field calls actually appear
- The optimization also consolidates the field() detection logic within the annotation check, reducing redundant type checks

**Key Optimization 2: Early Termination in Import Collection (0.7% → 2.3% of runtime)**

The optimized code tracks `remaining_names` and exits the import search loop as soon as all needed names are found. In the large-scale test with 200 imports, this cuts iterations from 589 to 420, terminating 169 iterations early (28% reduction). The `remaining_names.discard()` operations ensure we stop searching once all imports are located.

**Impact on Hot Paths:**

Based on `function_references`, this function is called in:
1. **Test benchmarking infrastructure** - where it's repeatedly executed for performance measurement
2. **Code context extraction workflows** - potentially called multiple times when analyzing class hierarchies

The optimization is particularly effective for:
- **Large classes with many annotations** (890% speedup in test with 100 annotations)
- **Modules with many imports** (176-300% speedup when 20-50 imports present)
- **Complex nested generics** (334% speedup for deeply nested types)
- **Dataclass-heavy codebases** where field() detection is critical

The early termination optimization is especially valuable when the needed imports appear early in the module's import statements, which is common in well-organized Python codebases.
@codeflash-ai codeflash-ai bot added ⚡️ codeflash Optimization PR opened by Codeflash AI 🎯 Quality: High Optimization Quality according to Codeflash labels Feb 4, 2026
@KRRT7
Copy link
Collaborator

KRRT7 commented Feb 19, 2026

Closing stale bot PR.

@KRRT7 KRRT7 closed this Feb 19, 2026
@KRRT7 KRRT7 deleted the codeflash/optimize-pr1335-2026-02-04T00.49.48 branch February 19, 2026 12:56
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

⚡️ codeflash Optimization PR opened by Codeflash AI 🎯 Quality: High Optimization Quality according to Codeflash

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants