Skip to content

Comments

⚡️ Speed up function _add_timing_instrumentation by 16% in PR #1580 (fix/java-direct-jvm-and-bugs)#1595

Merged
claude[bot] merged 2 commits intofix/java-direct-jvm-and-bugsfrom
codeflash/optimize-pr1580-2026-02-20T09.26.51
Feb 20, 2026
Merged

⚡️ Speed up function _add_timing_instrumentation by 16% in PR #1580 (fix/java-direct-jvm-and-bugs)#1595
claude[bot] merged 2 commits intofix/java-direct-jvm-and-bugsfrom
codeflash/optimize-pr1580-2026-02-20T09.26.51

Conversation

@codeflash-ai
Copy link
Contributor

@codeflash-ai codeflash-ai bot commented Feb 20, 2026

⚡️ This pull request contains optimizations for PR #1580

If you approve this dependent PR, these changes will be merged into the original PR branch fix/java-direct-jvm-and-bugs.

This PR will be automatically closed if the original PR is merged.


📄 16% (0.16x) speedup for _add_timing_instrumentation in codeflash/languages/java/instrumentation.py

⏱️ Runtime : 10.2 milliseconds 8.81 milliseconds (best of 250 runs)

📝 Explanation and details

This optimization achieves a 15% runtime improvement (10.2ms → 8.81ms) by replacing recursive AST traversal with iterative stack-based traversal in two critical functions: collect_test_methods and collect_target_calls.

Key Changes

1. Iterative AST Traversal (Primary Speedup)

  • Replaced recursive tree walking with explicit stack-based iteration
  • In collect_test_methods: Changed from recursive calls to while stack loop with stack.extend(reversed(current.children))
  • In collect_target_calls: Similar transformation using explicit stack management
  • Impact: Line profiler shows collect_test_methods dropped from 24.2% to 3.8% of total runtime (81% reduction in that function)

2. Why This Works in Python

  • Python function calls have significant overhead (frame creation, argument binding, scope setup)
  • Recursive traversal compounds this overhead across potentially deep AST trees
  • Iterative approach uses a simple list for the stack, avoiding repeated function call overhead
  • The reversed() call ensures children are processed in the same order as recursive traversal, preserving correctness

3. Performance Characteristics
Based on annotated tests:

  • Large method bodies (500+ lines): 23.8% faster - most benefit from reduced recursion overhead
  • Many test methods (100 methods): 9.2% faster - cumulative savings across many traversals
  • Simple cases: 2-5% faster - overhead reduction still measurable
  • Empty/no-match cases: Minor regression (8-9% slower) due to negligible baseline times (12-40μs)

Impact on Workloads

The function references show _add_timing_instrumentation is called from test instrumentation code. This optimization particularly benefits:

  • Java projects with large test suites containing many @Test methods
  • Complex test methods with deep AST structures and multiple method invocations
  • Batch instrumentation operations where the function is called repeatedly

The iterative approach scales better than recursion as AST depth and method count increase, making it especially valuable for large Java codebases where instrumentation is applied across hundreds of test methods.

Correctness verification report:

Test Status
⚙️ Existing Unit Tests 🔘 None Found
🌀 Generated Regression Tests 9 Passed
⏪ Replay Tests 🔘 None Found
🔎 Concolic Coverage Tests 🔘 None Found
📊 Tests Coverage 100.0%
🌀 Click to see Generated Regression Tests
import types
from typing import List

# imports
import pytest  # used for our unit tests
# Import the function under test
from codeflash.languages.java import instrumentation as instrumentation_mod
from codeflash.languages.java.instrumentation import \
    _add_timing_instrumentation

# We will provide a lightweight, deterministic "analyzer" implementation for the
# instrumentation module to use during tests. This analyzer mimics just enough of
# the tree-sitter-like API expected by the instrumentation code so tests can run
# deterministically without requiring a real tree-sitter Java grammar build.
#
# NOTE: This test-created analyzer uses real Python classes defined here to shape
# nodes. The instrumentation module will call get_java_analyzer(), so we monkeypatch
# that function to return an instance of our TestAnalyzer. This keeps the module
# under test unchanged while providing a controllable parse implementation.
#
# The TestAnalyzer.parse method inspects the provided bytes and constructs a simple
# AST with:
# - a root_node with children
# - method_declaration nodes for methods annotated with '@Test'
# - method bodies with start/end byte offsets that align with source bytes
# - for wrapper parses (the instrumentation wrapper prefix is used), it produces a
#   method_declaration whose body contains method_invocation nodes for occurrences
#   of the requested function name.
#
# The instrumentation code also calls analyzer.get_node_text(node, source_bytes),
# which we implement to return the substring of the original bytes for the node's
# start/end byte offsets. start_point on the method_declaration nodes is provided
# so the instrumentation can compute indentation.

# Helper lightweight AST node used by our TestAnalyzer.
class SimpleNode:
    def __init__(
        self,
        type_: str,
        start_byte: int = 0,
        end_byte: int = 0,
        start_point: tuple = (0, 0),
    ):
        self.type = type_
        self.children: List[SimpleNode] = []
        # For mapping field names like "body" or "name" to child nodes
        self._field_children: dict = {}
        self.parent = None
        self.start_byte = start_byte
        self.end_byte = end_byte
        self.start_point = start_point
        # end_point not used by instrumentation, so omitted

    def add_child(self, child: "SimpleNode", field_name: str | None = None):
        child.parent = self
        self.children.append(child)
        if field_name:
            self._field_children[field_name] = child

    def child_by_field_name(self, name: str):
        return self._field_children.get(name)

    # Representational convenience
    def __repr__(self):
        return f"<SimpleNode type={self.type} [{self.start_byte}:{self.end_byte}]>"

class TestAnalyzer:
    def __init__(self):
        # store last parsed bytes in case get_node_text needs them easily
        self._last_bytes = b""
        # We will honor wrapper bytes constructed by instrumentation:
        # prefix + body_bytes + suffix. The instrumentation module provides the
        # prefix bytes as _TS_BODY_PREFIX_BYTES; we import it for alignment.
        from codeflash.languages.java.instrumentation import \
            _TS_BODY_PREFIX_BYTES

        self._prefix = _TS_BODY_PREFIX_BYTES

    def parse(self, source_bytes: bytes):
        """Return a fake 'Tree' object with root_node attribute.

        For original source files (not wrapped), we scan for methods annotated
        with '@Test' and create method_declaration nodes with a body child whose
        byte offsets correspond to the braces in the source.

        For wrapper bytes (starting with the special prefix), we construct a
        single method_declaration containing method_invocation nodes for each
        occurrence of the target function name (we do not need to know func_name
        here; the instrumentation will inspect node text via get_node_text).
        """
        self._last_bytes = source_bytes

        # Simple container to mimic a tree_sitter Tree with a root_node attribute.
        class FakeTree:
            def __init__(self, root_node):
                self.root_node = root_node

        # If the bytes begin with the wrapper prefix, it's the wrapper parse call.
        if source_bytes.startswith(self._prefix):
            # For wrapper, create a single method_declaration node whose body covers
            # the wrapped body region (prefix .. suffix). The instrumentation expects
            # to find method_invocation nodes within that body; we'll create one
            # method_invocation node for every occurrence of a Java identifier-like
            # token in the body (we'll rely on the instrumentation's func_name match
            # to select the right ones).
            root = SimpleNode("program", 0, len(source_bytes))
            # Find the first occurrence of "void _m()" in the wrapper to approximate method location
            start = 0
            # Find wrapper body braces to set body node offsets; body is between the
            # prefix and suffix in wrapper_bytes
            prefix_len = len(self._prefix)
            suffix_bytes = b"}}"
            try:
                suffix_index = source_bytes.rfind(suffix_bytes)
                body_start = prefix_len
                body_end = suffix_index if suffix_index != -1 else len(source_bytes) - len(suffix_bytes)
            except Exception:
                body_start = prefix_len
                body_end = len(source_bytes)
            method_node = SimpleNode("method_declaration", start, len(source_bytes), (0, 0))
            body_node = SimpleNode("block", body_start, body_end, (0, 0))
            method_node.add_child(body_node, field_name="body")
            # Identify candidate method_invocation spans by searching for '(' which likely follows names,
            # but create invocation nodes for any contiguous identifier characters followed by '('
            import re

            # We'll create invocation nodes for any pattern like funcName( in the body.
            for m in re.finditer(rb"([A-Za-z_][A-Za-z0-9_]*)\s*\(", source_bytes[body_start:body_end]):
                # Compute occurrence bytes relative to wrapper bytes
                name = m.group(1)
                name_start = body_start + m.start(1)
                name_end = body_start + m.end(1)
                # Create a method_invocation node spanning the entire call (approx until ')')
                # For simplicity, span until the next ')' after the '('
                paren_index = body_start + m.end(0) - 1  # index of '(' relative to wrapper
                # find corresponding ')'
                closing = source_bytes.find(b")", paren_index)
                if closing == -1:
                    closing = paren_index + 1
                call_node = SimpleNode("method_invocation", name_start, closing + 1, (0, 0))
                # name child of invocation
                name_node = SimpleNode("identifier", name_start, name_end, (0, 0))
                call_node.add_child(name_node, field_name="name")
                body_node.add_child(call_node)
            root.add_child(method_node)
            return FakeTree(root)

        # Otherwise, it's a parse of a normal source file.
        text = source_bytes.decode("utf8")
        root = SimpleNode("program", 0, len(source_bytes))
        # Very naive parsing:
        # Find class boundaries and methods annotated with '@Test'
        import re

        # Find all '@Test' occurrences and then find the subsequent method body braces.
        for m in re.finditer(r"@Test\b", text):
            ann_start = len(text[: m.start()].encode("utf8"))
            ann_end = len(text[: m.end()].encode("utf8"))
            # Build modifiers node containing the annotation
            modifiers_node = SimpleNode("modifiers", ann_start, ann_end, (0, 0))
            ann_node = SimpleNode("annotation", ann_start, ann_end, (0, 0))
            modifiers_node.add_child(ann_node)
            # Heuristically find the method declaration that follows the annotation:
            # Look ahead for the first '{' after the annotation to mark the body.
            brace_index = text.find("{", m.end())
            if brace_index == -1:
                continue
            # Find matching closing brace for this method body by naive counting.
            idx = brace_index
            depth = 0
            end_index = -1
            while idx < len(text):
                if text[idx] == "{":
                    depth += 1
                elif text[idx] == "}":
                    depth -= 1
                    if depth == 0:
                        end_index = idx
                        break
                idx += 1
            if end_index == -1:
                continue
            # Determine byte offsets
            method_start_byte = len(text[: m.start()].encode("utf8"))
            method_end_byte = len(text[: end_index + 1].encode("utf8"))
            # Provide a start_point column: count characters on the line where the method starts
            # We'll set column to 4 for tests as our sample sources use that indentation
            method_node = SimpleNode("method_declaration", method_start_byte, method_end_byte, (0, 4))
            # Body node excludes the outermost braces in instrumentation code: instrumentation uses
            # body_node.start_byte + 1 and end_byte -1 to slice body contents. So we set body_node to cover braces.
            body_open_byte = len(text[: brace_index].encode("utf8"))
            body_close_byte = len(text[: end_index].encode("utf8")) + 1  # include closing brace
            body_node = SimpleNode("block", body_open_byte, body_close_byte, (0, 0))
            method_node.add_child(modifiers_node)
            method_node.add_child(body_node, field_name="body")
            # Add the method node under root (we ignore class nodes for simplicity)
            root.add_child(method_node)
        return FakeTree(root)

    def get_node_text(self, node: SimpleNode, source_bytes: bytes) -> str:
        """Return the decoded text covered by the node's byte span."""
        # Guard: if node doesn't have start_byte/end_byte, return empty string
        try:
            return source_bytes[node.start_byte:node.end_byte].decode("utf8")
        except Exception:
            return ""

def test_no_test_methods_returns_source():
    # Basic: if there are no @Test methods, the original source should be returned unchanged.
    source = """public class A {
    public void helper() {
        // no tests here
    }
}
"""
    # Call the function under test; since our analyzer finds no '@Test', it should return the identical source.
    codeflash_output = _add_timing_instrumentation(source, "A", "target"); out = codeflash_output # 37.0μs -> 40.4μs (8.53% slower)

def test_single_test_method_instrumentation_inserts_loop_and_markers():
    # Single @Test method containing a call to target(); instrumentation should insert the inner loop
    source = """public class MyTest {
    @Test
    public void testOne() {
        target();
    }
}
"""
    # instrument
    codeflash_output = _add_timing_instrumentation(source, "MyTest", "target"); out = codeflash_output # 98.1μs -> 95.8μs (2.38% faster)

def test_variable_declaration_hoisted_for_assignment():
    # When the target() call appears inside a local variable declaration like:
    # int len = target();
    # The instrumentation should hoist "int len = 0;" and replace the original with "len = target();"
    source = """public class HoistTest {
    @Test
    public void testLen() {
        int len = target();
        int after = len;
    }
}
"""
    codeflash_output = _add_timing_instrumentation(source, "HoistTest", "target"); out = codeflash_output # 123μs -> 118μs (4.09% faster)

def test_multiple_calls_in_single_method_produces_multiple_wrappers():
    # A test method with two separate statements each invoking target() should produce two timing wrappers
    source = """public class MultiCallTest {
    @Test
    public void testBoth() {
        helper();
        target();
        another();
        target();
    }
}
"""
    codeflash_output = _add_timing_instrumentation(source, "MultiCallTest", "target"); out = codeflash_output # 133μs -> 126μs (5.62% faster)
    # Expect instrumentation comment to appear twice (once per targeted statement)
    count_comments = out.count("// Codeflash timing instrumentation with inner loop for JIT warmup")

def test_empty_source_returns_empty_string():
    # Edge: empty source should return empty string and not crash
    codeflash_output = _add_timing_instrumentation("", "C", "f"); out = codeflash_output # 12.1μs -> 13.4μs (9.42% slower)

def test_unicode_characters_preserved_and_instrumented():
    # Source contains non-ASCII characters inside strings; instrumentation must handle UTF-8 correctly
    source = """public class UnicodeTest {
    @Test
    public void testUnicode() {
        String s = "é世";
        target();
    }
}
"""
    codeflash_output = _add_timing_instrumentation(source, "UnicodeTest", "target"); out = codeflash_output # 118μs -> 114μs (3.61% faster)

def test_class_and_function_names_with_special_characters():
    # Use class and function names with underscores and digits to verify instrumentation string injection works
    source = """public class C_123 {
    @Test
    public void test_fn_1() {
        target_1();
    }
}
"""
    codeflash_output = _add_timing_instrumentation(source, "C_123", "target_1"); out = codeflash_output # 92.2μs -> 89.9μs (2.54% faster)

def test_large_number_of_test_methods_instrumented():
    # Build a large source with many @Test methods, each calling target()
    n = 100  # within the requested up-to-1000; kept smaller for test runtime
    methods = []
    for i in range(n):
        methods.append(f"    @Test\n    public void test{i}() {{\n        target();\n    }}\n")
    source = "public class BigTest {\n" + "".join(methods) + "}\n"
    codeflash_output = _add_timing_instrumentation(source, "BigTest", "target"); out = codeflash_output # 4.45ms -> 4.08ms (9.18% faster)
    # Expect the instrumentation comment to appear n times (one per targeted statement/method)
    count_comments = out.count("// Codeflash timing instrumentation with inner loop for JIT warmup")

def test_large_inner_iterations_handling_string_size():
    # Stress the reindentation and string handling by creating a single test method with deep content
    # Build a large method body with many lines and a single target() near the end.
    body_lines = ["        int x = 0;"] * 500  # many lines to stress reindent_block
    body_lines.append("        target();")
    body_text = "\n".join(body_lines)
    source = f"""public class StressTest {{
    @Test
    public void longMethod() {{
{body_text}
    }}
}}
"""
    codeflash_output = _add_timing_instrumentation(source, "StressTest", "target"); out = codeflash_output # 5.11ms -> 4.13ms (23.8% faster)
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.

To edit these changes git checkout codeflash/optimize-pr1580-2026-02-20T09.26.51 and push.

Codeflash Static Badge

This optimization achieves a **15% runtime improvement** (10.2ms → 8.81ms) by replacing recursive AST traversal with iterative stack-based traversal in two critical functions: `collect_test_methods` and `collect_target_calls`.

## Key Changes

**1. Iterative AST Traversal (Primary Speedup)**
- Replaced recursive tree walking with explicit stack-based iteration
- In `collect_test_methods`: Changed from recursive calls to `while stack` loop with `stack.extend(reversed(current.children))`
- In `collect_target_calls`: Similar transformation using explicit stack management
- **Impact**: Line profiler shows `collect_test_methods` dropped from 24.2% to 3.8% of total runtime (81% reduction in that function)

**2. Why This Works in Python**
- Python function calls have significant overhead (frame creation, argument binding, scope setup)
- Recursive traversal compounds this overhead across potentially deep AST trees
- Iterative approach uses a simple list for the stack, avoiding repeated function call overhead
- The `reversed()` call ensures children are processed in the same order as recursive traversal, preserving correctness

**3. Performance Characteristics**
Based on annotated tests:
- **Large method bodies** (500+ lines): 23.8% faster - most benefit from reduced recursion overhead
- **Many test methods** (100 methods): 9.2% faster - cumulative savings across many traversals
- **Simple cases**: 2-5% faster - overhead reduction still measurable
- **Empty/no-match cases**: Minor regression (8-9% slower) due to negligible baseline times (12-40μs)

## Impact on Workloads

The function references show `_add_timing_instrumentation` is called from test instrumentation code. This optimization particularly benefits:
- **Java projects with large test suites** containing many `@Test` methods
- **Complex test methods** with deep AST structures and multiple method invocations
- **Batch instrumentation operations** where the function is called repeatedly

The iterative approach scales better than recursion as AST depth and method count increase, making it especially valuable for large Java codebases where instrumentation is applied across hundreds of test methods.
@codeflash-ai codeflash-ai bot added ⚡️ codeflash Optimization PR opened by Codeflash AI 🎯 Quality: High Optimization Quality according to Codeflash labels Feb 20, 2026
@claude
Copy link
Contributor

claude bot commented Feb 20, 2026

PR Review Summary

Prek Checks

Passed — Auto-fixed 2 extra blank lines in instrumentation.py (formatting). Committed and pushed as af617e6b.

Mypy

Passed — No type errors in codeflash/languages/java/instrumentation.py.

Code Review

No critical issues found.

The change converts collect_test_methods and collect_target_calls from recursive to iterative (stack-based) AST traversal. This is a correct transformation:

  • collect_test_methods: Uses continue after matching a @Test method, skipping child exploration. This is safe — Java doesn't support nested method declarations, so no @Test methods can exist inside another method's body.
  • collect_target_calls: Always explores children after checking for method_invocation, preserving the original recursive behavior exactly.
  • Traversal order is preserved via reversed(current.children) with stack.pop().

Test Coverage

File Stmts Miss Coverage
codeflash/languages/java/instrumentation.py 515 92 82%

Changed lines (729-753): 24/25 lines covered ✅

  • Only line 751 (logger.debug for lambda/complex expression skip path) is uncovered — this is a minor edge case.
  • All core iterative traversal logic (stack operations, type checks, child exploration) is fully covered by existing tests.

Overall project coverage: 79% (49,207 / 61,967 statements)


Last updated: 2026-02-20

@claude claude bot merged commit ae1c03d into fix/java-direct-jvm-and-bugs Feb 20, 2026
23 of 30 checks passed
@claude claude bot deleted the codeflash/optimize-pr1580-2026-02-20T09.26.51 branch February 20, 2026 12:46
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

⚡️ codeflash Optimization PR opened by Codeflash AI 🎯 Quality: High Optimization Quality according to Codeflash

Projects

None yet

Development

Successfully merging this pull request may close these issues.

0 participants