Skip to content

fix: instrument PyTorch nn.Module forward method calls via instance#1418

Open
aseembits93 wants to merge 8 commits intomainfrom
fix/pytorch-forward-method-instrumentation
Open

fix: instrument PyTorch nn.Module forward method calls via instance#1418
aseembits93 wants to merge 8 commits intomainfrom
fix/pytorch-forward-method-instrumentation

Conversation

@aseembits93
Copy link
Contributor

Summary

  • Fix instrumentation of PyTorch nn.Module forward method when called via instance (e.g., model(input_data))
  • Add special handling for the pattern: model = ClassName(...); model(input_data) where model(input_data) internally calls forward()
  • Resolves "Ignoring test case that passed but had no runtime" error when optimizing forward methods

Problem

When running codeflash --function AlexNet.forward, tests with this pattern weren't being instrumented:

model = AlexNet(num_classes=10)
result = model(input_data)  # calls __call__ which invokes forward()

The instrumentation was looking for direct calls to forward or AlexNet, but model(input_data) matched neither.

Solution

  1. Added collect_instance_variables() to track variables assigned from class instantiations
  2. Modified find_and_update_line_node() to wrap calls to instance variables when optimizing forward methods
  3. Added test case specifically for this PyTorch pattern

Test plan

  • Added test_pytorch_forward_method_instrumentation test case
  • All existing instrumentation tests pass (19/19)
  • Verified with actual codeflash command - runtime is now properly measured

🤖 Generated with Claude Code

When optimizing a `forward` method on a class (e.g., AlexNet.forward),
the test pattern `model = AlexNet(...); model(input_data)` wasn't being
instrumented because the call `model(input_data)` didn't match the
expected function name "forward".

This fix adds special handling for the PyTorch nn.Module pattern:
- Collect variable names assigned from class instantiations
- Also wrap calls to those instance variables when optimizing `forward`

Fixes the "Ignoring test case that passed but had no runtime" error
when running codeflash on PyTorch model forward methods.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
def visit_FunctionDef(self, node: ast.FunctionDef, test_class_name: str | None = None) -> ast.FunctionDef:
if node.name.startswith("test_"):
# Collect instance variables for forward method instrumentation (PyTorch pattern)
self.collect_instance_variables(node)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit (non-blocking): instance_variable_names is accumulated across all test functions without being cleared. If a file has multiple test functions, variable names collected from test_a will persist when processing test_b. This could cause false-positive instrumentation if a variable name from one test happens to be called in another.

Consider clearing the set at the start of each test function:

Suggested change
self.collect_instance_variables(node)
self.instance_variable_names.clear()
self.collect_instance_variables(node)

@claude
Copy link
Contributor

claude bot commented Feb 6, 2026

PR Review Summary

Prek Checks

All prek checks pass (ruff check, ruff format)

Mypy Type Checking

⚠️ Fixed 23 mypy errors (41 → 18 remaining). Changes committed as fix: resolve mypy type errors in changed files.

Fixes applied:

  • Added type annotation for found_qualified_name: str | None in ImportAnalyzer
  • Added type annotation for memory_cache in TestsCache
  • Added type annotation for new_body: list[ast.stmt] in AsyncCallInstrumenter._process_test_function
  • Added type annotation for stack: list[ast.AST] in _optimized_instrument_statement
  • Added type annotation for context_stack: list[str] in AsyncDecoratorAdder
  • Fixed visit_AsyncFunctionDef/visit_FunctionDef return types in AsyncCallInstrumenter
  • Fixed _is_target_decorator parameter type to cst.BaseExpression

Remaining 18 errors are pre-existing (present on main branch).

⚠️ Push failed due to remote having newer commits. Commit is local only — will need to be pushed manually after a rebase.

Code Review

Existing comment still valid:

  • Previous comment about clearing instance_variable_names before each test function is still not addressed. The InjectPerfOnly transformer is created once per file and processes all test functions — instance variable names from test_a will persist when processing test_b, potentially causing false-positive instrumentation.

No new critical issues found. The implementation correctly:

  • Tracks model = ClassName(...) assignments in visit_Assign (keeping it active even after finding a target function)
  • Detects model(input_data) as invoking forward() via visit_Call in ImportAnalyzer
  • Extracts instance-to-class mappings in InstanceMappingExtractor for Jedi fallback
  • Instruments instance variable calls as codeflash_wrap(model, ...) in InjectPerfOnly
  • Has a dedicated test (test_pytorch_forward_method_instrumentation)

Test Coverage

File PR Branch Main Branch Delta
instrument_existing_tests.py 90.0% 90.5% -0.5%
discover_unit_tests.py 72.7% 74.0% -1.3%
Overall 78.4% 78.4% ~0.0%

Coverage decreased slightly in both changed files. The new code in discover_unit_tests.py (PyTorch nn.Module.forward detection in process_test_files, ~55 lines) is only partially covered by tests — the test in test_instrument_tests.py covers the instrumentation path but not the test discovery path (which requires a full Jedi environment with actual PyTorch imports).

The test_alexnet.py sample test file is a fixture for manual/integration testing, not exercised by the unit test suite.


Last updated: 2026-02-17T02:30:00Z

The optimized code achieves a **768% speedup** (from 1.30ms to 150μs) by replacing the expensive `ast.walk()` traversal with a targeted manual traversal strategy.

**Key Optimization:**

The original code uses `ast.walk(func_node)`, which recursively visits *every* node in the entire AST tree - including all expression nodes, operators, literals, and other irrelevant node types. The line profiler shows this single loop consumed 87.3% of the execution time (9.2ms out of 10.5ms).

The optimized version implements a **work-list algorithm** that only traverses statement nodes (body, orelse, finalbody, handlers). This dramatically reduces the number of nodes examined:
- Original: 1,889 nodes visited per call
- Optimized: ~317 nodes visited per call (83% reduction)

**Why This Works:**

1. **Targeted traversal**: Assignment statements (`ast.Assign`) can only appear as statements, not as expressions buried deep in the tree. By only following statement-level structure (`body`, `orelse`, etc.), we skip visiting thousands of irrelevant expression nodes.

2. **Cache-friendly**: Local variables `class_name` and `instance_vars` eliminate repeated `self.` attribute lookups, reducing pointer indirection.

3. **Early filtering**: The manual stack-based approach allows us to skip entire branches of the AST that can't contain assignments.

**Performance Impact by Test Case:**

- Simple cases (single assignment): ~500-600% faster
- Complex nested cases: ~429% faster  
- Large-scale scenario (300 assignments): **807% faster** - showing the optimization scales particularly well with code complexity

The optimization preserves all functionality (same nodes discovered, same instance variables collected) while dramatically reducing the algorithmic complexity from O(all_nodes) to O(statement_nodes).
@codeflash-ai
Copy link
Contributor

codeflash-ai bot commented Feb 6, 2026

⚡️ Codeflash found optimizations for this PR

📄 769% (7.69x) speedup for InjectPerfOnly.collect_instance_variables in codeflash/code_utils/instrument_existing_tests.py

⏱️ Runtime : 1.30 milliseconds 150 microseconds (best of 15 runs)

A dependent PR with the suggested changes has been created. Please review:

If you approve, it will be merged into this PR (branch fix/pytorch-forward-method-instrumentation).

Static Badge

…2026-02-06T22.39.42

⚡️ Speed up method `InjectPerfOnly.collect_instance_variables` by 769% in PR #1418 (`fix/pytorch-forward-method-instrumentation`)
@codeflash-ai
Copy link
Contributor

codeflash-ai bot commented Feb 11, 2026

@KRRT7
Copy link
Collaborator

KRRT7 commented Feb 11, 2026

@claude fix the mypy type issues and push

1 similar comment
@KRRT7
Copy link
Collaborator

KRRT7 commented Feb 11, 2026

@claude fix the mypy type issues and push

Comment on lines +526 to +533
for alias in node.names:
if alias.name == "*":
continue
imported_name = alias.asname if alias.asname else alias.name
self.imported_modules.add(imported_name)
if alias.asname:
self.alias_mapping[imported_name] = alias.name
self.generic_visit(node)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚡️Codeflash found 363% (3.63x) speedup for InstanceMappingExtractor.visit_ImportFrom in codeflash/discovery/discover_unit_tests.py

⏱️ Runtime : 1.52 milliseconds 329 microseconds (best of 129 runs)

📝 Explanation and details

The optimized code achieves a 362% speedup (from 1.52ms to 329μs) by eliminating an expensive, unnecessary call to generic_visit() that was consuming 80% of the original runtime.

Key Optimization

The original code unconditionally called self.generic_visit(node) at the end of visit_ImportFrom(), which triggers a full AST traversal of all child nodes. However, ImportFrom nodes only contain alias children that were already processed in the explicit for alias in node.names loop above.

The optimized version:

  1. Checks if a subclass has overridden generic_visit() - If so, calls it to preserve custom traversal behavior
  2. Otherwise, skips the heavy traversal entirely and only manually visits alias children if a visit_alias() handler exists

This change is safe because:

  • The InstanceMappingExtractor class doesn't override generic_visit() or define visit_alias()
  • All alias processing is already done explicitly in the loop
  • The redundant traversal was performing no additional work beyond what the loop already accomplished

Performance Impact

The line profiler shows the generic_visit() call dropped from 15.6ms (80% of runtime) to essentially zero. Test results demonstrate consistent speedups across all scenarios:

  • Simple imports: 165-225% faster
  • Large-scale imports (1000+ aliases): 372-386% faster
  • The optimization is most impactful for files with many import statements

The few test cases showing minor slowdowns (2-8%) involve trivial early-return paths (e.g., module=None) where the new conditional checks add overhead that exceeds the already-minimal original cost. This is a reasonable trade-off given the dramatic improvements in all realistic use cases.

Context

Since InstanceMappingExtractor is used for AST analysis (extracting import mappings for PyTorch module detection), it likely processes many Python files with numerous imports. This optimization significantly reduces the cost of analyzing import-heavy codebases.

Correctness verification report:

Test Status
⚙️ Existing Unit Tests 🔘 None Found
🌀 Generated Regression Tests 55 Passed
⏪ Replay Tests 🔘 None Found
🔎 Concolic Coverage Tests 🔘 None Found
📊 Tests Coverage 100.0%
🌀 Click to see Generated Regression Tests
import ast

import pytest  # used for our unit tests
from codeflash.discovery.discover_unit_tests import InstanceMappingExtractor

def test_single_import_without_alias_adds_module_name_to_imported_modules():
    # Create a visitor instance to test instance state changes.
    extractor = InstanceMappingExtractor()
    # Create an ImportFrom node: from mypkg import MyClass
    node = ast.ImportFrom(module="mypkg", names=[ast.alias(name="MyClass", asname=None)], level=0)
    # Call the specific visitor method directly as required.
    extractor.visit_ImportFrom(node) # 4.57μs -> 1.48μs (208% faster)

def test_single_import_with_alias_creates_alias_mapping_and_imported_modules_entry():
    # Create a fresh visitor instance.
    extractor = InstanceMappingExtractor()
    # Create an ImportFrom node with an alias: from pkg import RealName as AliasName
    node = ast.ImportFrom(module="pkg", names=[ast.alias(name="RealName", asname="AliasName")], level=0)
    # Invoke the method under test.
    extractor.visit_ImportFrom(node) # 4.64μs -> 1.75μs (165% faster)

def test_star_import_is_ignored_and_does_not_modify_mappings():
    extractor = InstanceMappingExtractor()
    # from pkg import *
    node = ast.ImportFrom(module="pkg", names=[ast.alias(name="*", asname=None)], level=0)
    extractor.visit_ImportFrom(node) # 4.20μs -> 1.29μs (225% faster)

def test_none_module_is_ignored_no_changes_made():
    extractor = InstanceMappingExtractor()
    # An ImportFrom with module set to None should be ignored.
    node = ast.ImportFrom(module=None, names=[ast.alias(name="Name", asname=None)], level=0)
    extractor.visit_ImportFrom(node) # 461ns -> 471ns (2.12% slower)

def test_empty_names_list_does_not_raise_and_makes_no_changes():
    extractor = InstanceMappingExtractor()
    # from pkg import  (no names) -> names list empty
    node = ast.ImportFrom(module="pkg", names=[], level=0)
    # Should simply return without error and without modifications.
    extractor.visit_ImportFrom(node) # 2.25μs -> 1.24μs (81.6% faster)

def test_empty_asname_string_treated_as_no_alias():
    extractor = InstanceMappingExtractor()
    # asname is empty string: from pkg import Name as ""
    # Empty string is falsy; the implementation checks `if alias.asname:`
    node = ast.ImportFrom(module="pkg", names=[ast.alias(name="Name", asname="")], level=0)
    extractor.visit_ImportFrom(node) # 4.54μs -> 1.63μs (178% faster)

def test_special_characters_in_names_and_aliases_handled_correctly():
    extractor = InstanceMappingExtractor()
    # Use names with underscores and digits and unusual but valid identifier-like strings.
    aliases = [
        ast.alias(name="Cls_1", asname=None),
        ast.alias(name="RealName2", asname="alias_2"),
        ast.alias(name="ÎnvalidUnicode", asname="alias_unicode"),  # unicode inside Python identifier context
    ]
    node = ast.ImportFrom(module="some_pkg", names=aliases, level=0)
    extractor.visit_ImportFrom(node) # 7.10μs -> 2.33μs (204% faster)

def test_duplicate_aliases_do_not_raise_and_result_in_set_behavior_for_imported_modules():
    extractor = InstanceMappingExtractor()
    # Two identical asnames and real names repeated; set semantics means only one stored.
    aliases = [
        ast.alias(name="A", asname="X"),
        ast.alias(name="A", asname="X"),
        ast.alias(name="B", asname="Y"),
    ]
    node = ast.ImportFrom(module="dup_pkg", names=aliases, level=0)
    extractor.visit_ImportFrom(node) # 7.22μs -> 2.40μs (202% faster)

def test_large_scale_many_aliases_performance_and_correctness():
    extractor = InstanceMappingExtractor()
    # Construct 1000 aliases where every even one has an asname, odds do not.
    num = 1000
    names = []
    for i in range(num):
        real = f"Real{i}"
        # Give half of them an alias, half not.
        if i % 2 == 0:
            # alias name will be Alias{i}
            names.append(ast.alias(name=real, asname=f"Alias{i}"))
        else:
            names.append(ast.alias(name=real, asname=None))
    node = ast.ImportFrom(module="big_pkg", names=names, level=0)
    # Ensure this runs quickly and does not raise.
    extractor.visit_ImportFrom(node) # 986μs -> 208μs (372% faster)
    # Validate sizes: imported_modules should contain the aliased names (num/2) plus the non-aliased original names (num/2).
    expected_imported = set()
    for i in range(num):
        if i % 2 == 0:
            expected_imported.add(f"Alias{i}")
        else:
            expected_imported.add(f"Real{i}")
    # alias_mapping should contain only the even indices mapping alias->real.
    expected_alias_map = {f"Alias{i}": f"Real{i}" for i in range(0, num, 2)}

def test_mixed_star_and_regular_imports_large_scale():
    extractor = InstanceMappingExtractor()
    # Build a large list mixing star imports, aliased imports and non-aliased imports.
    names = []
    num = 500
    for i in range(num):
        if i % 10 == 0:
            # include a star import occasionally
            names.append(ast.alias(name="*", asname=None))
        elif i % 3 == 0:
            names.append(ast.alias(name=f"Real{i}", asname=f"Alias{i}"))
        else:
            names.append(ast.alias(name=f"Real{i}", asname=None))
    node = ast.ImportFrom(module="mix_pkg", names=names, level=0)
    extractor.visit_ImportFrom(node) # 480μs -> 98.9μs (386% faster)
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.
import ast

# imports
import pytest
from codeflash.discovery.discover_unit_tests import InstanceMappingExtractor

def test_visit_import_from_single_import():
    """Test visiting a simple import statement with a single name."""
    code = "from module import name"
    tree = ast.parse(code)
    extractor = InstanceMappingExtractor()
    extractor.visit(tree)

def test_visit_import_from_with_alias():
    """Test visiting an import statement with an alias."""
    code = "from module import name as alias_name"
    tree = ast.parse(code)
    extractor = InstanceMappingExtractor()
    extractor.visit(tree)

def test_visit_import_from_multiple_names():
    """Test visiting an import statement with multiple names."""
    code = "from module import name1, name2, name3"
    tree = ast.parse(code)
    extractor = InstanceMappingExtractor()
    extractor.visit(tree)

def test_visit_import_from_mixed_aliases():
    """Test visiting an import statement with both aliased and non-aliased names."""
    code = "from module import name1, name2 as alias2, name3"
    tree = ast.parse(code)
    extractor = InstanceMappingExtractor()
    extractor.visit(tree)

def test_visit_import_from_star_import():
    """Test that star imports are skipped and not added to imported_modules."""
    code = "from module import *"
    tree = ast.parse(code)
    extractor = InstanceMappingExtractor()
    extractor.visit(tree)

def test_visit_import_from_no_module():
    """Test that ImportFrom with no module attribute is handled gracefully."""
    # Create an ImportFrom node with module=None (relative import)
    node = ast.ImportFrom(module=None, names=[ast.alias(name="name", asname=None)], level=1)
    extractor = InstanceMappingExtractor()
    
    # Should return None without error
    codeflash_output = extractor.visit_ImportFrom(node); result = codeflash_output # 571ns -> 602ns (5.15% slower)

def test_visit_import_from_torch_nn_module():
    """Test a realistic PyTorch import scenario."""
    code = "from torch import nn"
    tree = ast.parse(code)
    extractor = InstanceMappingExtractor()
    extractor.visit(tree)

def test_visit_import_from_multiple_imports_same_module():
    """Test visiting multiple ImportFrom statements from different modules."""
    code = """from module1 import name1
from module2 import name2 as alias2
from module3 import name3"""
    tree = ast.parse(code)
    extractor = InstanceMappingExtractor()
    extractor.visit(tree)

def test_visit_import_from_empty_module_name():
    """Test ImportFrom with an empty string as module name."""
    # Create an ImportFrom node with empty module
    node = ast.ImportFrom(module="", names=[ast.alias(name="name", asname=None)], level=0)
    extractor = InstanceMappingExtractor()
    
    # Process the node
    extractor.visit_ImportFrom(node) # 561ns -> 611ns (8.18% slower)

def test_visit_import_from_no_names():
    """Test ImportFrom with empty names list."""
    node = ast.ImportFrom(module="module", names=[], level=0)
    extractor = InstanceMappingExtractor()
    
    # Process the node
    extractor.visit_ImportFrom(node) # 2.52μs -> 1.23μs (104% faster)

def test_visit_import_from_special_characters_in_name():
    """Test import names with underscores and numbers."""
    code = "from module import _private, __dunder__, name123"
    tree = ast.parse(code)
    extractor = InstanceMappingExtractor()
    extractor.visit(tree)

def test_visit_import_from_long_module_path():
    """Test ImportFrom with a deeply nested module path."""
    code = "from package.subpackage.module import name"
    tree = ast.parse(code)
    extractor = InstanceMappingExtractor()
    extractor.visit(tree)

def test_visit_import_from_relative_import_with_level():
    """Test relative imports with various levels."""
    code = "from . import name"
    tree = ast.parse(code)
    extractor = InstanceMappingExtractor()
    extractor.visit(tree)

def test_visit_import_from_star_with_other_names():
    """Test that when star is mixed with other names, only star is skipped."""
    # Create a node with star and another name
    node = ast.ImportFrom(
        module="module",
        names=[ast.alias(name="*", asname=None), ast.alias(name="name", asname=None)],
        level=0
    )
    extractor = InstanceMappingExtractor()
    
    # Process the node
    extractor.visit_ImportFrom(node) # 5.86μs -> 1.92μs (205% faster)

def test_visit_import_from_alias_overwrites_same_name():
    """Test that importing the same name twice with different aliases uses the latest."""
    code = """from module import name as alias1
from module import name as alias2"""
    tree = ast.parse(code)
    extractor = InstanceMappingExtractor()
    extractor.visit(tree)

def test_visit_import_from_single_char_names():
    """Test importing single character names."""
    code = "from module import a, b, c"
    tree = ast.parse(code)
    extractor = InstanceMappingExtractor()
    extractor.visit(tree)

def test_visit_import_from_very_long_names():
    """Test importing very long names."""
    long_name = "a" * 100
    code = f"from module import {long_name}"
    tree = ast.parse(code)
    extractor = InstanceMappingExtractor()
    extractor.visit(tree)

def test_visit_import_from_alias_same_as_original():
    """Test when alias is the same as the original name."""
    code = "from module import name as name"
    tree = ast.parse(code)
    extractor = InstanceMappingExtractor()
    extractor.visit(tree)

def test_visit_import_from_preserves_state():
    """Test that the extractor preserves state across multiple visits."""
    code1 = "from module1 import name1"
    code2 = "from module2 import name2"
    
    tree1 = ast.parse(code1)
    tree2 = ast.parse(code2)
    
    extractor = InstanceMappingExtractor()
    extractor.visit(tree1)
    extractor.visit(tree2)

def test_visit_import_from_many_imports_single_statement():
    """Test importing many names from a single module."""
    # Generate a large import statement
    names = [f"name{i}" for i in range(500)]
    code = f"from module import {', '.join(names)}"
    tree = ast.parse(code)
    extractor = InstanceMappingExtractor()
    extractor.visit(tree)
    for i in range(500):
        pass

def test_visit_import_from_many_imports_multiple_statements():
    """Test processing many separate ImportFrom statements."""
    code_lines = [f"from module{i} import name{i}" for i in range(500)]
    code = "\n".join(code_lines)
    tree = ast.parse(code)
    extractor = InstanceMappingExtractor()
    extractor.visit(tree)
    for i in range(500):
        pass

def test_visit_import_from_many_aliases():
    """Test importing many names with aliases."""
    names = [f"name{i} as alias{i}" for i in range(500)]
    code = f"from module import {', '.join(names)}"
    tree = ast.parse(code)
    extractor = InstanceMappingExtractor()
    extractor.visit(tree)
    for i in range(500):
        pass

def test_visit_import_from_complex_mixed_imports():
    """Test complex scenario with many mixed imports."""
    code_lines = []
    for i in range(250):
        if i % 3 == 0:
            # Non-aliased import
            code_lines.append(f"from module{i} import name{i}")
        elif i % 3 == 1:
            # Aliased import
            code_lines.append(f"from module{i} import name{i} as alias{i}")
        else:
            # Multiple imports
            code_lines.append(f"from module{i} import name{i}a, name{i}b, name{i}c")
    
    code = "\n".join(code_lines)
    tree = ast.parse(code)
    extractor = InstanceMappingExtractor()
    extractor.visit(tree)

def test_visit_import_from_deeply_nested_code():
    """Test ImportFrom statements within nested code structures."""
    code = """
def func():
    if True:
        from module1 import name1
        for i in range(10):
            from module2 import name2
            try:
                from module3 import name3
            except:
                pass
    from module4 import name4
"""
    tree = ast.parse(code)
    extractor = InstanceMappingExtractor()
    extractor.visit(tree)

def test_visit_import_from_instance_mapping_remains_empty():
    """Test that instance_mapping remains empty after visiting ImportFrom nodes."""
    code = "from module import name1, name2 as alias2"
    tree = ast.parse(code)
    extractor = InstanceMappingExtractor()
    extractor.visit(tree)

def test_visit_import_from_set_behavior_duplicates():
    """Test that imported_modules set prevents duplicates."""
    # Manually create scenario where same name could be added twice
    node1 = ast.ImportFrom(module="module1", names=[ast.alias(name="name", asname=None)], level=0)
    node2 = ast.ImportFrom(module="module2", names=[ast.alias(name="name", asname=None)], level=0)
    
    extractor = InstanceMappingExtractor()
    extractor.visit_ImportFrom(node1) # 4.74μs -> 1.77μs (167% faster)
    initial_size = len(extractor.imported_modules)
    extractor.visit_ImportFrom(node2) # 2.77μs -> 892ns (211% faster)
    final_size = len(extractor.imported_modules)

def test_visit_import_from_generic_visit_called():
    """Test that generic_visit is properly called for child traversal."""
    # Create a nested AST structure that would require generic_visit
    code = """
from module import (
    name1,
    name2,
    name3
)
"""
    tree = ast.parse(code)
    extractor = InstanceMappingExtractor()
    extractor.visit(tree)

def test_visit_import_from_return_value_is_none():
    """Test that visit_ImportFrom returns None (standard NodeVisitor behavior)."""
    code = "from module import name"
    tree = ast.parse(code)
    
    # Get the ImportFrom node
    for node in ast.walk(tree):
        if isinstance(node, ast.ImportFrom):
            extractor = InstanceMappingExtractor()
            codeflash_output = extractor.visit_ImportFrom(node); result = codeflash_output
            break
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.

To test or edit this optimization locally git merge codeflash/optimize-pr1418-2026-02-17T13.07.22

Click to see suggested changes
Suggested change
for alias in node.names:
if alias.name == "*":
continue
imported_name = alias.asname if alias.asname else alias.name
self.imported_modules.add(imported_name)
if alias.asname:
self.alias_mapping[imported_name] = alias.name
self.generic_visit(node)
# Preserve original behavior: if a subclass has overridden generic_visit,
# call it to allow that custom traversal to run. Otherwise, avoid the
# heavy generic traversal and only visit alias children if a specific
# visit_alias handler exists.
if getattr(type(self), "generic_visit", ast.NodeVisitor.generic_visit) is not ast.NodeVisitor.generic_visit:
self.generic_visit(node)
return
has_visit_alias = hasattr(self, "visit_alias")
for alias in node.names:
if alias.name == "*":
continue
imported_name = alias.asname if alias.asname else alias.name
self.imported_modules.add(imported_name)
if alias.asname:
self.alias_mapping[imported_name] = alias.name
if has_visit_alias:
# If a visitor for alias nodes exists, invoke it to match
# the behavior of the default generic_visit which would have
# dispatched to visit_alias.
self.visit(alias)

Static Badge

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants

Comments