From 1cdb7dc1e9bf5be72a6b7ac814ced0d2de321b9d Mon Sep 17 00:00:00 2001 From: shreejaykurhade Date: Sun, 19 Apr 2026 02:04:55 +0530 Subject: [PATCH 1/3] Refactor create_wrapper_function in instrument_existing_tests to reduce complexity and isolate wrapper-building concerns --- .../code_utils/instrument_existing_tests.py | 427 ++++++++---------- 1 file changed, 198 insertions(+), 229 deletions(-) diff --git a/codeflash/code_utils/instrument_existing_tests.py b/codeflash/code_utils/instrument_existing_tests.py index 02450bb12..c8f40425a 100644 --- a/codeflash/code_utils/instrument_existing_tests.py +++ b/codeflash/code_utils/instrument_existing_tests.py @@ -1032,11 +1032,28 @@ def _create_device_sync_statements( return sync_statements -def create_wrapper_function( - mode: TestingMode = TestingMode.BEHAVIOR, used_frameworks: dict[str, str] | None = None -) -> ast.FunctionDef: - lineno = 1 - wrapper_body: list[ast.stmt] = [ +def _create_perf_counter_call() -> ast.Call: + return ast.Call( + func=ast.Attribute(value=ast.Name(id="time", ctx=ast.Load()), attr="perf_counter_ns", ctx=ast.Load()), + args=[], + keywords=[], + ) + + +def _create_codeflash_duration_assignment(lineno: int) -> ast.Assign: + return ast.Assign( + targets=[ast.Name(id="codeflash_duration", ctx=ast.Store())], + value=ast.BinOp( + left=_create_perf_counter_call(), + op=ast.Sub(), + right=ast.Name(id="counter", ctx=ast.Load()), + ), + lineno=lineno, + ) + + +def _create_wrapper_invocation_tracking_statements(lineno: int) -> list[ast.stmt]: + return [ ast.Assign( targets=[ast.Name(id="test_id", ctx=ast.Store())], value=ast.JoinedStr( @@ -1065,11 +1082,7 @@ def create_wrapper_function( ), body=[ ast.Assign( - targets=[ - ast.Attribute( - value=ast.Name(id="codeflash_wrap", ctx=ast.Load()), attr="index", ctx=ast.Store() - ) - ], + targets=[ast.Attribute(value=ast.Name(id="codeflash_wrap", ctx=ast.Load()), attr="index", ctx=ast.Store())], value=ast.Dict(keys=[], values=[]), lineno=lineno + 3, ) @@ -1088,9 +1101,7 @@ def create_wrapper_function( body=[ ast.AugAssign( target=ast.Subscript( - value=ast.Attribute( - value=ast.Name(id="codeflash_wrap", ctx=ast.Load()), attr="index", ctx=ast.Load() - ), + value=ast.Attribute(value=ast.Name(id="codeflash_wrap", ctx=ast.Load()), attr="index", ctx=ast.Load()), slice=ast.Name(id="test_id", ctx=ast.Load()), ctx=ast.Store(), ), @@ -1136,66 +1147,186 @@ def create_wrapper_function( ), lineno=lineno + 8, ), - *( - [ - ast.Assign( - targets=[ast.Name(id="test_stdout_tag", ctx=ast.Store())], - value=ast.JoinedStr( - values=[ - ast.FormattedValue( - value=ast.Name(id="codeflash_test_module_name", ctx=ast.Load()), conversion=-1 - ), - ast.Constant(value=":"), - ast.FormattedValue( - value=ast.IfExp( - test=ast.Name(id="codeflash_test_class_name", ctx=ast.Load()), - body=ast.BinOp( - left=ast.Name(id="codeflash_test_class_name", ctx=ast.Load()), - op=ast.Add(), - right=ast.Constant(value="."), - ), - orelse=ast.Constant(value=""), - ), - conversion=-1, - ), - ast.FormattedValue(value=ast.Name(id="codeflash_test_name", ctx=ast.Load()), conversion=-1), - ast.Constant(value=":"), - ast.FormattedValue( - value=ast.Name(id="codeflash_function_name", ctx=ast.Load()), conversion=-1 - ), - ast.Constant(value=":"), - ast.FormattedValue( - value=ast.Name(id="codeflash_loop_index", ctx=ast.Load()), conversion=-1 + ] + + +def _create_wrapper_stdout_statements(lineno: int) -> list[ast.stmt]: + return [ + ast.Assign( + targets=[ast.Name(id="test_stdout_tag", ctx=ast.Store())], + value=ast.JoinedStr( + values=[ + ast.FormattedValue(value=ast.Name(id="codeflash_test_module_name", ctx=ast.Load()), conversion=-1), + ast.Constant(value=":"), + ast.FormattedValue( + value=ast.IfExp( + test=ast.Name(id="codeflash_test_class_name", ctx=ast.Load()), + body=ast.BinOp( + left=ast.Name(id="codeflash_test_class_name", ctx=ast.Load()), + op=ast.Add(), + right=ast.Constant(value="."), ), - ast.Constant(value=":"), - ast.FormattedValue(value=ast.Name(id="invocation_id", ctx=ast.Load()), conversion=-1), + orelse=ast.Constant(value=""), + ), + conversion=-1, + ), + ast.FormattedValue(value=ast.Name(id="codeflash_test_name", ctx=ast.Load()), conversion=-1), + ast.Constant(value=":"), + ast.FormattedValue(value=ast.Name(id="codeflash_function_name", ctx=ast.Load()), conversion=-1), + ast.Constant(value=":"), + ast.FormattedValue(value=ast.Name(id="codeflash_loop_index", ctx=ast.Load()), conversion=-1), + ast.Constant(value=":"), + ast.FormattedValue(value=ast.Name(id="invocation_id", ctx=ast.Load()), conversion=-1), + ] + ), + lineno=lineno + 9, + ), + ast.Expr( + value=ast.Call( + func=ast.Name(id="print", ctx=ast.Load()), + args=[ + ast.JoinedStr( + values=[ + ast.Constant(value="!$######"), + ast.FormattedValue(value=ast.Name(id="test_stdout_tag", ctx=ast.Load()), conversion=-1), + ast.Constant(value="######$!"), ] + ) + ], + keywords=[], + ) + ), + ] + + +def _create_wrapper_try_statement(lineno: int, used_frameworks: dict[str, str] | None) -> ast.Try: + return ast.Try( + body=[ + *_create_device_sync_statements(used_frameworks, for_return_value=False), + ast.Assign( + targets=[ast.Name(id="counter", ctx=ast.Store())], + value=_create_perf_counter_call(), + lineno=lineno + 11, + ), + ast.Assign( + targets=[ast.Name(id="return_value", ctx=ast.Store())], + value=ast.Call( + func=ast.Name(id="codeflash_wrapped", ctx=ast.Load()), + args=[ast.Starred(value=ast.Name(id="args", ctx=ast.Load()), ctx=ast.Load())], + keywords=[ast.keyword(arg=None, value=ast.Name(id="kwargs", ctx=ast.Load()))], + ), + lineno=lineno + 12, + ), + *_create_device_sync_statements(used_frameworks, for_return_value=True), + _create_codeflash_duration_assignment(lineno + 13), + ], + handlers=[ + ast.ExceptHandler( + type=ast.Name(id="Exception", ctx=ast.Load()), + name="e", + body=[ + _create_codeflash_duration_assignment(lineno + 15), + ast.Assign( + targets=[ast.Name(id="exception", ctx=ast.Store())], + value=ast.Name(id="e", ctx=ast.Load()), + lineno=lineno + 13, ), - lineno=lineno + 9, + ], + lineno=lineno + 14, + ) + ], + orelse=[], + finalbody=[], + lineno=lineno + 11, + ) + + +def _create_wrapper_behavior_statements(lineno: int) -> list[ast.stmt]: + return [ + ast.Assign( + targets=[ast.Name(id="pickled_return_value", ctx=ast.Store())], + value=ast.IfExp( + test=ast.Name(id="exception", ctx=ast.Load()), + body=ast.Call( + func=ast.Attribute(value=ast.Name(id="pickle", ctx=ast.Load()), attr="dumps", ctx=ast.Load()), + args=[ast.Name(id="exception", ctx=ast.Load())], + keywords=[], ), - ast.Expr( - value=ast.Call( - func=ast.Name(id="print", ctx=ast.Load()), - args=[ - ast.JoinedStr( - values=[ - ast.Constant(value="!$######"), - ast.FormattedValue( - value=ast.Name(id="test_stdout_tag", ctx=ast.Load()), conversion=-1 - ), - ast.Constant(value="######$!"), - ] - ) - ], - keywords=[], - ) + orelse=ast.Call( + func=ast.Attribute(value=ast.Name(id="pickle", ctx=ast.Load()), attr="dumps", ctx=ast.Load()), + args=[ast.Name(id="return_value", ctx=ast.Load())], + keywords=[], ), - ] + ), + lineno=lineno + 18, + ), + ast.Expr( + value=ast.Call( + func=ast.Attribute(value=ast.Name(id="codeflash_cur", ctx=ast.Load()), attr="execute", ctx=ast.Load()), + args=[ + ast.Constant(value="INSERT INTO test_results VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)"), + ast.Tuple( + elts=[ + ast.Name(id="codeflash_test_module_name", ctx=ast.Load()), + ast.Name(id="codeflash_test_class_name", ctx=ast.Load()), + ast.Name(id="codeflash_test_name", ctx=ast.Load()), + ast.Name(id="codeflash_function_name", ctx=ast.Load()), + ast.Name(id="codeflash_loop_index", ctx=ast.Load()), + ast.Name(id="invocation_id", ctx=ast.Load()), + ast.Name(id="codeflash_duration", ctx=ast.Load()), + ast.Name(id="pickled_return_value", ctx=ast.Load()), + ast.Constant(value=VerificationType.FUNCTION_CALL.value), + ], + ctx=ast.Load(), + ), + ], + keywords=[], + ), + lineno=lineno + 20, + ), + ast.Expr( + value=ast.Call( + func=ast.Attribute(value=ast.Name(id="codeflash_con", ctx=ast.Load()), attr="commit", ctx=ast.Load()), + args=[], + keywords=[], + ), + lineno=lineno + 21, ), + ] + + +def _create_wrapper_args(mode: TestingMode) -> ast.arguments: + return ast.arguments( + args=[ + ast.arg(arg="codeflash_wrapped", annotation=None), + ast.arg(arg="codeflash_test_module_name", annotation=None), + ast.arg(arg="codeflash_test_class_name", annotation=None), + ast.arg(arg="codeflash_test_name", annotation=None), + ast.arg(arg="codeflash_function_name", annotation=None), + ast.arg(arg="codeflash_line_id", annotation=None), + ast.arg(arg="codeflash_loop_index", annotation=None), + *([ast.arg(arg="codeflash_cur", annotation=None)] if mode == TestingMode.BEHAVIOR else []), + *([ast.arg(arg="codeflash_con", annotation=None)] if mode == TestingMode.BEHAVIOR else []), + ], + vararg=ast.arg(arg="args"), + kwarg=ast.arg(arg="kwargs"), + posonlyargs=[], + kwonlyargs=[], + kw_defaults=[], + defaults=[], + ) + + +def create_wrapper_function( + mode: TestingMode = TestingMode.BEHAVIOR, used_frameworks: dict[str, str] | None = None +) -> ast.FunctionDef: + lineno = 1 + wrapper_body: list[ast.stmt] = [ + *_create_wrapper_invocation_tracking_statements(lineno), + *_create_wrapper_stdout_statements(lineno), ast.Assign( targets=[ast.Name(id="exception", ctx=ast.Store())], value=ast.Constant(value=None), lineno=lineno + 10 ), - # Pre-compute device sync conditions before profiling to avoid overhead during timing *_create_device_sync_precompute_statements(used_frameworks), ast.Expr( value=ast.Call( @@ -1205,83 +1336,7 @@ def create_wrapper_function( ), lineno=lineno + 9, ), - ast.Try( - body=[ - # Pre-sync: synchronize device before starting timer - *_create_device_sync_statements(used_frameworks, for_return_value=False), - ast.Assign( - targets=[ast.Name(id="counter", ctx=ast.Store())], - value=ast.Call( - func=ast.Attribute( - value=ast.Name(id="time", ctx=ast.Load()), attr="perf_counter_ns", ctx=ast.Load() - ), - args=[], - keywords=[], - ), - lineno=lineno + 11, - ), - ast.Assign( - targets=[ast.Name(id="return_value", ctx=ast.Store())], - value=ast.Call( - func=ast.Name(id="codeflash_wrapped", ctx=ast.Load()), - args=[ast.Starred(value=ast.Name(id="args", ctx=ast.Load()), ctx=ast.Load())], - keywords=[ast.keyword(arg=None, value=ast.Name(id="kwargs", ctx=ast.Load()))], - ), - lineno=lineno + 12, - ), - # Post-sync: synchronize device after function call to ensure all device work is complete - *_create_device_sync_statements(used_frameworks, for_return_value=True), - ast.Assign( - targets=[ast.Name(id="codeflash_duration", ctx=ast.Store())], - value=ast.BinOp( - left=ast.Call( - func=ast.Attribute( - value=ast.Name(id="time", ctx=ast.Load()), attr="perf_counter_ns", ctx=ast.Load() - ), - args=[], - keywords=[], - ), - op=ast.Sub(), - right=ast.Name(id="counter", ctx=ast.Load()), - ), - lineno=lineno + 13, - ), - ], - handlers=[ - ast.ExceptHandler( - type=ast.Name(id="Exception", ctx=ast.Load()), - name="e", - body=[ - ast.Assign( - targets=[ast.Name(id="codeflash_duration", ctx=ast.Store())], - value=ast.BinOp( - left=ast.Call( - func=ast.Attribute( - value=ast.Name(id="time", ctx=ast.Load()), - attr="perf_counter_ns", - ctx=ast.Load(), - ), - args=[], - keywords=[], - ), - op=ast.Sub(), - right=ast.Name(id="counter", ctx=ast.Load()), - ), - lineno=lineno + 15, - ), - ast.Assign( - targets=[ast.Name(id="exception", ctx=ast.Store())], - value=ast.Name(id="e", ctx=ast.Load()), - lineno=lineno + 13, - ), - ], - lineno=lineno + 14, - ) - ], - orelse=[], - finalbody=[], - lineno=lineno + 11, - ), + _create_wrapper_try_statement(lineno, used_frameworks), ast.Expr( value=ast.Call( func=ast.Attribute(value=ast.Name(id="gc", ctx=ast.Load()), attr="enable", ctx=ast.Load()), @@ -1314,75 +1369,7 @@ def create_wrapper_function( keywords=[], ) ), - *( - [ - ast.Assign( - targets=[ast.Name(id="pickled_return_value", ctx=ast.Store())], - value=ast.IfExp( - test=ast.Name(id="exception", ctx=ast.Load()), - body=ast.Call( - func=ast.Attribute( - value=ast.Name(id="pickle", ctx=ast.Load()), attr="dumps", ctx=ast.Load() - ), - args=[ast.Name(id="exception", ctx=ast.Load())], - keywords=[], - ), - orelse=ast.Call( - func=ast.Attribute( - value=ast.Name(id="pickle", ctx=ast.Load()), attr="dumps", ctx=ast.Load() - ), - args=[ast.Name(id="return_value", ctx=ast.Load())], - keywords=[], - ), - ), - lineno=lineno + 18, - ) - ] - if mode == TestingMode.BEHAVIOR - else [] - ), - *( - [ - ast.Expr( - value=ast.Call( - func=ast.Attribute( - value=ast.Name(id="codeflash_cur", ctx=ast.Load()), attr="execute", ctx=ast.Load() - ), - args=[ - ast.Constant(value="INSERT INTO test_results VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)"), - ast.Tuple( - elts=[ - ast.Name(id="codeflash_test_module_name", ctx=ast.Load()), - ast.Name(id="codeflash_test_class_name", ctx=ast.Load()), - ast.Name(id="codeflash_test_name", ctx=ast.Load()), - ast.Name(id="codeflash_function_name", ctx=ast.Load()), - ast.Name(id="codeflash_loop_index", ctx=ast.Load()), - ast.Name(id="invocation_id", ctx=ast.Load()), - ast.Name(id="codeflash_duration", ctx=ast.Load()), - ast.Name(id="pickled_return_value", ctx=ast.Load()), - ast.Constant(value=VerificationType.FUNCTION_CALL.value), - ], - ctx=ast.Load(), - ), - ], - keywords=[], - ), - lineno=lineno + 20, - ), - ast.Expr( - value=ast.Call( - func=ast.Attribute( - value=ast.Name(id="codeflash_con", ctx=ast.Load()), attr="commit", ctx=ast.Load() - ), - args=[], - keywords=[], - ), - lineno=lineno + 21, - ), - ] - if mode == TestingMode.BEHAVIOR - else [] - ), + *(_create_wrapper_behavior_statements(lineno) if mode == TestingMode.BEHAVIOR else []), ast.If( test=ast.Name(id="exception", ctx=ast.Load()), body=[ast.Raise(exc=ast.Name(id="exception", ctx=ast.Load()), cause=None, lineno=lineno + 22)], @@ -1393,25 +1380,7 @@ def create_wrapper_function( ] return ast.FunctionDef( name="codeflash_wrap", - args=ast.arguments( - args=[ - ast.arg(arg="codeflash_wrapped", annotation=None), - ast.arg(arg="codeflash_test_module_name", annotation=None), - ast.arg(arg="codeflash_test_class_name", annotation=None), - ast.arg(arg="codeflash_test_name", annotation=None), - ast.arg(arg="codeflash_function_name", annotation=None), - ast.arg(arg="codeflash_line_id", annotation=None), - ast.arg(arg="codeflash_loop_index", annotation=None), - *([ast.arg(arg="codeflash_cur", annotation=None)] if mode == TestingMode.BEHAVIOR else []), - *([ast.arg(arg="codeflash_con", annotation=None)] if mode == TestingMode.BEHAVIOR else []), - ], - vararg=ast.arg(arg="args"), - kwarg=ast.arg(arg="kwargs"), - posonlyargs=[], - kwonlyargs=[], - kw_defaults=[], - defaults=[], - ), + args=_create_wrapper_args(mode), body=wrapper_body, lineno=lineno, decorator_list=[], From 5a1efbf34d153032d67b382fe4e6014dc0a44f75 Mon Sep 17 00:00:00 2001 From: shreejaykurhade Date: Sun, 19 Apr 2026 02:15:27 +0530 Subject: [PATCH 2/3] refactor: add missing type hints in instrument_existing_tests --- codeflash/code_utils/instrument_existing_tests.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/codeflash/code_utils/instrument_existing_tests.py b/codeflash/code_utils/instrument_existing_tests.py index c8f40425a..7888eb4f3 100644 --- a/codeflash/code_utils/instrument_existing_tests.py +++ b/codeflash/code_utils/instrument_existing_tests.py @@ -15,7 +15,7 @@ from codeflash.models.models import FunctionParent, TestingMode, VerificationType if TYPE_CHECKING: - from collections.abc import Iterable + from collections.abc import Iterable, Iterator from codeflash.models.models import CodePosition @@ -90,7 +90,7 @@ def find_and_update_line_node( # it's much more efficient to visit nodes manually. We'll only descend into expressions/statements. # Helper for manual walk - def iter_ast_calls(node): + def iter_ast_calls(node: ast.AST) -> Iterator[ast.Call]: # Generator to yield each ast.Call in test_node, preserves node identity stack = [node] while stack: From 8e6b558fbf37797d4ae55328809f18f95b5c84a5 Mon Sep 17 00:00:00 2001 From: shreejaykurhade Date: Sun, 19 Apr 2026 02:17:24 +0530 Subject: [PATCH 3/3] docs: add docstrings to public instrumentation helpers --- .../code_utils/instrument_existing_tests.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/codeflash/code_utils/instrument_existing_tests.py b/codeflash/code_utils/instrument_existing_tests.py index 7888eb4f3..ceedf4ea4 100644 --- a/codeflash/code_utils/instrument_existing_tests.py +++ b/codeflash/code_utils/instrument_existing_tests.py @@ -27,10 +27,12 @@ class FunctionCallNodeArguments: def get_call_arguments(call_node: ast.Call) -> FunctionCallNodeArguments: + """Capture the positional and keyword arguments from a call node.""" return FunctionCallNodeArguments(call_node.args, call_node.keywords) def node_in_call_position(node: ast.AST, call_positions: list[CodePosition]) -> bool: + """Return whether an AST call node overlaps one of the recorded call positions.""" # Profile: The most meaningful speedup here is to reduce attribute lookup and to localize call_positions if not empty. # Small optimizations for tight loop: if isinstance(node, ast.Call): @@ -57,6 +59,7 @@ def node_in_call_position(node: ast.AST, call_positions: list[CodePosition]) -> def is_argument_name(name: str, arguments_node: ast.arguments) -> bool: + """Check whether a name appears anywhere in a function signature's argument lists.""" return any( element.arg == name for attribute_name in dir(arguments_node) @@ -710,6 +713,13 @@ def inject_profiling_into_existing_test( tests_project_root: Path, mode: TestingMode = TestingMode.BEHAVIOR, ) -> tuple[bool, str | None]: + """Instrument matching calls in an existing sync test file and return the rewritten source. + + This is the main entry point for Python test instrumentation. It rewrites the target test + module so matching function calls are wrapped with CodeFlash profiling logic, then prepends + the imports and helper wrapper needed for the selected testing mode. + + """ tests_project_root = tests_project_root.resolve() if function_to_optimize.is_async: return inject_async_profiling_into_existing_test( @@ -1320,6 +1330,12 @@ def _create_wrapper_args(mode: TestingMode) -> ast.arguments: def create_wrapper_function( mode: TestingMode = TestingMode.BEHAVIOR, used_frameworks: dict[str, str] | None = None ) -> ast.FunctionDef: + """Build the `codeflash_wrap` AST function used to profile instrumented sync test calls. + + The generated wrapper is responsible for invocation tracking, timing, optional framework + synchronization, stdout tagging, and behavior-mode result persistence. + + """ lineno = 1 wrapper_body: list[ast.stmt] = [ *_create_wrapper_invocation_tracking_statements(lineno), @@ -1638,6 +1654,7 @@ async def async_wrapper(*args, **kwargs): def get_decorator_name_for_mode(mode: TestingMode) -> str: + """Return the async instrumentation decorator name for the requested testing mode.""" if mode == TestingMode.BEHAVIOR: return "codeflash_behavior_async" if mode == TestingMode.CONCURRENCY: @@ -1701,5 +1718,6 @@ def add_async_decorator_to_function( def create_instrumented_source_module_path(source_path: Path, temp_dir: Path) -> Path: + """Create the temp-file path used to hold an instrumented copy of a source module.""" instrumented_filename = f"instrumented_{source_path.name}" return temp_dir / instrumented_filename