Skip to content

Comments

⚡️ Speed up function _create_cpu_timing_try_body by 25% in PR #1335 (gpu-flag)#1344

Closed
codeflash-ai[bot] wants to merge 5 commits intogpu-flagfrom
codeflash/optimize-pr1335-2026-02-04T00.13.36
Closed

⚡️ Speed up function _create_cpu_timing_try_body by 25% in PR #1335 (gpu-flag)#1344
codeflash-ai[bot] wants to merge 5 commits intogpu-flagfrom
codeflash/optimize-pr1335-2026-02-04T00.13.36

Conversation

@codeflash-ai
Copy link
Contributor

@codeflash-ai codeflash-ai bot commented Feb 4, 2026

⚡️ This pull request contains optimizations for PR #1335

If you approve this dependent PR, these changes will be merged into the original PR branch gpu-flag.

This PR will be automatically closed if the original PR is merged.


📄 25% (0.25x) speedup for _create_cpu_timing_try_body in codeflash/code_utils/instrument_existing_tests.py

⏱️ Runtime : 1.19 milliseconds 952 microseconds (best of 250 runs)

📝 Explanation and details

The optimization achieves a 25% speedup (1.19ms → 952μs) by eliminating redundant AST node construction through two key strategies:

Primary Optimization: LRU Caching of AST Structures

The code extracts framework-specific AST generation into separate cached functions (_create_torch_sync_ast, _create_jax_sync_ast, _create_tf_sync_ast) decorated with @lru_cache(maxsize=32). This is highly effective because:

  1. Eliminates Repeated Construction: The line profiler shows the original code spending significant time constructing identical AST nodes on every call. For example, the PyTorch sync statement construction (ast.If, nested ast.Attribute, ast.Call, etc.) took ~791μs for just the MPS test name creation alone. With caching, these structures are built once per framework alias and reused.

  2. Dramatic Per-Call Speedup: Tests with frameworks show the most significant improvements:

    • test_with_torch_framework: 22.0μs → 11.1μs (98.5% faster)
    • test_with_multiple_frameworks: 30.5μs → 11.4μs (168% faster)
    • test_with_tensorflow_framework: 17.3μs → 10.6μs (62.4% faster)
  3. Cumulative Benefits: In test_multiple_consecutive_calls (100 iterations), the speedup is 710μs → 651μs (9.1%), showing consistent cache hits across repeated invocations.

Secondary Optimization: Shared Context Objects

The code pre-creates _LOAD_CTX and _STORE_CTX as module-level constants, reusing the same ast.Load() and ast.Store() instances throughout. This reduces object allocation overhead, particularly visible in _create_cpu_timing_try_body where context objects are used 10+ times per call.

Performance Impact

The line profiler confirms _create_device_sync_statements total time drops from 2.28ms to 0.56ms (75% reduction). The caching is especially beneficial when the same framework configurations are used repeatedly, which is typical in test instrumentation scenarios where the same frameworks are synchronized across many test cases. Tests without frameworks show modest 5-10% gains (context object reuse), while framework-heavy tests show 60-168% improvements (cache hits on AST structures).

Correctness verification report:

Test Status
⚙️ Existing Unit Tests 🔘 None Found
🌀 Generated Regression Tests 130 Passed
⏪ Replay Tests 🔘 None Found
🔎 Concolic Coverage Tests 🔘 None Found
📊 Tests Coverage 100.0%
🌀 Click to see Generated Regression Tests
import ast

import pytest
from codeflash.code_utils.instrument_existing_tests import \
    _create_cpu_timing_try_body

class TestCreateCpuTimingTryBodyBasic:
    """Basic test cases for _create_cpu_timing_try_body function."""

    def test_returns_list_of_ast_statements(self):
        """Test that the function returns a list of AST statements."""
        codeflash_output = _create_cpu_timing_try_body(None); result = codeflash_output # 11.3μs -> 10.7μs (5.69% faster)

    def test_with_none_frameworks(self):
        """Test that the function works with None as input (no frameworks used)."""
        codeflash_output = _create_cpu_timing_try_body(None); result = codeflash_output # 10.6μs -> 10.1μs (4.77% faster)
        
    def test_with_empty_frameworks_dict(self):
        """Test that the function works with an empty dictionary."""
        codeflash_output = _create_cpu_timing_try_body({}); result = codeflash_output # 10.7μs -> 9.86μs (8.74% faster)

    def test_contains_counter_assignment(self):
        """Test that the result contains a counter assignment statement."""
        codeflash_output = _create_cpu_timing_try_body(None); result = codeflash_output # 10.7μs -> 9.73μs (9.58% faster)
        # Find the assignment to 'counter'
        counter_assignments = [
            stmt for stmt in result
            if isinstance(stmt, ast.Assign)
            and any(isinstance(target, ast.Name) and target.id == "counter" for target in stmt.targets)
        ]
        assignment = counter_assignments[0]

    def test_contains_return_value_assignment(self):
        """Test that the result contains a return_value assignment."""
        codeflash_output = _create_cpu_timing_try_body(None); result = codeflash_output # 10.5μs -> 9.58μs (9.72% faster)
        # Find the assignment to 'return_value'
        return_assignments = [
            stmt for stmt in result
            if isinstance(stmt, ast.Assign)
            and any(isinstance(target, ast.Name) and target.id == "return_value" for target in stmt.targets)
        ]
        assignment = return_assignments[0]

    def test_contains_duration_assignment(self):
        """Test that the result contains a codeflash_duration assignment."""
        codeflash_output = _create_cpu_timing_try_body(None); result = codeflash_output # 10.6μs -> 9.65μs (9.75% faster)
        # Find the assignment to 'codeflash_duration'
        duration_assignments = [
            stmt for stmt in result
            if isinstance(stmt, ast.Assign)
            and any(isinstance(target, ast.Name) and target.id == "codeflash_duration" for target in stmt.targets)
        ]
        assignment = duration_assignments[0]

class TestCreateCpuTimingTryBodyWithFrameworks:
    """Test cases for _create_cpu_timing_try_body with various framework combinations."""

    def test_with_torch_framework(self):
        """Test that torch framework generates appropriate sync statements."""
        frameworks = {"torch": "torch"}
        codeflash_output = _create_cpu_timing_try_body(frameworks); result = codeflash_output # 22.0μs -> 11.1μs (98.5% faster)
        # Should include device sync statements
        if_statements = [stmt for stmt in result if isinstance(stmt, ast.If)]
        
    def test_with_tensorflow_framework(self):
        """Test that tensorflow framework generates appropriate sync statements."""
        frameworks = {"tensorflow": "tf"}
        codeflash_output = _create_cpu_timing_try_body(frameworks); result = codeflash_output # 17.3μs -> 10.6μs (62.4% faster)
        if_statements = [stmt for stmt in result if isinstance(stmt, ast.If)]

    def test_with_jax_framework(self):
        """Test that jax framework generates appropriate sync statements."""
        frameworks = {"jax": "jax"}
        codeflash_output = _create_cpu_timing_try_body(frameworks); result = codeflash_output # 13.9μs -> 10.6μs (30.6% faster)
        # JAX sync should be generated (after function call for return_value)
        if_statements = [stmt for stmt in result if isinstance(stmt, ast.If)]

    def test_with_multiple_frameworks(self):
        """Test that multiple frameworks generate all necessary sync statements."""
        frameworks = {"torch": "torch", "tensorflow": "tf", "jax": "jax"}
        codeflash_output = _create_cpu_timing_try_body(frameworks); result = codeflash_output # 30.5μs -> 11.4μs (168% faster)
        if_statements = [stmt for stmt in result if isinstance(stmt, ast.If)]

    def test_with_custom_torch_alias(self):
        """Test that custom torch alias is used in generated code."""
        frameworks = {"torch": "th"}
        codeflash_output = _create_cpu_timing_try_body(frameworks); result = codeflash_output # 21.5μs -> 10.7μs (101% faster)
        # Verify the alias is used somewhere in the generated code
        code_str = ast.unparse(ast.Module(body=result, type_ignores=[]))

    def test_with_custom_tensorflow_alias(self):
        """Test that custom tensorflow alias is used in generated code."""
        frameworks = {"tensorflow": "tensorflow_alias"}
        codeflash_output = _create_cpu_timing_try_body(frameworks); result = codeflash_output # 17.5μs -> 10.7μs (64.0% faster)
        code_str = ast.unparse(ast.Module(body=result, type_ignores=[]))

    def test_with_custom_jax_alias(self):
        """Test that custom jax alias is used in generated code."""
        frameworks = {"jax": "jx"}
        codeflash_output = _create_cpu_timing_try_body(frameworks); result = codeflash_output # 14.3μs -> 10.5μs (35.6% faster)
        code_str = ast.unparse(ast.Module(body=result, type_ignores=[]))

class TestCreateCpuTimingTryBodyEdgeCases:
    """Edge case tests for _create_cpu_timing_try_body function."""

    def test_frameworks_with_extra_unknown_framework(self):
        """Test that unknown frameworks in dict don't cause errors."""
        frameworks = {"torch": "torch", "unknown_framework": "unk"}
        codeflash_output = _create_cpu_timing_try_body(frameworks); result = codeflash_output # 21.8μs -> 10.5μs (107% faster)

    def test_frameworks_dict_with_empty_string_alias(self):
        """Test handling of empty string as framework alias."""
        frameworks = {"torch": ""}
        codeflash_output = _create_cpu_timing_try_body(frameworks); result = codeflash_output # 21.5μs -> 10.8μs (100% faster)

    def test_statement_order_preservation(self):
        """Test that the order of statements is correct: sync, counter, call, sync, duration."""
        codeflash_output = _create_cpu_timing_try_body(None); result = codeflash_output # 10.4μs -> 9.76μs (6.16% faster)
        # Extract meaningful statement indices
        counter_idx = None
        return_idx = None
        duration_idx = None
        
        for i, stmt in enumerate(result):
            if isinstance(stmt, ast.Assign):
                target_id = stmt.targets[0].id if isinstance(stmt.targets[0], ast.Name) else None
                if target_id == "counter":
                    counter_idx = i
                elif target_id == "return_value":
                    return_idx = i
                elif target_id == "codeflash_duration":
                    duration_idx = i

    def test_all_statements_have_ast_types(self):
        """Test that all returned items are valid AST statement types."""
        frameworks = {"torch": "torch", "tensorflow": "tf", "jax": "jax"}
        codeflash_output = _create_cpu_timing_try_body(frameworks); result = codeflash_output # 30.5μs -> 11.5μs (165% faster)
        for stmt in result:
            pass

    def test_perf_counter_ns_calls(self):
        """Test that perf_counter_ns is called at least twice (before and after)."""
        codeflash_output = _create_cpu_timing_try_body(None); result = codeflash_output # 10.5μs -> 9.71μs (8.04% faster)
        code_str = ast.unparse(ast.Module(body=result, type_ignores=[]))
        # Count occurrences of perf_counter_ns
        count = code_str.count("perf_counter_ns()")

    def test_duration_calculation_uses_subtraction(self):
        """Test that duration is calculated using subtraction."""
        codeflash_output = _create_cpu_timing_try_body(None); result = codeflash_output # 10.5μs -> 9.82μs (7.43% faster)
        duration_assignments = [
            stmt for stmt in result
            if isinstance(stmt, ast.Assign)
            and any(isinstance(target, ast.Name) and target.id == "codeflash_duration" for target in stmt.targets)
        ]
        assignment = duration_assignments[0]

    def test_codeflash_wrapped_call_includes_args(self):
        """Test that codeflash_wrapped call unpacks *args."""
        codeflash_output = _create_cpu_timing_try_body(None); result = codeflash_output # 10.4μs -> 9.72μs (6.81% faster)
        return_assignments = [
            stmt for stmt in result
            if isinstance(stmt, ast.Assign)
            and any(isinstance(target, ast.Name) and target.id == "return_value" for target in stmt.targets)
        ]
        assignment = return_assignments[0]
        call = assignment.value

    def test_codeflash_wrapped_call_includes_kwargs(self):
        """Test that codeflash_wrapped call includes **kwargs."""
        codeflash_output = _create_cpu_timing_try_body(None); result = codeflash_output # 10.4μs -> 9.64μs (7.80% faster)
        return_assignments = [
            stmt for stmt in result
            if isinstance(stmt, ast.Assign)
            and any(isinstance(target, ast.Name) and target.id == "return_value" for target in stmt.targets)
        ]
        assignment = return_assignments[0]
        call = assignment.value

class TestCreateCpuTimingTryBodyLargeScale:
    """Large scale and performance tests for _create_cpu_timing_try_body."""

    def test_many_frameworks_in_dict(self):
        """Test with a large number of framework entries (even if most are unknown)."""
        # Create a dict with many entries
        frameworks = {f"framework_{i}": f"alias_{i}" for i in range(100)}
        # Add known frameworks
        frameworks["torch"] = "torch"
        frameworks["tensorflow"] = "tf"
        frameworks["jax"] = "jax"
        
        codeflash_output = _create_cpu_timing_try_body(frameworks); result = codeflash_output # 31.1μs -> 12.0μs (159% faster)

    def test_code_generation_performance(self):
        """Test that code generation is fast even with many frameworks."""
        frameworks = {f"torch_{i}" if i % 3 == 0 else f"tf_{i}" if i % 3 == 1 else f"jax_{i}": f"alias_{i}" for i in range(50)}
        # Add actual framework names for ones that should be recognized
        frameworks["torch"] = "torch"
        frameworks["tensorflow"] = "tf"
        frameworks["jax"] = "jax"
        
        codeflash_output = _create_cpu_timing_try_body(frameworks); result = codeflash_output # 30.3μs -> 11.7μs (158% faster)

    def test_unparse_large_ast(self):
        """Test that the generated AST can be unparsed to valid code."""
        frameworks = {"torch": "torch", "tensorflow": "tf", "jax": "jax"}
        codeflash_output = _create_cpu_timing_try_body(frameworks); result = codeflash_output # 30.0μs -> 11.2μs (167% faster)
        # This should not raise any exception
        code_str = ast.unparse(ast.Module(body=result, type_ignores=[]))

    def test_multiple_consecutive_calls(self):
        """Test multiple consecutive calls to the function."""
        for _ in range(100):
            codeflash_output = _create_cpu_timing_try_body(None); result = codeflash_output # 710μs -> 651μs (9.10% faster)

    def test_statement_consistency_across_calls(self):
        """Test that multiple calls produce structurally identical results."""
        frameworks = {"torch": "torch", "tensorflow": "tf"}
        results = [_create_cpu_timing_try_body(frameworks) for _ in range(10)]
        
        # All should contain the same key assignments
        for result in results:
            assignments = [
                stmt for stmt in result
                if isinstance(stmt, ast.Assign)
            ]
            target_names = set()
            for assign in assignments:
                for target in assign.targets:
                    if isinstance(target, ast.Name):
                        target_names.add(target.id)

class TestCreateCpuTimingTryBodyASTValidation:
    """Test cases that validate the AST structure in detail."""

    def test_return_value_call_has_correct_structure(self):
        """Test that the return_value call has the correct AST structure."""
        codeflash_output = _create_cpu_timing_try_body(None); result = codeflash_output # 10.7μs -> 9.72μs (10.4% faster)
        return_assignments = [
            stmt for stmt in result
            if isinstance(stmt, ast.Assign)
            and any(isinstance(target, ast.Name) and target.id == "return_value" for target in stmt.targets)
        ]
        assignment = return_assignments[0]
        call = assignment.value
        
        # Verify *args is present
        starred_args = [arg for arg in call.args if isinstance(arg, ast.Starred)]
        
        # Verify **kwargs is present
        kwargs_keywords = [kw for kw in call.keywords if kw.arg is None]

    def test_counter_perf_counter_ns_structure(self):
        """Test the structure of the perf_counter_ns call for counter."""
        codeflash_output = _create_cpu_timing_try_body(None); result = codeflash_output # 10.2μs -> 9.50μs (7.59% faster)
        counter_assignments = [
            stmt for stmt in result
            if isinstance(stmt, ast.Assign)
            and any(isinstance(target, ast.Name) and target.id == "counter" for target in stmt.targets)
        ]
        
        assignment = counter_assignments[0]
        call = assignment.value

    def test_duration_calculation_structure(self):
        """Test the structure of the duration calculation."""
        codeflash_output = _create_cpu_timing_try_body(None); result = codeflash_output # 10.5μs -> 9.62μs (8.85% faster)
        duration_assignments = [
            stmt for stmt in result
            if isinstance(stmt, ast.Assign)
            and any(isinstance(target, ast.Name) and target.id == "codeflash_duration" for target in stmt.targets)
        ]
        
        assignment = duration_assignments[0]
        binop = assignment.value

    def test_torch_sync_with_cuda_and_mps(self):
        """Test that torch sync checks both CUDA and MPS."""
        frameworks = {"torch": "torch"}
        codeflash_output = _create_cpu_timing_try_body(frameworks); result = codeflash_output # 21.4μs -> 10.5μs (103% faster)
        
        if_statements = [stmt for stmt in result if isinstance(stmt, ast.If)]
        
        # Find the torch sync statement
        torch_sync = None
        for stmt in if_statements:
            if isinstance(stmt.test, ast.Name):
                if stmt.test.id == "_codeflash_should_sync_cuda":
                    torch_sync = stmt
                    break

    def test_lineno_attributes_set(self):
        """Test that lineno attributes are set on assignment statements."""
        codeflash_output = _create_cpu_timing_try_body(None); result = codeflash_output # 10.3μs -> 9.60μs (7.20% faster)
        assignments = [stmt for stmt in result if isinstance(stmt, ast.Assign)]
        
        for assignment in assignments:
            pass
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.

To edit these changes git checkout codeflash/optimize-pr1335-2026-02-04T00.13.36 and push.

Codeflash Static Badge

aseembits93 and others added 5 commits February 3, 2026 14:33
Add a `gpu` parameter to instrument tests with torch.cuda.Event timing
instead of time.perf_counter_ns() for measuring GPU kernel execution time.
Falls back to CPU timing when CUDA is not available/initialized.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Fix unused variables, single-item membership tests, unnecessary lambdas,
and ternary expressions that can use `or` operator.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
The optimization achieves a **25% speedup** (1.19ms → 952μs) by eliminating redundant AST node construction through two key strategies:

## Primary Optimization: LRU Caching of AST Structures

The code extracts framework-specific AST generation into separate cached functions (`_create_torch_sync_ast`, `_create_jax_sync_ast`, `_create_tf_sync_ast`) decorated with `@lru_cache(maxsize=32)`. This is highly effective because:

1. **Eliminates Repeated Construction**: The line profiler shows the original code spending significant time constructing identical AST nodes on every call. For example, the PyTorch sync statement construction (`ast.If`, nested `ast.Attribute`, `ast.Call`, etc.) took ~791μs for just the MPS test name creation alone. With caching, these structures are built once per framework alias and reused.

2. **Dramatic Per-Call Speedup**: Tests with frameworks show the most significant improvements:
   - `test_with_torch_framework`: 22.0μs → 11.1μs (98.5% faster)
   - `test_with_multiple_frameworks`: 30.5μs → 11.4μs (168% faster)
   - `test_with_tensorflow_framework`: 17.3μs → 10.6μs (62.4% faster)

3. **Cumulative Benefits**: In `test_multiple_consecutive_calls` (100 iterations), the speedup is 710μs → 651μs (9.1%), showing consistent cache hits across repeated invocations.

## Secondary Optimization: Shared Context Objects

The code pre-creates `_LOAD_CTX` and `_STORE_CTX` as module-level constants, reusing the same `ast.Load()` and `ast.Store()` instances throughout. This reduces object allocation overhead, particularly visible in `_create_cpu_timing_try_body` where context objects are used 10+ times per call.

## Performance Impact

The line profiler confirms `_create_device_sync_statements` total time drops from 2.28ms to 0.56ms (75% reduction). The caching is especially beneficial when the same framework configurations are used repeatedly, which is typical in test instrumentation scenarios where the same frameworks are synchronized across many test cases. Tests without frameworks show modest 5-10% gains (context object reuse), while framework-heavy tests show 60-168% improvements (cache hits on AST structures).
@codeflash-ai codeflash-ai bot added ⚡️ codeflash Optimization PR opened by Codeflash AI 🎯 Quality: High Optimization Quality according to Codeflash labels Feb 4, 2026
@KRRT7
Copy link
Collaborator

KRRT7 commented Feb 19, 2026

Closing stale bot PR.

@KRRT7 KRRT7 closed this Feb 19, 2026
@KRRT7 KRRT7 deleted the codeflash/optimize-pr1335-2026-02-04T00.13.36 branch February 19, 2026 12:56
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

⚡️ codeflash Optimization PR opened by Codeflash AI 🎯 Quality: High Optimization Quality according to Codeflash

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants