fix: instrument PyTorch nn.Module forward method calls via instance#1418
fix: instrument PyTorch nn.Module forward method calls via instance#1418aseembits93 wants to merge 8 commits intomainfrom
Conversation
When optimizing a `forward` method on a class (e.g., AlexNet.forward), the test pattern `model = AlexNet(...); model(input_data)` wasn't being instrumented because the call `model(input_data)` didn't match the expected function name "forward". This fix adds special handling for the PyTorch nn.Module pattern: - Collect variable names assigned from class instantiations - Also wrap calls to those instance variables when optimizing `forward` Fixes the "Ignoring test case that passed but had no runtime" error when running codeflash on PyTorch model forward methods. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
| def visit_FunctionDef(self, node: ast.FunctionDef, test_class_name: str | None = None) -> ast.FunctionDef: | ||
| if node.name.startswith("test_"): | ||
| # Collect instance variables for forward method instrumentation (PyTorch pattern) | ||
| self.collect_instance_variables(node) |
There was a problem hiding this comment.
Nit (non-blocking): instance_variable_names is accumulated across all test functions without being cleared. If a file has multiple test functions, variable names collected from test_a will persist when processing test_b. This could cause false-positive instrumentation if a variable name from one test happens to be called in another.
Consider clearing the set at the start of each test function:
| self.collect_instance_variables(node) | |
| self.instance_variable_names.clear() | |
| self.collect_instance_variables(node) |
PR Review SummaryPrek Checks✅ All prek checks pass (ruff check, ruff format) Mypy Type Checking
Fixes applied:
Remaining 18 errors are pre-existing (present on
Code ReviewExisting comment still valid:
No new critical issues found. The implementation correctly:
Test Coverage
Coverage decreased slightly in both changed files. The new code in The Last updated: 2026-02-17T02:30:00Z |
The optimized code achieves a **768% speedup** (from 1.30ms to 150μs) by replacing the expensive `ast.walk()` traversal with a targeted manual traversal strategy. **Key Optimization:** The original code uses `ast.walk(func_node)`, which recursively visits *every* node in the entire AST tree - including all expression nodes, operators, literals, and other irrelevant node types. The line profiler shows this single loop consumed 87.3% of the execution time (9.2ms out of 10.5ms). The optimized version implements a **work-list algorithm** that only traverses statement nodes (body, orelse, finalbody, handlers). This dramatically reduces the number of nodes examined: - Original: 1,889 nodes visited per call - Optimized: ~317 nodes visited per call (83% reduction) **Why This Works:** 1. **Targeted traversal**: Assignment statements (`ast.Assign`) can only appear as statements, not as expressions buried deep in the tree. By only following statement-level structure (`body`, `orelse`, etc.), we skip visiting thousands of irrelevant expression nodes. 2. **Cache-friendly**: Local variables `class_name` and `instance_vars` eliminate repeated `self.` attribute lookups, reducing pointer indirection. 3. **Early filtering**: The manual stack-based approach allows us to skip entire branches of the AST that can't contain assignments. **Performance Impact by Test Case:** - Simple cases (single assignment): ~500-600% faster - Complex nested cases: ~429% faster - Large-scale scenario (300 assignments): **807% faster** - showing the optimization scales particularly well with code complexity The optimization preserves all functionality (same nodes discovered, same instance variables collected) while dramatically reducing the algorithmic complexity from O(all_nodes) to O(statement_nodes).
⚡️ Codeflash found optimizations for this PR📄 769% (7.69x) speedup for
|
…2026-02-06T22.39.42 ⚡️ Speed up method `InjectPerfOnly.collect_instance_variables` by 769% in PR #1418 (`fix/pytorch-forward-method-instrumentation`)
|
This PR is now faster! 🚀 @KRRT7 accepted my optimizations from: |
|
@claude fix the mypy type issues and push |
1 similar comment
|
@claude fix the mypy type issues and push |
| for alias in node.names: | ||
| if alias.name == "*": | ||
| continue | ||
| imported_name = alias.asname if alias.asname else alias.name | ||
| self.imported_modules.add(imported_name) | ||
| if alias.asname: | ||
| self.alias_mapping[imported_name] = alias.name | ||
| self.generic_visit(node) |
There was a problem hiding this comment.
⚡️Codeflash found 363% (3.63x) speedup for InstanceMappingExtractor.visit_ImportFrom in codeflash/discovery/discover_unit_tests.py
⏱️ Runtime : 1.52 milliseconds → 329 microseconds (best of 129 runs)
📝 Explanation and details
The optimized code achieves a 362% speedup (from 1.52ms to 329μs) by eliminating an expensive, unnecessary call to generic_visit() that was consuming 80% of the original runtime.
Key Optimization
The original code unconditionally called self.generic_visit(node) at the end of visit_ImportFrom(), which triggers a full AST traversal of all child nodes. However, ImportFrom nodes only contain alias children that were already processed in the explicit for alias in node.names loop above.
The optimized version:
- Checks if a subclass has overridden
generic_visit()- If so, calls it to preserve custom traversal behavior - Otherwise, skips the heavy traversal entirely and only manually visits alias children if a
visit_alias()handler exists
This change is safe because:
- The
InstanceMappingExtractorclass doesn't overridegeneric_visit()or definevisit_alias() - All alias processing is already done explicitly in the loop
- The redundant traversal was performing no additional work beyond what the loop already accomplished
Performance Impact
The line profiler shows the generic_visit() call dropped from 15.6ms (80% of runtime) to essentially zero. Test results demonstrate consistent speedups across all scenarios:
- Simple imports: 165-225% faster
- Large-scale imports (1000+ aliases): 372-386% faster
- The optimization is most impactful for files with many import statements
The few test cases showing minor slowdowns (2-8%) involve trivial early-return paths (e.g., module=None) where the new conditional checks add overhead that exceeds the already-minimal original cost. This is a reasonable trade-off given the dramatic improvements in all realistic use cases.
Context
Since InstanceMappingExtractor is used for AST analysis (extracting import mappings for PyTorch module detection), it likely processes many Python files with numerous imports. This optimization significantly reduces the cost of analyzing import-heavy codebases.
✅ Correctness verification report:
| Test | Status |
|---|---|
| ⚙️ Existing Unit Tests | 🔘 None Found |
| 🌀 Generated Regression Tests | ✅ 55 Passed |
| ⏪ Replay Tests | 🔘 None Found |
| 🔎 Concolic Coverage Tests | 🔘 None Found |
| 📊 Tests Coverage | 100.0% |
🌀 Click to see Generated Regression Tests
import ast
import pytest # used for our unit tests
from codeflash.discovery.discover_unit_tests import InstanceMappingExtractor
def test_single_import_without_alias_adds_module_name_to_imported_modules():
# Create a visitor instance to test instance state changes.
extractor = InstanceMappingExtractor()
# Create an ImportFrom node: from mypkg import MyClass
node = ast.ImportFrom(module="mypkg", names=[ast.alias(name="MyClass", asname=None)], level=0)
# Call the specific visitor method directly as required.
extractor.visit_ImportFrom(node) # 4.57μs -> 1.48μs (208% faster)
def test_single_import_with_alias_creates_alias_mapping_and_imported_modules_entry():
# Create a fresh visitor instance.
extractor = InstanceMappingExtractor()
# Create an ImportFrom node with an alias: from pkg import RealName as AliasName
node = ast.ImportFrom(module="pkg", names=[ast.alias(name="RealName", asname="AliasName")], level=0)
# Invoke the method under test.
extractor.visit_ImportFrom(node) # 4.64μs -> 1.75μs (165% faster)
def test_star_import_is_ignored_and_does_not_modify_mappings():
extractor = InstanceMappingExtractor()
# from pkg import *
node = ast.ImportFrom(module="pkg", names=[ast.alias(name="*", asname=None)], level=0)
extractor.visit_ImportFrom(node) # 4.20μs -> 1.29μs (225% faster)
def test_none_module_is_ignored_no_changes_made():
extractor = InstanceMappingExtractor()
# An ImportFrom with module set to None should be ignored.
node = ast.ImportFrom(module=None, names=[ast.alias(name="Name", asname=None)], level=0)
extractor.visit_ImportFrom(node) # 461ns -> 471ns (2.12% slower)
def test_empty_names_list_does_not_raise_and_makes_no_changes():
extractor = InstanceMappingExtractor()
# from pkg import (no names) -> names list empty
node = ast.ImportFrom(module="pkg", names=[], level=0)
# Should simply return without error and without modifications.
extractor.visit_ImportFrom(node) # 2.25μs -> 1.24μs (81.6% faster)
def test_empty_asname_string_treated_as_no_alias():
extractor = InstanceMappingExtractor()
# asname is empty string: from pkg import Name as ""
# Empty string is falsy; the implementation checks `if alias.asname:`
node = ast.ImportFrom(module="pkg", names=[ast.alias(name="Name", asname="")], level=0)
extractor.visit_ImportFrom(node) # 4.54μs -> 1.63μs (178% faster)
def test_special_characters_in_names_and_aliases_handled_correctly():
extractor = InstanceMappingExtractor()
# Use names with underscores and digits and unusual but valid identifier-like strings.
aliases = [
ast.alias(name="Cls_1", asname=None),
ast.alias(name="RealName2", asname="alias_2"),
ast.alias(name="ÎnvalidUnicode", asname="alias_unicode"), # unicode inside Python identifier context
]
node = ast.ImportFrom(module="some_pkg", names=aliases, level=0)
extractor.visit_ImportFrom(node) # 7.10μs -> 2.33μs (204% faster)
def test_duplicate_aliases_do_not_raise_and_result_in_set_behavior_for_imported_modules():
extractor = InstanceMappingExtractor()
# Two identical asnames and real names repeated; set semantics means only one stored.
aliases = [
ast.alias(name="A", asname="X"),
ast.alias(name="A", asname="X"),
ast.alias(name="B", asname="Y"),
]
node = ast.ImportFrom(module="dup_pkg", names=aliases, level=0)
extractor.visit_ImportFrom(node) # 7.22μs -> 2.40μs (202% faster)
def test_large_scale_many_aliases_performance_and_correctness():
extractor = InstanceMappingExtractor()
# Construct 1000 aliases where every even one has an asname, odds do not.
num = 1000
names = []
for i in range(num):
real = f"Real{i}"
# Give half of them an alias, half not.
if i % 2 == 0:
# alias name will be Alias{i}
names.append(ast.alias(name=real, asname=f"Alias{i}"))
else:
names.append(ast.alias(name=real, asname=None))
node = ast.ImportFrom(module="big_pkg", names=names, level=0)
# Ensure this runs quickly and does not raise.
extractor.visit_ImportFrom(node) # 986μs -> 208μs (372% faster)
# Validate sizes: imported_modules should contain the aliased names (num/2) plus the non-aliased original names (num/2).
expected_imported = set()
for i in range(num):
if i % 2 == 0:
expected_imported.add(f"Alias{i}")
else:
expected_imported.add(f"Real{i}")
# alias_mapping should contain only the even indices mapping alias->real.
expected_alias_map = {f"Alias{i}": f"Real{i}" for i in range(0, num, 2)}
def test_mixed_star_and_regular_imports_large_scale():
extractor = InstanceMappingExtractor()
# Build a large list mixing star imports, aliased imports and non-aliased imports.
names = []
num = 500
for i in range(num):
if i % 10 == 0:
# include a star import occasionally
names.append(ast.alias(name="*", asname=None))
elif i % 3 == 0:
names.append(ast.alias(name=f"Real{i}", asname=f"Alias{i}"))
else:
names.append(ast.alias(name=f"Real{i}", asname=None))
node = ast.ImportFrom(module="mix_pkg", names=names, level=0)
extractor.visit_ImportFrom(node) # 480μs -> 98.9μs (386% faster)
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.import ast
# imports
import pytest
from codeflash.discovery.discover_unit_tests import InstanceMappingExtractor
def test_visit_import_from_single_import():
"""Test visiting a simple import statement with a single name."""
code = "from module import name"
tree = ast.parse(code)
extractor = InstanceMappingExtractor()
extractor.visit(tree)
def test_visit_import_from_with_alias():
"""Test visiting an import statement with an alias."""
code = "from module import name as alias_name"
tree = ast.parse(code)
extractor = InstanceMappingExtractor()
extractor.visit(tree)
def test_visit_import_from_multiple_names():
"""Test visiting an import statement with multiple names."""
code = "from module import name1, name2, name3"
tree = ast.parse(code)
extractor = InstanceMappingExtractor()
extractor.visit(tree)
def test_visit_import_from_mixed_aliases():
"""Test visiting an import statement with both aliased and non-aliased names."""
code = "from module import name1, name2 as alias2, name3"
tree = ast.parse(code)
extractor = InstanceMappingExtractor()
extractor.visit(tree)
def test_visit_import_from_star_import():
"""Test that star imports are skipped and not added to imported_modules."""
code = "from module import *"
tree = ast.parse(code)
extractor = InstanceMappingExtractor()
extractor.visit(tree)
def test_visit_import_from_no_module():
"""Test that ImportFrom with no module attribute is handled gracefully."""
# Create an ImportFrom node with module=None (relative import)
node = ast.ImportFrom(module=None, names=[ast.alias(name="name", asname=None)], level=1)
extractor = InstanceMappingExtractor()
# Should return None without error
codeflash_output = extractor.visit_ImportFrom(node); result = codeflash_output # 571ns -> 602ns (5.15% slower)
def test_visit_import_from_torch_nn_module():
"""Test a realistic PyTorch import scenario."""
code = "from torch import nn"
tree = ast.parse(code)
extractor = InstanceMappingExtractor()
extractor.visit(tree)
def test_visit_import_from_multiple_imports_same_module():
"""Test visiting multiple ImportFrom statements from different modules."""
code = """from module1 import name1
from module2 import name2 as alias2
from module3 import name3"""
tree = ast.parse(code)
extractor = InstanceMappingExtractor()
extractor.visit(tree)
def test_visit_import_from_empty_module_name():
"""Test ImportFrom with an empty string as module name."""
# Create an ImportFrom node with empty module
node = ast.ImportFrom(module="", names=[ast.alias(name="name", asname=None)], level=0)
extractor = InstanceMappingExtractor()
# Process the node
extractor.visit_ImportFrom(node) # 561ns -> 611ns (8.18% slower)
def test_visit_import_from_no_names():
"""Test ImportFrom with empty names list."""
node = ast.ImportFrom(module="module", names=[], level=0)
extractor = InstanceMappingExtractor()
# Process the node
extractor.visit_ImportFrom(node) # 2.52μs -> 1.23μs (104% faster)
def test_visit_import_from_special_characters_in_name():
"""Test import names with underscores and numbers."""
code = "from module import _private, __dunder__, name123"
tree = ast.parse(code)
extractor = InstanceMappingExtractor()
extractor.visit(tree)
def test_visit_import_from_long_module_path():
"""Test ImportFrom with a deeply nested module path."""
code = "from package.subpackage.module import name"
tree = ast.parse(code)
extractor = InstanceMappingExtractor()
extractor.visit(tree)
def test_visit_import_from_relative_import_with_level():
"""Test relative imports with various levels."""
code = "from . import name"
tree = ast.parse(code)
extractor = InstanceMappingExtractor()
extractor.visit(tree)
def test_visit_import_from_star_with_other_names():
"""Test that when star is mixed with other names, only star is skipped."""
# Create a node with star and another name
node = ast.ImportFrom(
module="module",
names=[ast.alias(name="*", asname=None), ast.alias(name="name", asname=None)],
level=0
)
extractor = InstanceMappingExtractor()
# Process the node
extractor.visit_ImportFrom(node) # 5.86μs -> 1.92μs (205% faster)
def test_visit_import_from_alias_overwrites_same_name():
"""Test that importing the same name twice with different aliases uses the latest."""
code = """from module import name as alias1
from module import name as alias2"""
tree = ast.parse(code)
extractor = InstanceMappingExtractor()
extractor.visit(tree)
def test_visit_import_from_single_char_names():
"""Test importing single character names."""
code = "from module import a, b, c"
tree = ast.parse(code)
extractor = InstanceMappingExtractor()
extractor.visit(tree)
def test_visit_import_from_very_long_names():
"""Test importing very long names."""
long_name = "a" * 100
code = f"from module import {long_name}"
tree = ast.parse(code)
extractor = InstanceMappingExtractor()
extractor.visit(tree)
def test_visit_import_from_alias_same_as_original():
"""Test when alias is the same as the original name."""
code = "from module import name as name"
tree = ast.parse(code)
extractor = InstanceMappingExtractor()
extractor.visit(tree)
def test_visit_import_from_preserves_state():
"""Test that the extractor preserves state across multiple visits."""
code1 = "from module1 import name1"
code2 = "from module2 import name2"
tree1 = ast.parse(code1)
tree2 = ast.parse(code2)
extractor = InstanceMappingExtractor()
extractor.visit(tree1)
extractor.visit(tree2)
def test_visit_import_from_many_imports_single_statement():
"""Test importing many names from a single module."""
# Generate a large import statement
names = [f"name{i}" for i in range(500)]
code = f"from module import {', '.join(names)}"
tree = ast.parse(code)
extractor = InstanceMappingExtractor()
extractor.visit(tree)
for i in range(500):
pass
def test_visit_import_from_many_imports_multiple_statements():
"""Test processing many separate ImportFrom statements."""
code_lines = [f"from module{i} import name{i}" for i in range(500)]
code = "\n".join(code_lines)
tree = ast.parse(code)
extractor = InstanceMappingExtractor()
extractor.visit(tree)
for i in range(500):
pass
def test_visit_import_from_many_aliases():
"""Test importing many names with aliases."""
names = [f"name{i} as alias{i}" for i in range(500)]
code = f"from module import {', '.join(names)}"
tree = ast.parse(code)
extractor = InstanceMappingExtractor()
extractor.visit(tree)
for i in range(500):
pass
def test_visit_import_from_complex_mixed_imports():
"""Test complex scenario with many mixed imports."""
code_lines = []
for i in range(250):
if i % 3 == 0:
# Non-aliased import
code_lines.append(f"from module{i} import name{i}")
elif i % 3 == 1:
# Aliased import
code_lines.append(f"from module{i} import name{i} as alias{i}")
else:
# Multiple imports
code_lines.append(f"from module{i} import name{i}a, name{i}b, name{i}c")
code = "\n".join(code_lines)
tree = ast.parse(code)
extractor = InstanceMappingExtractor()
extractor.visit(tree)
def test_visit_import_from_deeply_nested_code():
"""Test ImportFrom statements within nested code structures."""
code = """
def func():
if True:
from module1 import name1
for i in range(10):
from module2 import name2
try:
from module3 import name3
except:
pass
from module4 import name4
"""
tree = ast.parse(code)
extractor = InstanceMappingExtractor()
extractor.visit(tree)
def test_visit_import_from_instance_mapping_remains_empty():
"""Test that instance_mapping remains empty after visiting ImportFrom nodes."""
code = "from module import name1, name2 as alias2"
tree = ast.parse(code)
extractor = InstanceMappingExtractor()
extractor.visit(tree)
def test_visit_import_from_set_behavior_duplicates():
"""Test that imported_modules set prevents duplicates."""
# Manually create scenario where same name could be added twice
node1 = ast.ImportFrom(module="module1", names=[ast.alias(name="name", asname=None)], level=0)
node2 = ast.ImportFrom(module="module2", names=[ast.alias(name="name", asname=None)], level=0)
extractor = InstanceMappingExtractor()
extractor.visit_ImportFrom(node1) # 4.74μs -> 1.77μs (167% faster)
initial_size = len(extractor.imported_modules)
extractor.visit_ImportFrom(node2) # 2.77μs -> 892ns (211% faster)
final_size = len(extractor.imported_modules)
def test_visit_import_from_generic_visit_called():
"""Test that generic_visit is properly called for child traversal."""
# Create a nested AST structure that would require generic_visit
code = """
from module import (
name1,
name2,
name3
)
"""
tree = ast.parse(code)
extractor = InstanceMappingExtractor()
extractor.visit(tree)
def test_visit_import_from_return_value_is_none():
"""Test that visit_ImportFrom returns None (standard NodeVisitor behavior)."""
code = "from module import name"
tree = ast.parse(code)
# Get the ImportFrom node
for node in ast.walk(tree):
if isinstance(node, ast.ImportFrom):
extractor = InstanceMappingExtractor()
codeflash_output = extractor.visit_ImportFrom(node); result = codeflash_output
break
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.To test or edit this optimization locally git merge codeflash/optimize-pr1418-2026-02-17T13.07.22
Click to see suggested changes
| for alias in node.names: | |
| if alias.name == "*": | |
| continue | |
| imported_name = alias.asname if alias.asname else alias.name | |
| self.imported_modules.add(imported_name) | |
| if alias.asname: | |
| self.alias_mapping[imported_name] = alias.name | |
| self.generic_visit(node) | |
| # Preserve original behavior: if a subclass has overridden generic_visit, | |
| # call it to allow that custom traversal to run. Otherwise, avoid the | |
| # heavy generic traversal and only visit alias children if a specific | |
| # visit_alias handler exists. | |
| if getattr(type(self), "generic_visit", ast.NodeVisitor.generic_visit) is not ast.NodeVisitor.generic_visit: | |
| self.generic_visit(node) | |
| return | |
| has_visit_alias = hasattr(self, "visit_alias") | |
| for alias in node.names: | |
| if alias.name == "*": | |
| continue | |
| imported_name = alias.asname if alias.asname else alias.name | |
| self.imported_modules.add(imported_name) | |
| if alias.asname: | |
| self.alias_mapping[imported_name] = alias.name | |
| if has_visit_alias: | |
| # If a visitor for alias nodes exists, invoke it to match | |
| # the behavior of the default generic_visit which would have | |
| # dispatched to visit_alias. | |
| self.visit(alias) |
Summary
nn.Moduleforward method when called via instance (e.g.,model(input_data))model = ClassName(...); model(input_data)wheremodel(input_data)internally callsforward()forwardmethodsProblem
When running
codeflash --function AlexNet.forward, tests with this pattern weren't being instrumented:The instrumentation was looking for direct calls to
forwardorAlexNet, butmodel(input_data)matched neither.Solution
collect_instance_variables()to track variables assigned from class instantiationsfind_and_update_line_node()to wrap calls to instance variables when optimizingforwardmethodsTest plan
test_pytorch_forward_method_instrumentationtest case🤖 Generated with Claude Code