From c311e0c47758b7499a60a388b8ee0a2ac4e2f8c8 Mon Sep 17 00:00:00 2001 From: ali Date: Thu, 23 Apr 2026 13:21:19 +0200 Subject: [PATCH 01/10] base --- codeflash/cli_cmds/cmd_init.py | 5 + codeflash/cli_cmds/init_go.py | 188 ++++++++++ codeflash/cli_cmds/init_javascript.py | 5 + codeflash/languages/__init__.py | 6 + codeflash/languages/current.py | 10 + codeflash/languages/golang/__init__.py | 0 codeflash/languages/golang/config.py | 66 ++++ codeflash/languages/golang/discovery.py | 147 ++++++++ codeflash/languages/golang/parser.py | 334 ++++++++++++++++++ codeflash/languages/golang/support.py | 189 ++++++++++ codeflash/languages/language_enum.py | 1 + codeflash/languages/registry.py | 4 + codeflash/setup/detector.py | 57 +++ codeflash/version.py | 2 +- pyproject.toml | 1 + .../fixtures/go_project/calculator.go | 53 +++ .../fixtures/go_project/calculator_test.go | 34 ++ .../test_languages/fixtures/go_project/go.mod | 7 + tests/test_languages/test_golang/__init__.py | 0 .../test_languages/test_golang/test_config.py | 47 +++ .../test_golang/test_discovery.py | 163 +++++++++ .../test_languages/test_golang/test_parser.py | 218 ++++++++++++ .../test_golang/test_support.py | 120 +++++++ uv.lock | 53 +++ 24 files changed, 1709 insertions(+), 1 deletion(-) create mode 100644 codeflash/cli_cmds/init_go.py create mode 100644 codeflash/languages/golang/__init__.py create mode 100644 codeflash/languages/golang/config.py create mode 100644 codeflash/languages/golang/discovery.py create mode 100644 codeflash/languages/golang/parser.py create mode 100644 codeflash/languages/golang/support.py create mode 100644 tests/test_languages/fixtures/go_project/calculator.go create mode 100644 tests/test_languages/fixtures/go_project/calculator_test.go create mode 100644 tests/test_languages/fixtures/go_project/go.mod create mode 100644 tests/test_languages/test_golang/__init__.py create mode 100644 tests/test_languages/test_golang/test_config.py create mode 100644 tests/test_languages/test_golang/test_discovery.py create mode 100644 tests/test_languages/test_golang/test_parser.py create mode 100644 tests/test_languages/test_golang/test_support.py diff --git a/codeflash/cli_cmds/cmd_init.py b/codeflash/cli_cmds/cmd_init.py index d4da0ed04..bd44cb761 100644 --- a/codeflash/cli_cmds/cmd_init.py +++ b/codeflash/cli_cmds/cmd_init.py @@ -29,6 +29,7 @@ get_suggestions, should_modify_pyproject_toml, ) +from codeflash.cli_cmds.init_go import init_go_project from codeflash.cli_cmds.init_java import init_java_project from codeflash.cli_cmds.init_javascript import ProjectLanguage, detect_project_language, init_js_project from codeflash.code_utils.code_utils import validate_relative_directory_path @@ -61,6 +62,10 @@ def init_codeflash() -> None: # Detect project language project_language = detect_project_language() + if project_language == ProjectLanguage.GO: + init_go_project() + return + if project_language == ProjectLanguage.JAVA: init_java_project() return diff --git a/codeflash/cli_cmds/init_go.py b/codeflash/cli_cmds/init_go.py new file mode 100644 index 000000000..032072231 --- /dev/null +++ b/codeflash/cli_cmds/init_go.py @@ -0,0 +1,188 @@ +"""Go project initialization for Codeflash.""" + +from __future__ import annotations + +import os +import sys +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Union + +import click +import inquirer +from git import InvalidGitRepositoryError, Repo +from rich.console import Group +from rich.panel import Panel +from rich.table import Table +from rich.text import Text + +from codeflash.cli_cmds.console import console +from codeflash.code_utils.compat import LF +from codeflash.code_utils.git_utils import get_git_remotes +from codeflash.code_utils.shell_utils import get_shell_rc_path, is_powershell +from codeflash.languages.golang.config import detect_go_project, detect_go_version +from codeflash.telemetry.posthog_cf import ph + + +@dataclass(frozen=True) +class GoSetupInfo: + module_root_override: Union[str, None] = None + test_root_override: Union[str, None] = None + formatter_override: Union[list[str], None] = None + git_remote: str = "origin" + disable_telemetry: bool = False + ignore_paths: list[str] | None = None + + +def _get_theme() -> Any: + from codeflash.cli_cmds.init_config import CodeflashTheme + + return CodeflashTheme() + + +def init_go_project() -> None: + from codeflash.cli_cmds.github_workflow import install_github_actions + from codeflash.cli_cmds.init_auth import install_github_app, prompt_api_key + + lang_panel = Panel( + Text( + "Go project detected!\n\nI'll help you set up Codeflash for your project.", style="cyan", justify="center" + ), + title="Go Setup", + border_style="bright_cyan", + ) + console.print(lang_panel) + console.print() + + did_add_new_key = prompt_api_key() + + setup_info = collect_go_setup_info() + git_remote = setup_info.git_remote or "origin" + + install_github_app(git_remote) + + install_github_actions(override_formatter_check=True) + + usage_table = Table(show_header=False, show_lines=False, border_style="dim") + usage_table.add_column("Command", style="cyan") + usage_table.add_column("Description", style="white") + + usage_table.add_row("codeflash --file --function ", "Optimize a specific function") + usage_table.add_row("codeflash --all", "Optimize all functions in all files") + usage_table.add_row("codeflash --help", "See all available options") + + completion_message = "Codeflash is now set up for your Go project!\n\nYou can now run any of these commands:" + + if did_add_new_key: + completion_message += ( + "\n\nDon't forget to restart your shell to load the CODEFLASH_API_KEY environment variable!" + ) + if os.name == "nt": + reload_cmd = f". {get_shell_rc_path()}" if is_powershell() else f"call {get_shell_rc_path()}" + else: + reload_cmd = f"source {get_shell_rc_path()}" + completion_message += f"\nOr run: {reload_cmd}" + + completion_panel = Panel( + Group(Text(completion_message, style="bold green"), Text(""), usage_table), + title="Setup Complete!", + border_style="bright_green", + padding=(1, 2), + ) + console.print(completion_panel) + + ph("cli-go-installation-successful", {"did_add_new_key": did_add_new_key}) + sys.exit(0) + + +def collect_go_setup_info() -> GoSetupInfo: + + from codeflash.cli_cmds.init_config import ask_for_telemetry + + curdir = Path.cwd() + + if not os.access(curdir, os.W_OK): + click.echo(f"The current directory isn't writable, please check your folder permissions and try again.{LF}") + sys.exit(1) + + config = detect_go_project(curdir) + module_path = config.module_path if config else "unknown" + go_version = (config.go_version if config else None) or detect_go_version() or "unknown" + has_vendor = config.has_vendor if config else False + + detection_table = Table(show_header=False, box=None, padding=(0, 2)) + detection_table.add_column("Setting", style="cyan") + detection_table.add_column("Value", style="green") + detection_table.add_row("Module", module_path) + detection_table.add_row("Go version", go_version) + detection_table.add_row("Source root", ".") + detection_table.add_row("Test root", ". (co-located)") + detection_table.add_row("Formatter", "gofmt") + if has_vendor: + detection_table.add_row("Vendor", "yes (vendor/ detected)") + + detection_panel = Panel( + Group(Text("Auto-detected settings for your Go project:\n", style="cyan"), detection_table), + title="Auto-Detection Results", + border_style="bright_blue", + ) + console.print(detection_panel) + console.print() + + git_remote = _get_git_remote_for_setup() + + disable_telemetry = not ask_for_telemetry() + + return GoSetupInfo(git_remote=git_remote, disable_telemetry=disable_telemetry) + + +def _get_git_remote_for_setup() -> str: + try: + repo = Repo(Path.cwd(), search_parent_directories=True) + git_remotes = get_git_remotes(repo) + if not git_remotes: + return "" + + if len(git_remotes) == 1: + return git_remotes[0] + + git_panel = Panel( + Text( + "Configure Git Remote for Pull Requests.\n\nCodeflash will use this remote to create pull requests.", + style="blue", + ), + title="Git Remote Setup", + border_style="bright_blue", + ) + console.print(git_panel) + console.print() + + git_questions = [ + inquirer.List( + "git_remote", + message="Which git remote should Codeflash use?", + choices=git_remotes, + default="origin", + carousel=True, + ) + ] + + git_answers = inquirer.prompt(git_questions, theme=_get_theme()) + return git_answers["git_remote"] if git_answers else git_remotes[0] + except InvalidGitRepositoryError: + return "" + + +def get_go_runtime_setup_steps() -> str: + return """- name: Set up Go + uses: actions/setup-go@v5 + with: + go-version: 'stable'""" + + +def get_go_dependency_installation_commands() -> str: + return "go mod download" + + +def get_go_test_command() -> str: + return "go test ./..." diff --git a/codeflash/cli_cmds/init_javascript.py b/codeflash/cli_cmds/init_javascript.py index fcd3c4b57..20f76d249 100644 --- a/codeflash/cli_cmds/init_javascript.py +++ b/codeflash/cli_cmds/init_javascript.py @@ -38,6 +38,7 @@ class ProjectLanguage(Enum): JAVASCRIPT = auto() TYPESCRIPT = auto() JAVA = auto() + GO = auto() class JsPackageManager(Enum): @@ -89,6 +90,10 @@ def detect_project_language(project_root: Path | None = None) -> ProjectLanguage """ root = project_root or Path.cwd() + # Go detection (go.mod is definitive) + if (root / "go.mod").exists(): + return ProjectLanguage.GO + # Java detection (pom.xml or build.gradle is definitive) has_pom = (root / "pom.xml").exists() has_gradle = (root / "build.gradle").exists() or (root / "build.gradle.kts").exists() diff --git a/codeflash/languages/__init__.py b/codeflash/languages/__init__.py index b0daea0fb..0ec0f87fd 100644 --- a/codeflash/languages/__init__.py +++ b/codeflash/languages/__init__.py @@ -31,6 +31,7 @@ from codeflash.languages.current import ( current_language, current_language_support, + is_go, is_java, is_javascript, is_python, @@ -83,6 +84,10 @@ def __getattr__(name: str): from codeflash.languages.java.support import JavaSupport return JavaSupport + if name == "GoSupport": + from codeflash.languages.golang.support import GoSupport + + return GoSupport msg = f"module {__name__!r} has no attribute {name!r}" raise AttributeError(msg) @@ -106,6 +111,7 @@ def __getattr__(name: str): "get_language_support", "get_supported_extensions", "get_supported_languages", + "is_go", "is_java", "is_javascript", "is_jest", diff --git a/codeflash/languages/current.py b/codeflash/languages/current.py index b9e45d367..8be5fd07a 100644 --- a/codeflash/languages/current.py +++ b/codeflash/languages/current.py @@ -113,6 +113,16 @@ def is_java() -> bool: return _current_language == Language.JAVA +def is_go() -> bool: + """Check if the current language is Go. + + Returns: + True if the current language is Go. + + """ + return _current_language == Language.GO + + def current_language_support() -> LanguageSupport: """Get the LanguageSupport instance for the current language. diff --git a/codeflash/languages/golang/__init__.py b/codeflash/languages/golang/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/codeflash/languages/golang/config.py b/codeflash/languages/golang/config.py new file mode 100644 index 000000000..25eae4cce --- /dev/null +++ b/codeflash/languages/golang/config.py @@ -0,0 +1,66 @@ +from __future__ import annotations + +import logging +import re +import subprocess +from dataclasses import dataclass +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from pathlib import Path + +logger = logging.getLogger(__name__) + + +@dataclass(frozen=True) +class GoProjectConfig: + project_root: Path + module_path: str + go_version: str | None = None + has_vendor: bool = False + + +def detect_go_project(project_root: Path) -> GoProjectConfig | None: + go_mod = project_root / "go.mod" + if not go_mod.exists(): + return None + + module_path = "" + go_version = None + + try: + content = go_mod.read_text(encoding="utf-8") + for line in content.splitlines(): + line = line.strip() + if line.startswith("module "): + module_path = line[len("module ") :].strip() + elif line.startswith("go "): + go_version = line[len("go ") :].strip() + except (OSError, UnicodeDecodeError): + logger.warning("Failed to read go.mod at %s", go_mod) + return None + + has_vendor = (project_root / "vendor").is_dir() + + return GoProjectConfig( + project_root=project_root, module_path=module_path, go_version=go_version, has_vendor=has_vendor + ) + + +def detect_go_version() -> str | None: + try: + result = subprocess.run(["go", "version"], capture_output=True, text=True, timeout=10, check=False) + if result.returncode != 0: + return None + match = re.search(r"go(\d+\.\d+(?:\.\d+)?)", result.stdout) + if match: + return match.group(1) + except (FileNotFoundError, subprocess.TimeoutExpired, OSError): + pass + return None + + +def is_go_project(project_root: Path) -> bool: + if (project_root / "go.mod").exists(): + return True + return any(project_root.glob("*.go")) diff --git a/codeflash/languages/golang/discovery.py b/codeflash/languages/golang/discovery.py new file mode 100644 index 000000000..c2e9ede2a --- /dev/null +++ b/codeflash/languages/golang/discovery.py @@ -0,0 +1,147 @@ +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING + +from codeflash.languages.golang.parser import GoAnalyzer + +if TYPE_CHECKING: + from pathlib import Path + + from codeflash.languages.base import FunctionFilterCriteria + from codeflash.languages.golang.parser import GoFunctionNode, GoMethodNode + from codeflash.models.function_types import FunctionToOptimize + + +logger = logging.getLogger(__name__) + +_SKIP_FUNCTION_NAMES = frozenset({"init", "main"}) + + +def discover_functions( + file_path: Path, filter_criteria: FunctionFilterCriteria | None = None, analyzer: GoAnalyzer | None = None +) -> list[FunctionToOptimize]: + try: + source = file_path.read_text(encoding="utf-8") + except (OSError, UnicodeDecodeError): + logger.warning("Failed to read Go file: %s", file_path) + return [] + return discover_functions_from_source(source, file_path, filter_criteria, analyzer) + + +def discover_functions_from_source( + source: str, + file_path: Path, + filter_criteria: FunctionFilterCriteria | None = None, + analyzer: GoAnalyzer | None = None, +) -> list[FunctionToOptimize]: + from codeflash.models.function_types import FunctionParent, FunctionToOptimize + + if analyzer is None: + analyzer = GoAnalyzer() + + results: list[FunctionToOptimize] = [] + + functions = analyzer.find_functions(source) + for func in functions: + if not _should_include_function(func, filter_criteria, file_path): + continue + results.append( + FunctionToOptimize( + function_name=func.name, + file_path=file_path, + parents=[], + starting_line=func.starting_line, + ending_line=func.ending_line, + starting_col=func.starting_col, + ending_col=func.ending_col, + is_async=False, + is_method=False, + language="go", + doc_start_line=func.doc_start_line, + ) + ) + + methods = analyzer.find_methods(source) + for method in methods: + if not _should_include_method(method, filter_criteria, file_path): + continue + results.append( + FunctionToOptimize( + function_name=method.name, + file_path=file_path, + parents=[FunctionParent(name=method.receiver_name, type="StructDef")], + starting_line=method.starting_line, + ending_line=method.ending_line, + starting_col=method.starting_col, + ending_col=method.ending_col, + is_async=False, + is_method=True, + language="go", + doc_start_line=method.doc_start_line, + ) + ) + + return results + + +def _should_include_function(func: GoFunctionNode, criteria: FunctionFilterCriteria | None, file_path: Path) -> bool: + if file_path.name.endswith("_test.go"): + return False + + if func.name in _SKIP_FUNCTION_NAMES: + return False + + if criteria is None: + return True + + if criteria.require_export and not func.is_exported: + return False + + if criteria.require_return and not func.has_return_type: + return False + + if criteria.matches_exclude_patterns(func.name): + return False + + if not criteria.matches_include_patterns(func.name): + return False + + line_count = func.ending_line - func.starting_line + 1 + if criteria.min_lines is not None and line_count < criteria.min_lines: + return False + if criteria.max_lines is not None and line_count > criteria.max_lines: + return False + + return True + + +def _should_include_method(method: GoMethodNode, criteria: FunctionFilterCriteria | None, file_path: Path) -> bool: + if file_path.name.endswith("_test.go"): + return False + + if criteria is None: + return True + + if not criteria.include_methods: + return False + + if criteria.require_export and not method.is_exported: + return False + + if criteria.require_return and not method.has_return_type: + return False + + if criteria.matches_exclude_patterns(method.name): + return False + + if not criteria.matches_include_patterns(method.name): + return False + + line_count = method.ending_line - method.starting_line + 1 + if criteria.min_lines is not None and line_count < criteria.min_lines: + return False + if criteria.max_lines is not None and line_count > criteria.max_lines: + return False + + return True diff --git a/codeflash/languages/golang/parser.py b/codeflash/languages/golang/parser.py new file mode 100644 index 000000000..ce0f173e6 --- /dev/null +++ b/codeflash/languages/golang/parser.py @@ -0,0 +1,334 @@ +from __future__ import annotations + +import logging +from dataclasses import dataclass, field +from typing import TYPE_CHECKING + +from tree_sitter import Language, Parser + +if TYPE_CHECKING: + from tree_sitter import Node, Tree + +logger = logging.getLogger(__name__) + +_GO_LANGUAGE: Language | None = None +_GO_PARSER: Parser | None = None + + +def _get_go_language() -> Language: + global _GO_LANGUAGE + if _GO_LANGUAGE is None: + import tree_sitter_go + + _GO_LANGUAGE = Language(tree_sitter_go.language()) + return _GO_LANGUAGE + + +def _get_go_parser() -> Parser: + global _GO_PARSER + if _GO_PARSER is None: + _GO_PARSER = Parser(_get_go_language()) + return _GO_PARSER + + +@dataclass(frozen=True) +class GoFunctionNode: + name: str + starting_line: int + ending_line: int + starting_col: int + ending_col: int + is_exported: bool + has_return_type: bool + doc_start_line: int | None = None + + +@dataclass(frozen=True) +class GoMethodNode: + name: str + receiver_name: str + receiver_is_pointer: bool + starting_line: int + ending_line: int + starting_col: int + ending_col: int + is_exported: bool + has_return_type: bool + doc_start_line: int | None = None + + +@dataclass(frozen=True) +class GoStructNode: + name: str + starting_line: int + ending_line: int + fields: list[str] = field(default_factory=list) + + +@dataclass(frozen=True) +class GoInterfaceNode: + name: str + starting_line: int + ending_line: int + methods: list[str] = field(default_factory=list) + + +@dataclass(frozen=True) +class GoImportInfo: + path: str + alias: str | None + starting_line: int + ending_line: int + + +class GoAnalyzer: + def __init__(self) -> None: + self._parser = _get_go_parser() + self._source_bytes: bytes | None = None + self._tree: Tree | None = None + + def parse(self, source: str) -> Tree: + self._source_bytes = source.encode("utf-8") + self._tree = self._parser.parse(self._source_bytes) + return self._tree + + def get_node_text(self, node: Node) -> str: + if self._source_bytes is None: + return "" + return self._source_bytes[node.start_byte : node.end_byte].decode("utf-8") + + def validate_syntax(self, source: str) -> bool: + tree = self.parse(source) + return not tree.root_node.has_error + + def find_functions(self, source: str) -> list[GoFunctionNode]: + tree = self.parse(source) + results: list[GoFunctionNode] = [] + for node in tree.root_node.children: + if node.type == "function_declaration": + func = self._parse_function_node(node) + if func is not None: + results.append(func) + return results + + def find_methods(self, source: str) -> list[GoMethodNode]: + tree = self.parse(source) + results: list[GoMethodNode] = [] + for node in tree.root_node.children: + if node.type == "method_declaration": + method = self._parse_method_node(node) + if method is not None: + results.append(method) + return results + + def find_structs(self, source: str) -> list[GoStructNode]: + tree = self.parse(source) + results: list[GoStructNode] = [] + for node in tree.root_node.children: + if node.type == "type_declaration": + for spec in _children_of_type(node, "type_spec"): + type_node = spec.child_by_field_name("type") + if type_node is not None and type_node.type == "struct_type": + name_node = spec.child_by_field_name("name") + if name_node is not None: + fields = self._extract_struct_fields(type_node) + results.append( + GoStructNode( + name=self.get_node_text(name_node), + starting_line=node.start_point.row + 1, + ending_line=node.end_point.row + 1, + fields=fields, + ) + ) + return results + + def find_interfaces(self, source: str) -> list[GoInterfaceNode]: + tree = self.parse(source) + results: list[GoInterfaceNode] = [] + for node in tree.root_node.children: + if node.type == "type_declaration": + for spec in _children_of_type(node, "type_spec"): + type_node = spec.child_by_field_name("type") + if type_node is not None and type_node.type == "interface_type": + name_node = spec.child_by_field_name("name") + if name_node is not None: + methods = self._extract_interface_methods(type_node) + results.append( + GoInterfaceNode( + name=self.get_node_text(name_node), + starting_line=node.start_point.row + 1, + ending_line=node.end_point.row + 1, + methods=methods, + ) + ) + return results + + def find_imports(self, source: str) -> list[GoImportInfo]: + tree = self.parse(source) + results: list[GoImportInfo] = [] + for node in tree.root_node.children: + if node.type == "import_declaration": + for spec in _iter_import_specs(node): + path_node = spec.child_by_field_name("path") + if path_node is None: + continue + import_path = self.get_node_text(path_node).strip('"') + alias_node = spec.child_by_field_name("name") + alias = self.get_node_text(alias_node) if alias_node is not None else None + results.append( + GoImportInfo( + path=import_path, + alias=alias, + starting_line=spec.start_point.row + 1, + ending_line=spec.end_point.row + 1, + ) + ) + return results + + def find_package_name(self, source: str) -> str | None: + tree = self.parse(source) + for node in tree.root_node.children: + if node.type == "package_clause": + for child in node.children: + if child.type == "package_identifier": + return self.get_node_text(child) + return None + + def _parse_function_node(self, node: Node) -> GoFunctionNode | None: + name_node = node.child_by_field_name("name") + if name_node is None: + return None + name = self.get_node_text(name_node) + result_node = node.child_by_field_name("result") + doc_line = _find_preceding_comment_line(node) + return GoFunctionNode( + name=name, + starting_line=node.start_point.row + 1, + ending_line=node.end_point.row + 1, + starting_col=node.start_point.column, + ending_col=node.end_point.column, + is_exported=name[0].isupper(), + has_return_type=result_node is not None, + doc_start_line=doc_line, + ) + + def _parse_method_node(self, node: Node) -> GoMethodNode | None: + name_node = node.child_by_field_name("name") + if name_node is None: + return None + name = self.get_node_text(name_node) + + receiver_node = node.child_by_field_name("receiver") + if receiver_node is None: + return None + receiver_name, receiver_is_pointer = self._parse_receiver(receiver_node) + if receiver_name is None: + return None + + result_node = node.child_by_field_name("result") + doc_line = _find_preceding_comment_line(node) + return GoMethodNode( + name=name, + receiver_name=receiver_name, + receiver_is_pointer=receiver_is_pointer, + starting_line=node.start_point.row + 1, + ending_line=node.end_point.row + 1, + starting_col=node.start_point.column, + ending_col=node.end_point.column, + is_exported=name[0].isupper(), + has_return_type=result_node is not None, + doc_start_line=doc_line, + ) + + def _parse_receiver(self, receiver_node: Node) -> tuple[str | None, bool]: + for param in _children_of_type(receiver_node, "parameter_declaration"): + type_node = param.child_by_field_name("type") + if type_node is None: + continue + if type_node.type == "pointer_type": + inner = type_node.child(1) + if inner is not None: + return self.get_node_text(inner), True + elif type_node.type == "type_identifier": + return self.get_node_text(type_node), False + return None, False + + def _extract_struct_fields(self, struct_node: Node) -> list[str]: + fields: list[str] = [] + for child in struct_node.children: + if child.type == "field_declaration_list": + for fc in child.children: + if fc.type == "field_declaration": + fields.append(self.get_node_text(fc).strip()) + break + return fields + + def _extract_interface_methods(self, iface_node: Node) -> list[str]: + methods: list[str] = [] + for child in iface_node.children: + if child.type == "method_elem": + methods.append(self.get_node_text(child).strip()) + return methods + + def extract_function_source(self, source: str, func_name: str, receiver_type: str | None = None) -> str | None: + tree = self.parse(source) + for node in tree.root_node.children: + if receiver_type is None and node.type == "function_declaration": + name_node = node.child_by_field_name("name") + if name_node is not None and self.get_node_text(name_node) == func_name: + return self._get_source_with_doc(node) + + if receiver_type is not None and node.type == "method_declaration": + name_node = node.child_by_field_name("name") + if name_node is None or self.get_node_text(name_node) != func_name: + continue + recv_node = node.child_by_field_name("receiver") + if recv_node is not None: + recv_name, _ = self._parse_receiver(recv_node) + if recv_name == receiver_type: + return self._get_source_with_doc(node) + return None + + def _get_source_with_doc(self, node: Node) -> str: + doc_line = _find_preceding_comment_line(node) + if doc_line is not None and self._source_bytes is not None: + lines = self._source_bytes.decode("utf-8").splitlines(keepends=True) + start = doc_line - 1 + end = node.end_point.row + 1 + return "".join(lines[start:end]) + return self.get_node_text(node) + + +def _children_of_type(node: Node, type_name: str) -> list[Node]: + return [child for child in node.children if child.type == type_name] + + +def _iter_import_specs(import_node: Node) -> list[Node]: + results: list[Node] = [] + for child in import_node.children: + if child.type == "import_spec": + results.append(child) + elif child.type == "import_spec_list": + results.extend(c for c in child.children if c.type == "import_spec") + return results + + +def _find_preceding_comment_line(node: Node) -> int | None: + prev = node.prev_named_sibling + if prev is None: + return None + if prev.type != "comment": + return None + if prev.end_point.row + 1 != node.start_point.row: + return None + comment_start = prev.start_point.row + 1 + current = prev + while True: + earlier = current.prev_named_sibling + if earlier is None or earlier.type != "comment": + break + if earlier.end_point.row + 1 != current.start_point.row: + break + comment_start = earlier.start_point.row + 1 + current = earlier + return comment_start diff --git a/codeflash/languages/golang/support.py b/codeflash/languages/golang/support.py new file mode 100644 index 000000000..4f5d08b3f --- /dev/null +++ b/codeflash/languages/golang/support.py @@ -0,0 +1,189 @@ +from __future__ import annotations + +import logging +from pathlib import Path +from typing import TYPE_CHECKING, Any + +from codeflash.languages.golang.config import detect_go_project, detect_go_version +from codeflash.languages.golang.discovery import discover_functions_from_source +from codeflash.languages.golang.parser import GoAnalyzer +from codeflash.languages.language_enum import Language +from codeflash.languages.registry import register_language + +if TYPE_CHECKING: + from collections.abc import Sequence + + from codeflash.languages.base import ( + CodeContext, + DependencyResolver, + FunctionFilterCriteria, + HelperFunction, + ReferenceInfo, + ) + from codeflash.models.function_types import FunctionToOptimize + +logger = logging.getLogger(__name__) + + +@register_language +class GoSupport: + def __init__(self) -> None: + self._analyzer = GoAnalyzer() + self._go_version: str | None = None + self._go_version_detected = False + + @property + def language(self) -> Language: + return Language.GO + + @property + def file_extensions(self) -> tuple[str, ...]: + return (".go",) + + @property + def default_file_extension(self) -> str: + return ".go" + + @property + def test_framework(self) -> str: + return "go-test" + + @property + def comment_prefix(self) -> str: + return "//" + + @property + def dir_excludes(self) -> frozenset[str]: + return frozenset({"vendor", "testdata", ".git", "node_modules"}) + + @property + def language_version(self) -> str | None: + if not self._go_version_detected: + self._go_version = detect_go_version() + self._go_version_detected = True + return self._go_version + + @property + def valid_test_frameworks(self) -> tuple[str, ...]: + return ("go-test",) + + @property + def test_result_serialization_format(self) -> str: + return "json" + + @property + def function_optimizer_class(self) -> type: + raise NotImplementedError + + def discover_functions( + self, source: str, file_path: Path, filter_criteria: FunctionFilterCriteria | None = None + ) -> list[FunctionToOptimize]: + return discover_functions_from_source(source, file_path, filter_criteria, self._analyzer) + + def discover_tests(self, test_root: Path, source_functions: Sequence[FunctionToOptimize]) -> dict[str, list[Any]]: + raise NotImplementedError + + def validate_syntax(self, source: str, file_path: Path | None = None) -> bool: + return self._analyzer.validate_syntax(source) + + def extract_code_context(self, function: FunctionToOptimize, project_root: Path, module_root: Path) -> CodeContext: + raise NotImplementedError + + def find_helper_functions(self, function: FunctionToOptimize, project_root: Path) -> list[HelperFunction]: + raise NotImplementedError + + def find_references( + self, function: FunctionToOptimize, project_root: Path, tests_root: Path | None = None, max_files: int = 100 + ) -> list[ReferenceInfo]: + raise NotImplementedError + + def replace_function(self, source: str, function: FunctionToOptimize, new_source: str) -> str: + raise NotImplementedError + + def format_code(self, source: str, file_path: Path | None = None) -> str: + raise NotImplementedError + + def normalize_code(self, source: str) -> str: + raise NotImplementedError + + def add_global_declarations(self, optimized_code: str, original_source: str, module_abspath: Path) -> str: + raise NotImplementedError + + def prepare_module( + self, module_code: str, module_path: Path, project_root: Path + ) -> tuple[dict[Path, Any], None] | None: + raise NotImplementedError + + def setup_test_config(self, test_cfg: Any) -> None: + project_root = getattr(test_cfg, "project_root_path", Path.cwd()) + config = detect_go_project(project_root) + if config is not None and config.go_version: + self._go_version = config.go_version + self._go_version_detected = True + + def detect_module_system(self, project_root: Path, source_file: Path | None = None) -> str | None: + return None + + def run_behavioral_tests(self, *args: Any, **kwargs: Any) -> Any: + raise NotImplementedError + + def run_benchmarking_tests(self, *args: Any, **kwargs: Any) -> Any: + raise NotImplementedError + + def run_line_profile_tests(self, *args: Any, **kwargs: Any) -> Any: + raise NotImplementedError + + def compare_test_results(self, *args: Any, **kwargs: Any) -> tuple[bool, list[Any]]: + raise NotImplementedError + + def instrument_for_behavior(self, source: str, functions: Sequence[FunctionToOptimize]) -> str: + return source + + def instrument_for_benchmarking(self, test_source: str, target_function: FunctionToOptimize) -> str: + return test_source + + def instrument_existing_test(self, *args: Any, **kwargs: Any) -> tuple[bool, str | None]: + raise NotImplementedError + + def postprocess_generated_tests(self, *args: Any, **kwargs: Any) -> Any: + raise NotImplementedError + + def process_generated_test_strings(self, *args: Any, **kwargs: Any) -> Any: + raise NotImplementedError + + def load_coverage(self, *args: Any, **kwargs: Any) -> Any: + return None + + def get_test_file_suffix(self) -> str: + return "_test.go" + + def find_test_root(self, project_root: Path) -> Path | None: + return project_root + + def get_runtime_files(self) -> list[Path]: + return [] + + def ensure_runtime_environment(self, project_root: Path) -> bool: + return detect_go_version() is not None + + def create_dependency_resolver(self, project_root: Path) -> DependencyResolver | None: + return None + + def adjust_test_config_for_discovery(self, test_cfg: Any) -> None: + pass + + def add_runtime_comments( + self, test_source: str, original_runtimes: dict[str, Any], optimized_runtimes: dict[str, Any] + ) -> str: + return test_source + + def remove_test_functions(self, test_source: str, functions_to_remove: list[str]) -> str: + raise NotImplementedError + + def get_test_dir_for_source(self, test_dir: Path, source_file: Path | None = None) -> Path | None: + if source_file is not None: + return source_file.parent + return test_dir + + def parse_test_results(self, *args: Any, **kwargs: Any) -> Any: + raise NotImplementedError diff --git a/codeflash/languages/language_enum.py b/codeflash/languages/language_enum.py index 23187cb30..4b72db62b 100644 --- a/codeflash/languages/language_enum.py +++ b/codeflash/languages/language_enum.py @@ -13,6 +13,7 @@ class Language(str, Enum): JAVASCRIPT = "javascript" TYPESCRIPT = "typescript" JAVA = "java" + GO = "go" def __str__(self) -> str: return self.value diff --git a/codeflash/languages/registry.py b/codeflash/languages/registry.py index 17a528fae..e151a5e5c 100644 --- a/codeflash/languages/registry.py +++ b/codeflash/languages/registry.py @@ -54,6 +54,7 @@ def _ensure_languages_registered() -> None: "codeflash.languages.python.support", "codeflash.languages.javascript.support", "codeflash.languages.java.support", + "codeflash.languages.golang.support", ): with contextlib.suppress(ImportError): importlib.import_module(_lang_module) @@ -227,11 +228,14 @@ def get_language_support_by_common_formatters(formatter_cmd: str | list[str]) -> py_formatters = ["black", "isort", "ruff", "autopep8", "yapf", "pyfmt"] js_ts_formatters = ["prettier", "eslint", "biome", "rome", "deno", "standard", "tslint"] + go_formatters = ["gofmt", "goimports", "golines"] if any(cmd in py_formatters for cmd in formatter_cmd): ext = ".py" elif any(cmd in js_ts_formatters for cmd in formatter_cmd): ext = ".js" + elif any(cmd in go_formatters for cmd in formatter_cmd): + ext = ".go" if ext is None: # can't determine language diff --git a/codeflash/setup/detector.py b/codeflash/setup/detector.py index 216dd669d..3d7a0ea45 100644 --- a/codeflash/setup/detector.py +++ b/codeflash/setup/detector.py @@ -172,6 +172,7 @@ def _find_project_root(start_path: Path) -> Path | None: "pom.xml", "build.gradle", "build.gradle.kts", + "go.mod", ] for marker in markers: if (current / marker).exists(): @@ -203,6 +204,11 @@ def _detect_language(project_root: Path) -> tuple[str, float, str]: has_package_json = (project_root / "package.json").exists() has_pom_xml = (project_root / "pom.xml").exists() has_build_gradle = (project_root / "build.gradle").exists() or (project_root / "build.gradle.kts").exists() + has_go_mod = (project_root / "go.mod").exists() + + # Go (go.mod is definitive) + if has_go_mod: + return "go", 1.0, "go.mod found" # Java (pom.xml or build.gradle is definitive) if has_pom_xml: @@ -235,7 +241,10 @@ def _detect_language(project_root: Path) -> tuple[str, float, str]: js_count = len(list(project_root.rglob("*.js"))) ts_count = len(list(project_root.rglob("*.ts"))) java_count = len(list(project_root.rglob("*.java"))) + go_count = len(list(project_root.rglob("*.go"))) + if go_count > 0 and go_count >= max(py_count, js_count, ts_count, java_count): + return "go", 0.5, f"found {go_count} .go files" if java_count > 0 and java_count >= max(py_count, js_count, ts_count): return "java", 0.5, f"found {java_count} .java files" if ts_count > 0: @@ -264,6 +273,8 @@ def _detect_module_root(project_root: Path, language: str) -> tuple[Path, str]: return _detect_js_module_root(project_root) if language == "java": return _detect_java_module_root(project_root) + if language == "go": + return _detect_go_module_root(project_root) return _detect_python_module_root(project_root) @@ -441,6 +452,23 @@ def _detect_java_module_root(project_root: Path) -> tuple[Path, str]: return project_root, "project root" +def _detect_go_module_root(project_root: Path) -> tuple[Path, str]: + """Detect Go module root directory. + + Go projects use go.mod at the module root. The source directory is the + same as the module root (Go packages are directories, not subdirectories). + """ + if (project_root / "go.mod").exists(): + return project_root, "project root (go.mod found)" + + # Check common subdirectories + for subdir in ["cmd", "pkg", "internal"]: + if (project_root / subdir).is_dir(): + return project_root, f"project root ({subdir}/ found)" + + return project_root, "project root" + + def is_build_output_dir(path: Path) -> bool: """Check if a path is within a common build output directory. @@ -474,6 +502,13 @@ def _detect_tests_root(project_root: Path, language: str) -> tuple[Path | None, - spec/ (Ruby/JavaScript) """ + # Go: tests are co-located with source files (*_test.go) + if language == "go": + test_files = list(project_root.rglob("*_test.go")) + if test_files: + return project_root, "project root (Go tests co-located with source)" + return project_root, "project root (Go convention: *_test.go)" + # Java: standard Maven/Gradle test layout if language == "java": import xml.etree.ElementTree as ET @@ -558,6 +593,8 @@ def _detect_test_runner(project_root: Path, language: str) -> tuple[str, str]: return _detect_js_test_runner(project_root) if language == "java": return _detect_java_test_runner(project_root) + if language == "go": + return "go-test", "go test (built-in)" return _detect_python_test_runner(project_root) @@ -686,6 +723,8 @@ def _detect_formatter(project_root: Path, language: str) -> tuple[list[str], str return _detect_js_formatter(project_root) if language == "java": return _detect_java_formatter(project_root) + if language == "go": + return _detect_go_formatter(project_root) return _detect_python_formatter(project_root) @@ -803,6 +842,19 @@ def _detect_js_formatter(project_root: Path) -> tuple[list[str], str]: return [], "none detected" +def _detect_go_formatter(project_root: Path) -> tuple[list[str], str]: + """Detect Go formatter. + + Go has a universal formatter (gofmt). goimports is preferred if available + because it also manages imports. + """ + if shutil.which("goimports"): + return ["goimports -w $file"], "goimports (auto-detected)" + if shutil.which("gofmt"): + return ["gofmt -w $file"], "gofmt (auto-detected)" + return ["gofmt -w $file"], "gofmt (default)" + + def _detect_ignore_paths(project_root: Path, language: str) -> tuple[list[Path], str]: """Detect paths to ignore during optimization. @@ -836,6 +888,7 @@ def _detect_ignore_paths(project_root: Path, language: str) -> tuple[list[Path], "javascript": ["node_modules", "dist", "build", ".next", ".nuxt", "coverage", ".cache"], "typescript": ["node_modules", "dist", "build", ".next", ".nuxt", "coverage", ".cache"], "java": ["target", "build", ".gradle", ".idea", "out"], + "go": ["vendor", "testdata"], } # Add default ignores @@ -900,6 +953,10 @@ def has_existing_config(project_root: Path) -> tuple[bool, str | None]: except Exception: pass + # Check Go projects — go.mod presence means "configured" + if (project_root / "go.mod").exists(): + return True, "go.mod" + # Check Java build files — zero-config: build file presence means "configured" for build_file in ("pom.xml", "build.gradle", "build.gradle.kts"): if (project_root / build_file).exists(): diff --git a/codeflash/version.py b/codeflash/version.py index 0f1baf8bc..b354a5b56 100644 --- a/codeflash/version.py +++ b/codeflash/version.py @@ -1,2 +1,2 @@ # These version placeholders will be replaced by uv-dynamic-versioning during build. -__version__ = "0.20.5.post151.dev0+95b62113" +__version__ = "0.20.5.post243.dev0+67cf12392" diff --git a/pyproject.toml b/pyproject.toml index 7701725ea..9fa6ad9a0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,6 +26,7 @@ dependencies = [ "tree-sitter-javascript>=0.23.0", "tree-sitter-typescript>=0.23.0", "tree-sitter-java>=0.23.0", + "tree-sitter-go>=0.23.0", "tree-sitter-groovy>=0.1.0", "tree-sitter-kotlin>=1.0.0", "pytest-timeout>=2.1.0", diff --git a/tests/test_languages/fixtures/go_project/calculator.go b/tests/test_languages/fixtures/go_project/calculator.go new file mode 100644 index 000000000..787a45f73 --- /dev/null +++ b/tests/test_languages/fixtures/go_project/calculator.go @@ -0,0 +1,53 @@ +package calculator + +import "math" + +// Add returns the sum of two integers. +func Add(a, b int) int { + return a + b +} + +func Subtract(a, b int) int { + return a - b +} + +// unexported function +func multiply(a, b int) int { + return a * b +} + +// no return type +func init() { + // package initialization +} + +func Fibonacci(n int) int { + if n <= 1 { + return n + } + return Fibonacci(n-1) + Fibonacci(n-2) +} + +// Hypotenuse calculates the hypotenuse of a right triangle. +func Hypotenuse(a, b float64) float64 { + return math.Sqrt(a*a + b*b) +} + +type Calculator struct { + Result float64 +} + +// AddFloat adds a value to the calculator result. +func (c *Calculator) AddFloat(val float64) float64 { + c.Result += val + return c.Result +} + +func (c Calculator) GetResult() float64 { + return c.Result +} + +// Reset zeroes the calculator. +func (c *Calculator) Reset() { + c.Result = 0 +} diff --git a/tests/test_languages/fixtures/go_project/calculator_test.go b/tests/test_languages/fixtures/go_project/calculator_test.go new file mode 100644 index 000000000..c8e6e4d66 --- /dev/null +++ b/tests/test_languages/fixtures/go_project/calculator_test.go @@ -0,0 +1,34 @@ +package calculator + +import "testing" + +func TestAdd(t *testing.T) { + result := Add(2, 3) + if result != 5 { + t.Errorf("Add(2, 3) = %d; want 5", result) + } +} + +func TestSubtract(t *testing.T) { + result := Subtract(5, 3) + if result != 2 { + t.Errorf("Subtract(5, 3) = %d; want 2", result) + } +} + +func TestFibonacci(t *testing.T) { + tests := []struct { + input int + expected int + }{ + {0, 0}, + {1, 1}, + {10, 55}, + } + for _, tt := range tests { + result := Fibonacci(tt.input) + if result != tt.expected { + t.Errorf("Fibonacci(%d) = %d; want %d", tt.input, result, tt.expected) + } + } +} diff --git a/tests/test_languages/fixtures/go_project/go.mod b/tests/test_languages/fixtures/go_project/go.mod new file mode 100644 index 000000000..910687fdd --- /dev/null +++ b/tests/test_languages/fixtures/go_project/go.mod @@ -0,0 +1,7 @@ +module github.com/example/myproject + +go 1.22.0 + +require ( + github.com/stretchr/testify v1.9.0 +) diff --git a/tests/test_languages/test_golang/__init__.py b/tests/test_languages/test_golang/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/test_languages/test_golang/test_config.py b/tests/test_languages/test_golang/test_config.py new file mode 100644 index 000000000..c42e3cada --- /dev/null +++ b/tests/test_languages/test_golang/test_config.py @@ -0,0 +1,47 @@ +from __future__ import annotations + +from pathlib import Path + +from codeflash.languages.golang.config import detect_go_project, is_go_project + +FIXTURES_DIR = Path(__file__).parent.parent / "fixtures" / "go_project" + + +class TestDetectGoProject: + def test_detects_project(self) -> None: + config = detect_go_project(FIXTURES_DIR) + assert config is not None + assert config.module_path == "github.com/example/myproject" + assert config.go_version == "1.22.0" + + def test_no_go_mod(self, tmp_path: Path) -> None: + config = detect_go_project(tmp_path) + assert config is None + + def test_minimal_go_mod(self, tmp_path: Path) -> None: + go_mod = tmp_path / "go.mod" + go_mod.write_text("module example.com/minimal\n\ngo 1.21\n", encoding="utf-8") + config = detect_go_project(tmp_path) + assert config is not None + assert config.module_path == "example.com/minimal" + assert config.go_version == "1.21" + + def test_vendor_detection(self, tmp_path: Path) -> None: + go_mod = tmp_path / "go.mod" + go_mod.write_text("module example.com/vendored\n\ngo 1.22\n", encoding="utf-8") + (tmp_path / "vendor").mkdir() + config = detect_go_project(tmp_path) + assert config is not None + assert config.has_vendor is True + + +class TestIsGoProject: + def test_with_go_mod(self) -> None: + assert is_go_project(FIXTURES_DIR) is True + + def test_without_go_files(self, tmp_path: Path) -> None: + assert is_go_project(tmp_path) is False + + def test_with_go_files_no_mod(self, tmp_path: Path) -> None: + (tmp_path / "main.go").write_text("package main\n", encoding="utf-8") + assert is_go_project(tmp_path) is True diff --git a/tests/test_languages/test_golang/test_discovery.py b/tests/test_languages/test_golang/test_discovery.py new file mode 100644 index 000000000..19d05e6c0 --- /dev/null +++ b/tests/test_languages/test_golang/test_discovery.py @@ -0,0 +1,163 @@ +from __future__ import annotations + +from pathlib import Path + +from codeflash.languages.base import FunctionFilterCriteria +from codeflash.languages.golang.discovery import discover_functions_from_source + +GO_SOURCE = """\ +package calculator + +import "math" + +// Add returns the sum of two integers. +func Add(a, b int) int { + return a + b +} + +func subtract(a, b int) int { + return a - b +} + +func init() { + println("setup") +} + +func main() { + println("hello") +} + +func noReturn() { + println("hello") +} + +type Calculator struct { + Result float64 +} + +func (c *Calculator) AddFloat(val float64) float64 { + c.Result += val + return c.Result +} + +func (c Calculator) GetResult() float64 { + return c.Result +} + +func Hypotenuse(a, b float64) float64 { + return math.Sqrt(a*a + b*b) +} +""" + +GO_TEST_SOURCE = """\ +package calculator + +import "testing" + +func TestAdd(t *testing.T) { + result := Add(2, 3) + if result != 5 { + t.Errorf("want 5, got %d", result) + } +} +""" + + +class TestDiscoverFunctions: + def test_discovers_exported_functions(self) -> None: + results = discover_functions_from_source(GO_SOURCE, Path("/project/calc.go")) + names = [f.function_name for f in results] + assert "Add" in names + assert "Hypotenuse" in names + + def test_discovers_unexported_functions(self) -> None: + results = discover_functions_from_source(GO_SOURCE, Path("/project/calc.go")) + names = [f.function_name for f in results] + assert "subtract" in names + assert "noReturn" in names + + def test_skips_init_and_main(self) -> None: + results = discover_functions_from_source(GO_SOURCE, Path("/project/calc.go")) + names = [f.function_name for f in results] + assert "init" not in names + assert "main" not in names + + def test_skips_test_files(self) -> None: + results = discover_functions_from_source(GO_TEST_SOURCE, Path("/project/calc_test.go")) + assert len(results) == 0 + + def test_discovers_methods(self) -> None: + results = discover_functions_from_source(GO_SOURCE, Path("/project/calc.go")) + methods = [f for f in results if f.is_method] + assert len(methods) == 2 + names = [m.function_name for m in methods] + assert "AddFloat" in names + assert "GetResult" in names + + def test_method_parents(self) -> None: + results = discover_functions_from_source(GO_SOURCE, Path("/project/calc.go")) + method = next(f for f in results if f.function_name == "AddFloat") + assert len(method.parents) == 1 + assert method.parents[0].name == "Calculator" + assert method.parents[0].type == "StructDef" + + def test_language_is_go(self) -> None: + results = discover_functions_from_source(GO_SOURCE, Path("/project/calc.go")) + for func in results: + assert func.language == "go" + + def test_is_async_false(self) -> None: + results = discover_functions_from_source(GO_SOURCE, Path("/project/calc.go")) + for func in results: + assert func.is_async is False + + +class TestDiscoverWithFilters: + def test_filter_export_only(self) -> None: + criteria = FunctionFilterCriteria(require_export=True, require_return=False) + results = discover_functions_from_source(GO_SOURCE, Path("/project/calc.go"), criteria) + names = [f.function_name for f in results] + assert "Add" in names + assert "Hypotenuse" in names + assert "subtract" not in names + assert "noReturn" not in names + + def test_filter_require_return(self) -> None: + criteria = FunctionFilterCriteria(require_export=False, require_return=True) + results = discover_functions_from_source(GO_SOURCE, Path("/project/calc.go"), criteria) + names = [f.function_name for f in results] + assert "Add" in names + assert "noReturn" not in names + + def test_filter_exclude_methods(self) -> None: + criteria = FunctionFilterCriteria(require_export=False, require_return=False, include_methods=False) + results = discover_functions_from_source(GO_SOURCE, Path("/project/calc.go"), criteria) + methods = [f for f in results if f.is_method] + assert len(methods) == 0 + + def test_filter_exclude_pattern(self) -> None: + criteria = FunctionFilterCriteria( + require_export=False, require_return=False, exclude_patterns=["subtract"] + ) + results = discover_functions_from_source(GO_SOURCE, Path("/project/calc.go"), criteria) + names = [f.function_name for f in results] + assert "subtract" not in names + assert "Add" in names + + def test_filter_include_pattern(self) -> None: + criteria = FunctionFilterCriteria( + require_export=False, require_return=False, include_patterns=["Add*"] + ) + results = discover_functions_from_source(GO_SOURCE, Path("/project/calc.go"), criteria) + names = [f.function_name for f in results] + assert "Add" in names + assert "AddFloat" in names + assert "subtract" not in names + assert "Hypotenuse" not in names + + def test_filter_min_lines(self) -> None: + criteria = FunctionFilterCriteria(require_export=False, require_return=False, min_lines=4) + results = discover_functions_from_source(GO_SOURCE, Path("/project/calc.go"), criteria) + for func in results: + line_count = func.ending_line - func.starting_line + 1 + assert line_count >= 4 diff --git a/tests/test_languages/test_golang/test_parser.py b/tests/test_languages/test_golang/test_parser.py new file mode 100644 index 000000000..5ce663227 --- /dev/null +++ b/tests/test_languages/test_golang/test_parser.py @@ -0,0 +1,218 @@ +from __future__ import annotations + +from codeflash.languages.golang.parser import GoAnalyzer + +GO_SOURCE = """\ +package calculator + +import "math" + +// Add returns the sum of two integers. +func Add(a, b int) int { + return a + b +} + +func subtract(a, b int) int { + return a - b +} + +func noReturn() { + println("hello") +} + +type Calculator struct { + Result float64 +} + +// AddFloat adds a value. +func (c *Calculator) AddFloat(val float64) float64 { + c.Result += val + return c.Result +} + +func (c Calculator) GetResult() float64 { + return c.Result +} + +// Reset zeroes the calculator. +func (c *Calculator) Reset() { + c.Result = 0 +} + +type Adder interface { + Add(a, b int) int +} +""" + + +class TestGoAnalyzerFunctions: + def test_find_functions(self) -> None: + analyzer = GoAnalyzer() + functions = analyzer.find_functions(GO_SOURCE) + names = [f.name for f in functions] + assert "Add" in names + assert "subtract" in names + assert "noReturn" in names + + def test_exported_detection(self) -> None: + analyzer = GoAnalyzer() + functions = analyzer.find_functions(GO_SOURCE) + by_name = {f.name: f for f in functions} + assert by_name["Add"].is_exported is True + assert by_name["subtract"].is_exported is False + + def test_return_type_detection(self) -> None: + analyzer = GoAnalyzer() + functions = analyzer.find_functions(GO_SOURCE) + by_name = {f.name: f for f in functions} + assert by_name["Add"].has_return_type is True + assert by_name["noReturn"].has_return_type is False + + def test_doc_comment_detection(self) -> None: + analyzer = GoAnalyzer() + functions = analyzer.find_functions(GO_SOURCE) + by_name = {f.name: f for f in functions} + assert by_name["Add"].doc_start_line is not None + assert by_name["subtract"].doc_start_line is None + + def test_line_numbers(self) -> None: + analyzer = GoAnalyzer() + functions = analyzer.find_functions(GO_SOURCE) + by_name = {f.name: f for f in functions} + add_func = by_name["Add"] + assert add_func.starting_line == 6 + assert add_func.ending_line == 8 + + +class TestGoAnalyzerMethods: + def test_find_methods(self) -> None: + analyzer = GoAnalyzer() + methods = analyzer.find_methods(GO_SOURCE) + names = [m.name for m in methods] + assert "AddFloat" in names + assert "GetResult" in names + assert "Reset" in names + + def test_receiver_detection(self) -> None: + analyzer = GoAnalyzer() + methods = analyzer.find_methods(GO_SOURCE) + by_name = {m.name: m for m in methods} + assert by_name["AddFloat"].receiver_name == "Calculator" + assert by_name["AddFloat"].receiver_is_pointer is True + assert by_name["GetResult"].receiver_name == "Calculator" + assert by_name["GetResult"].receiver_is_pointer is False + + def test_method_doc_comment(self) -> None: + analyzer = GoAnalyzer() + methods = analyzer.find_methods(GO_SOURCE) + by_name = {m.name: m for m in methods} + assert by_name["AddFloat"].doc_start_line is not None + assert by_name["Reset"].doc_start_line is not None + assert by_name["GetResult"].doc_start_line is None + + def test_method_exported(self) -> None: + analyzer = GoAnalyzer() + methods = analyzer.find_methods(GO_SOURCE) + for m in methods: + assert m.is_exported is True + + +class TestGoAnalyzerStructs: + def test_find_structs(self) -> None: + analyzer = GoAnalyzer() + structs = analyzer.find_structs(GO_SOURCE) + assert len(structs) == 1 + assert structs[0].name == "Calculator" + assert len(structs[0].fields) > 0 + + def test_struct_field_content(self) -> None: + analyzer = GoAnalyzer() + structs = analyzer.find_structs(GO_SOURCE) + field_text = " ".join(structs[0].fields) + assert "Result" in field_text + assert "float64" in field_text + + +class TestGoAnalyzerInterfaces: + def test_find_interfaces(self) -> None: + analyzer = GoAnalyzer() + interfaces = analyzer.find_interfaces(GO_SOURCE) + assert len(interfaces) == 1 + assert interfaces[0].name == "Adder" + assert len(interfaces[0].methods) > 0 + + +class TestGoAnalyzerImports: + def test_find_imports(self) -> None: + analyzer = GoAnalyzer() + imports = analyzer.find_imports(GO_SOURCE) + assert len(imports) == 1 + assert imports[0].path == "math" + assert imports[0].alias is None + + def test_multi_import(self) -> None: + source = '''\ +package main + +import ( + "fmt" + "os" + str "strings" +) + +func Main() string { + return "hello" +} +''' + analyzer = GoAnalyzer() + imports = analyzer.find_imports(source) + paths = {i.path for i in imports} + assert paths == {"fmt", "os", "strings"} + aliases = {i.path: i.alias for i in imports} + assert aliases["strings"] == "str" + assert aliases["fmt"] is None + + +class TestGoAnalyzerPackage: + def test_find_package_name(self) -> None: + analyzer = GoAnalyzer() + assert analyzer.find_package_name(GO_SOURCE) == "calculator" + + def test_find_package_name_main(self) -> None: + analyzer = GoAnalyzer() + assert analyzer.find_package_name("package main\n\nfunc main() {}") == "main" + + +class TestGoAnalyzerSyntax: + def test_valid_syntax(self) -> None: + analyzer = GoAnalyzer() + assert analyzer.validate_syntax(GO_SOURCE) is True + + def test_invalid_syntax(self) -> None: + analyzer = GoAnalyzer() + assert analyzer.validate_syntax("func {{{invalid") is False + + +class TestGoAnalyzerExtract: + def test_extract_function_source(self) -> None: + analyzer = GoAnalyzer() + source = analyzer.extract_function_source(GO_SOURCE, "Add") + assert source is not None + assert "func Add" in source + assert "return a + b" in source + + def test_extract_function_source_with_doc(self) -> None: + analyzer = GoAnalyzer() + source = analyzer.extract_function_source(GO_SOURCE, "Add") + assert source is not None + assert "// Add returns" in source + + def test_extract_method_source(self) -> None: + analyzer = GoAnalyzer() + source = analyzer.extract_function_source(GO_SOURCE, "AddFloat", receiver_type="Calculator") + assert source is not None + assert "func (c *Calculator) AddFloat" in source + + def test_extract_nonexistent(self) -> None: + analyzer = GoAnalyzer() + assert analyzer.extract_function_source(GO_SOURCE, "DoesNotExist") is None diff --git a/tests/test_languages/test_golang/test_support.py b/tests/test_languages/test_golang/test_support.py new file mode 100644 index 000000000..df67d84eb --- /dev/null +++ b/tests/test_languages/test_golang/test_support.py @@ -0,0 +1,120 @@ +from __future__ import annotations + +from pathlib import Path + +from codeflash.languages.golang.support import GoSupport +from codeflash.languages.language_enum import Language +from codeflash.languages.registry import clear_cache, clear_registry, get_language_support + + +class TestGoSupportProperties: + def test_language(self) -> None: + support = GoSupport() + assert support.language == Language.GO + + def test_file_extensions(self) -> None: + support = GoSupport() + assert support.file_extensions == (".go",) + + def test_default_file_extension(self) -> None: + support = GoSupport() + assert support.default_file_extension == ".go" + + def test_test_framework(self) -> None: + support = GoSupport() + assert support.test_framework == "go-test" + + def test_comment_prefix(self) -> None: + support = GoSupport() + assert support.comment_prefix == "//" + + def test_valid_test_frameworks(self) -> None: + support = GoSupport() + assert support.valid_test_frameworks == ("go-test",) + + def test_serialization_format(self) -> None: + support = GoSupport() + assert support.test_result_serialization_format == "json" + + def test_get_test_file_suffix(self) -> None: + support = GoSupport() + assert support.get_test_file_suffix() == "_test.go" + + def test_dir_excludes(self) -> None: + support = GoSupport() + assert "vendor" in support.dir_excludes + assert "testdata" in support.dir_excludes + + +class TestGoSupportRegistration: + def test_lookup_by_language_enum(self) -> None: + support = get_language_support(Language.GO) + assert support.language == Language.GO + + def test_lookup_by_extension(self) -> None: + support = get_language_support(Path("main.go")) + assert support.language == Language.GO + + def test_lookup_by_string(self) -> None: + support = get_language_support("go") + assert support.language == Language.GO + + def test_lookup_by_dot_extension(self) -> None: + support = get_language_support(".go") + assert support.language == Language.GO + + +class TestGoSupportDiscoverFunctions: + def test_discovers_functions(self) -> None: + support = GoSupport() + source = """\ +package calc + +func Add(a, b int) int { + return a + b +} + +func subtract(a, b int) int { + return a - b +} +""" + results = support.discover_functions(source, Path("/project/calc.go")) + names = [f.function_name for f in results] + assert "Add" in names + assert "subtract" in names + + def test_validate_syntax_valid(self) -> None: + support = GoSupport() + assert support.validate_syntax("package main\n\nfunc main() {}") is True + + def test_validate_syntax_invalid(self) -> None: + support = GoSupport() + assert support.validate_syntax("func {{{ invalid") is False + + +class TestGoSupportHelpers: + def test_find_test_root(self) -> None: + support = GoSupport() + root = Path("/project") + assert support.find_test_root(root) == root + + def test_get_runtime_files(self) -> None: + support = GoSupport() + assert support.get_runtime_files() == [] + + def test_instrument_for_behavior_passthrough(self) -> None: + support = GoSupport() + source = "package main\n\nfunc main() {}\n" + assert support.instrument_for_behavior(source, []) == source + + def test_instrument_for_benchmarking_passthrough(self) -> None: + support = GoSupport() + source = "package main\n\nfunc Test() {}\n" + result = support.instrument_for_benchmarking(source, None) # type: ignore[arg-type] + assert result == source + + def test_get_test_dir_for_source(self) -> None: + support = GoSupport() + source_file = Path("/project/pkg/calc.go") + result = support.get_test_dir_for_source(Path("/project"), source_file) + assert result == Path("/project/pkg") diff --git a/uv.lock b/uv.lock index c059d601e..e357561a7 100644 --- a/uv.lock +++ b/uv.lock @@ -487,6 +487,8 @@ dependencies = [ { name = "tomlkit" }, { name = "tree-sitter", version = "0.23.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.10'" }, { name = "tree-sitter", version = "0.25.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.10'" }, + { name = "tree-sitter-go", version = "0.23.4", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.10'" }, + { name = "tree-sitter-go", version = "0.25.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.10'" }, { name = "tree-sitter-groovy" }, { name = "tree-sitter-java" }, { name = "tree-sitter-javascript", version = "0.23.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.10'" }, @@ -593,6 +595,7 @@ requires-dist = [ { name = "sentry-sdk", specifier = ">=1.40.6,<3.0.0" }, { name = "tomlkit", specifier = ">=0.11.7" }, { name = "tree-sitter", specifier = ">=0.23.0" }, + { name = "tree-sitter-go", specifier = ">=0.23.0" }, { name = "tree-sitter-groovy", specifier = ">=0.1.0" }, { name = "tree-sitter-java", specifier = ">=0.23.0" }, { name = "tree-sitter-javascript", specifier = ">=0.23.0" }, @@ -5826,6 +5829,56 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/a6/6e/e64621037357acb83d912276ffd30a859ef117f9c680f2e3cb955f47c680/tree_sitter-0.25.2-cp314-cp314-win_arm64.whl", hash = "sha256:b8d4429954a3beb3e844e2872610d2a4800ba4eb42bb1990c6a4b1949b18459f", size = 117470, upload-time = "2025-09-25T17:37:58.431Z" }, ] +[[package]] +name = "tree-sitter-go" +version = "0.23.4" +source = { registry = "https://pypi.org/simple" } +resolution-markers = [ + "python_full_version >= '3.9.2' and python_full_version < '3.10'", + "python_full_version < '3.9.2'", +] +sdist = { url = "https://files.pythonhosted.org/packages/2a/7f/13b83b877043faadecb5cb70982589ed79e7ebd78f8d239128dc6b23f595/tree_sitter_go-0.23.4.tar.gz", hash = "sha256:0ebff99820657066bec21690623a14c74d9e57a903f95f0837be112ddadf1a52", size = 85686, upload-time = "2024-11-24T19:37:18.235Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/8e/2d/070137fa47215265459bef90b27902471ddcd61530c3331437bcd9ba93cd/tree_sitter_go-0.23.4-cp39-abi3-macosx_10_9_x86_64.whl", hash = "sha256:c9320f87a05cd47fa0f627b9329bbc09b7ed90de8fe4f5882aed318d6e19962d", size = 45689, upload-time = "2024-11-24T19:37:07.228Z" }, + { url = "https://files.pythonhosted.org/packages/37/8a/9e1dc1c1cefcf060b0105fb294c399ec4808fa1f9e2cbf0463f991b28aed/tree_sitter_go-0.23.4-cp39-abi3-macosx_11_0_arm64.whl", hash = "sha256:914e63d16b36ab0e4f52b031e574b82d17d0bbfecca138ae83e887a1cf5b71ac", size = 47364, upload-time = "2024-11-24T19:37:08.835Z" }, + { url = "https://files.pythonhosted.org/packages/d6/8a/6c1f26d25cfcedd22d452a299bf9a753d97d5ebd8db4d2047f2002b5b301/tree_sitter_go-0.23.4-cp39-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:330ecbb38d6ea4ef41eba2d473056889705e64f6a51c2fb613de05b1bcb5ba22", size = 66543, upload-time = "2024-11-24T19:37:10.738Z" }, + { url = "https://files.pythonhosted.org/packages/f2/03/d82c4b61db9e29b272aed6742cde37244312e63860048fd66d927bfc4f50/tree_sitter_go-0.23.4-cp39-abi3-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dd14d23056ae980debfccc0db67d0a168da03792ca2968b1b5dd58ce288084e7", size = 65498, upload-time = "2024-11-24T19:37:12.375Z" }, + { url = "https://files.pythonhosted.org/packages/03/15/c37db75ff873042f74b1eec214fda84dfff985406ccdc94e4d2be9a6888b/tree_sitter_go-0.23.4-cp39-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:c3b40912487fdb78c4028860dd79493a521ffca0104f209849823358db3618a0", size = 64391, upload-time = "2024-11-24T19:37:13.944Z" }, + { url = "https://files.pythonhosted.org/packages/e3/cc/a32de9c9391a859dd5fc938922bb6cd5b7d6114c88998411433e06fe4572/tree_sitter_go-0.23.4-cp39-abi3-win_amd64.whl", hash = "sha256:ae4b231cad2ef76401d33617879cda6321c4d0853f7fd98cb5654c50a218effb", size = 46954, upload-time = "2024-11-24T19:37:14.953Z" }, + { url = "https://files.pythonhosted.org/packages/ec/35/a533173cd846385796eed56dde62eb908b3500e6308fddb4ddc30dc227b8/tree_sitter_go-0.23.4-cp39-abi3-win_arm64.whl", hash = "sha256:2ac907362a3c347145dc1da0858248546500a323de90d2cb76d2a3fdbfc8da25", size = 45276, upload-time = "2024-11-24T19:37:16.623Z" }, +] + +[[package]] +name = "tree-sitter-go" +version = "0.25.0" +source = { registry = "https://pypi.org/simple" } +resolution-markers = [ + "python_full_version >= '3.14' and sys_platform == 'win32'", + "python_full_version >= '3.14' and sys_platform == 'emscripten'", + "python_full_version >= '3.14' and sys_platform != 'emscripten' and sys_platform != 'win32'", + "python_full_version == '3.13.*' and sys_platform == 'win32'", + "python_full_version == '3.13.*' and sys_platform == 'emscripten'", + "python_full_version == '3.13.*' and sys_platform != 'emscripten' and sys_platform != 'win32'", + "python_full_version == '3.12.*' and sys_platform == 'win32'", + "python_full_version == '3.12.*' and sys_platform == 'emscripten'", + "python_full_version == '3.12.*' and sys_platform != 'emscripten' and sys_platform != 'win32'", + "python_full_version == '3.11.*' and sys_platform == 'win32'", + "python_full_version == '3.11.*' and sys_platform == 'emscripten'", + "python_full_version == '3.11.*' and sys_platform != 'emscripten' and sys_platform != 'win32'", + "python_full_version == '3.10.*'", +] +sdist = { url = "https://files.pythonhosted.org/packages/01/05/727308adbbc79bcb1c92fc0ea10556a735f9d0f0a5435a18f59d40f7fd77/tree_sitter_go-0.25.0.tar.gz", hash = "sha256:a7466e9b8d94dda94cae8d91629f26edb2d26166fd454d4831c3bf6dfa2e8d68", size = 93890, upload-time = "2025-08-29T06:20:25.044Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ca/aa/0984707acc2b9bb461fe4a41e7e0fc5b2b1e245c32820f0c83b3c602957c/tree_sitter_go-0.25.0-cp310-abi3-macosx_10_9_x86_64.whl", hash = "sha256:b852993063a3429a443e7bd0aa376dd7dd329d595819fabf56ac4cf9d7257b54", size = 47117, upload-time = "2025-08-29T06:20:14.286Z" }, + { url = "https://files.pythonhosted.org/packages/32/16/dd4cb124b35e99239ab3624225da07d4cb8da4d8564ed81d03fcb3a6ba9f/tree_sitter_go-0.25.0-cp310-abi3-macosx_11_0_arm64.whl", hash = "sha256:503b81a2b4c31e302869a1de3a352ad0912ccab3df9ac9950197b0a9ceeabd8f", size = 48674, upload-time = "2025-08-29T06:20:17.557Z" }, + { url = "https://files.pythonhosted.org/packages/86/fb/b30d63a08044115d8b8bd196c6c2ab4325fb8db5757249a4ef0563966e2e/tree_sitter_go-0.25.0-cp310-abi3-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:04b3b3cb4aff18e74e28d49b716c6f24cb71ddfdd66768987e26e4d0fa812f74", size = 66418, upload-time = "2025-08-29T06:20:18.345Z" }, + { url = "https://files.pythonhosted.org/packages/26/21/d3d88a30ad007419b2c97b3baeeef7431407faf9f686195b6f1cad0aedf9/tree_sitter_go-0.25.0-cp310-abi3-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:148255aca2f54b90d48c48a9dbb4c7faad6cad310a980b2c5a5a9822057ed145", size = 72006, upload-time = "2025-08-29T06:20:19.14Z" }, + { url = "https://files.pythonhosted.org/packages/cd/d0/0dd6442353ced8a88bbda9e546f4ea29e381b59b5a40b122e5abb586bb6c/tree_sitter_go-0.25.0-cp310-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:4d338116cdf8a6c6ff990d2441929b41323ef17c710407abe0993c13417d6aad", size = 70603, upload-time = "2025-08-29T06:20:21.544Z" }, + { url = "https://files.pythonhosted.org/packages/01/e2/ee5e09f63504fc286539535d374d2eaa0e7d489b80f8f744bb3962aff22a/tree_sitter_go-0.25.0-cp310-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:5608e089d2a29fa8d2b327abeb2ad1cdb8e223c440a6b0ceab0d3fa80bdeebae", size = 66088, upload-time = "2025-08-29T06:20:22.336Z" }, + { url = "https://files.pythonhosted.org/packages/6e/b6/d9142583374720e79aca9ccb394b3795149a54c012e1dfd80738df2d984e/tree_sitter_go-0.25.0-cp310-abi3-win_amd64.whl", hash = "sha256:30d4ada57a223dfc2c32d942f44d284d40f3d1215ddcf108f96807fd36d53022", size = 48152, upload-time = "2025-08-29T06:20:23.089Z" }, + { url = "https://files.pythonhosted.org/packages/9e/00/9a2638e7339236f5b01622952a4d71c1474dd3783d1982a89555fc1f03b1/tree_sitter_go-0.25.0-cp310-abi3-win_arm64.whl", hash = "sha256:d5d62362059bf79997340773d47cc7e7e002883b527a05cca829c46e40b70ded", size = 46752, upload-time = "2025-08-29T06:20:24.235Z" }, +] + [[package]] name = "tree-sitter-groovy" version = "0.1.2" From ba560308bc3f6516e7604e11957150437372b62a Mon Sep 17 00:00:00 2001 From: ali Date: Thu, 23 Apr 2026 14:34:53 +0200 Subject: [PATCH 02/10] code replacer / extraction & base test execution --- codeflash/languages/golang/comparator.py | 102 ++++++++ codeflash/languages/golang/context.py | 127 ++++++++++ codeflash/languages/golang/formatter.py | 111 ++++++++ codeflash/languages/golang/parser.py | 10 +- codeflash/languages/golang/replacement.py | 142 +++++++++++ codeflash/languages/golang/support.py | 96 +++++-- codeflash/languages/golang/test_discovery.py | 97 +++++++ codeflash/languages/golang/test_runner.py | 238 ++++++++++++++++++ .../test_golang/test_comparator.py | 109 ++++++++ .../test_golang/test_context.py | 225 +++++++++++++++++ .../test_golang/test_formatter.py | 100 ++++++++ .../test_golang/test_replacement.py | 200 +++++++++++++++ .../test_golang/test_test_discovery.py | 159 ++++++++++++ .../test_golang/test_test_runner.py | 92 +++++++ 14 files changed, 1785 insertions(+), 23 deletions(-) create mode 100644 codeflash/languages/golang/comparator.py create mode 100644 codeflash/languages/golang/context.py create mode 100644 codeflash/languages/golang/formatter.py create mode 100644 codeflash/languages/golang/replacement.py create mode 100644 codeflash/languages/golang/test_discovery.py create mode 100644 codeflash/languages/golang/test_runner.py create mode 100644 tests/test_languages/test_golang/test_comparator.py create mode 100644 tests/test_languages/test_golang/test_context.py create mode 100644 tests/test_languages/test_golang/test_formatter.py create mode 100644 tests/test_languages/test_golang/test_replacement.py create mode 100644 tests/test_languages/test_golang/test_test_discovery.py create mode 100644 tests/test_languages/test_golang/test_test_runner.py diff --git a/codeflash/languages/golang/comparator.py b/codeflash/languages/golang/comparator.py new file mode 100644 index 000000000..0a3b23850 --- /dev/null +++ b/codeflash/languages/golang/comparator.py @@ -0,0 +1,102 @@ +from __future__ import annotations + +import json +import logging +from dataclasses import dataclass +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from pathlib import Path + +logger = logging.getLogger(__name__) + + +@dataclass +class TestDiff: + test_name: str + original_passed: bool + candidate_passed: bool + message: str + + +def compare_test_results( + original_results_path: Path, + candidate_results_path: Path, + project_root: Path | None = None, + project_classpath: str | None = None, +) -> tuple[bool, list[TestDiff]]: + original = _load_results(original_results_path) + candidate = _load_results(candidate_results_path) + + diffs: list[TestDiff] = [] + + all_tests = set(original.keys()) | set(candidate.keys()) + + for test_name in sorted(all_tests): + orig = original.get(test_name) + cand = candidate.get(test_name) + + if orig is None: + diffs.append( + TestDiff( + test_name=test_name, + original_passed=False, + candidate_passed=cand or False, + message=f"Test {test_name} only present in candidate results", + ) + ) + continue + + if cand is None: + diffs.append( + TestDiff( + test_name=test_name, + original_passed=orig, + candidate_passed=False, + message=f"Test {test_name} missing from candidate results", + ) + ) + continue + + if orig != cand: + diffs.append( + TestDiff( + test_name=test_name, + original_passed=orig, + candidate_passed=cand, + message=f"Test {test_name}: original {'passed' if orig else 'failed'}, candidate {'passed' if cand else 'failed'}", + ) + ) + + are_equivalent = len(diffs) == 0 + return are_equivalent, diffs + + +def _load_results(path: Path) -> dict[str, bool]: + results: dict[str, bool] = {} + try: + content = path.read_text(encoding="utf-8") + except Exception: + logger.debug("Could not read results file %s", path) + return results + + for line in content.splitlines(): + line = line.strip() + if not line: + continue + try: + event = json.loads(line) + except json.JSONDecodeError: + continue + + action = event.get("Action") + test_name = event.get("Test") + if test_name is None: + continue + + if action == "pass": + results[test_name] = True + elif action == "fail": + results[test_name] = False + + return results diff --git a/codeflash/languages/golang/context.py b/codeflash/languages/golang/context.py new file mode 100644 index 000000000..499da4ba1 --- /dev/null +++ b/codeflash/languages/golang/context.py @@ -0,0 +1,127 @@ +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING + +from codeflash.languages.base import CodeContext, HelperFunction, Language +from codeflash.languages.golang.parser import GoAnalyzer + +if TYPE_CHECKING: + from pathlib import Path + + from codeflash.discovery.functions_to_optimize import FunctionToOptimize + +logger = logging.getLogger(__name__) + + +def extract_code_context( + function: FunctionToOptimize, + project_root: Path, + module_root: Path | None = None, + analyzer: GoAnalyzer | None = None, +) -> CodeContext: + analyzer = analyzer or GoAnalyzer() + + try: + source = function.file_path.read_text(encoding="utf-8") + except Exception: + logger.exception("Failed to read %s", function.file_path) + return CodeContext(target_code="", target_file=function.file_path, language=Language.GO) + + receiver_type = _get_receiver_type(function) + target_code = analyzer.extract_function_source(source, function.function_name, receiver_type=receiver_type) + if target_code is None: + target_code = "" + + imports = analyzer.find_imports(source) + import_lines = [_import_to_line(imp) for imp in imports] + + read_only_context = "" + if receiver_type: + read_only_context = _extract_struct_context(source, receiver_type, analyzer) + + helpers = find_helper_functions(source, function, analyzer) + + return CodeContext( + target_code=target_code, + target_file=function.file_path, + helper_functions=helpers, + read_only_context=read_only_context, + imports=import_lines, + language=Language.GO, + ) + + +def find_helper_functions( + source: str, function: FunctionToOptimize, analyzer: GoAnalyzer | None = None +) -> list[HelperFunction]: + analyzer = analyzer or GoAnalyzer() + target_name = function.function_name + + functions = analyzer.find_functions(source) + methods = analyzer.find_methods(source) + + helpers: list[HelperFunction] = [] + + for func in functions: + if func.name == target_name: + continue + if func.name in ("init", "main"): + continue + extracted = analyzer.extract_function_source(source, func.name) + if extracted is None: + continue + helpers.append( + HelperFunction( + name=func.name, + qualified_name=func.name, + file_path=function.file_path, + source_code=extracted, + start_line=func.starting_line, + end_line=func.ending_line, + ) + ) + + receiver_type = _get_receiver_type(function) + for method in methods: + if method.name == target_name and method.receiver_name == receiver_type: + continue + extracted = analyzer.extract_function_source(source, method.name, receiver_type=method.receiver_name) + if extracted is None: + continue + qualified = f"{method.receiver_name}.{method.name}" + helpers.append( + HelperFunction( + name=method.name, + qualified_name=qualified, + file_path=function.file_path, + source_code=extracted, + start_line=method.starting_line, + end_line=method.ending_line, + ) + ) + + return helpers + + +def _get_receiver_type(function: FunctionToOptimize) -> str | None: + if function.parents: + return function.parents[0].name + return None + + +def _import_to_line(imp: object) -> str: + path = getattr(imp, "path", "") + alias = getattr(imp, "alias", None) + if alias: + return f'{alias} "{path}"' + return f'"{path}"' + + +def _extract_struct_context(source: str, struct_name: str, analyzer: GoAnalyzer) -> str: + structs = analyzer.find_structs(source) + for s in structs: + if s.name == struct_name: + lines = source.splitlines() + return "\n".join(lines[s.starting_line - 1 : s.ending_line]) + return "" diff --git a/codeflash/languages/golang/formatter.py b/codeflash/languages/golang/formatter.py new file mode 100644 index 000000000..26dbe8715 --- /dev/null +++ b/codeflash/languages/golang/formatter.py @@ -0,0 +1,111 @@ +from __future__ import annotations + +import logging +import shutil +import subprocess +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from pathlib import Path + +logger = logging.getLogger(__name__) + + +def format_go_code(source: str, file_path: Path | None = None) -> str: + gofmt = shutil.which("gofmt") + if gofmt is None: + goimports = shutil.which("goimports") + if goimports is not None: + gofmt = goimports + else: + logger.debug("No Go formatter found (gofmt/goimports), returning source unchanged") + return source + + try: + result = subprocess.run([gofmt], input=source, capture_output=True, text=True, timeout=15, check=False) + if result.returncode == 0: + return result.stdout + logger.debug("gofmt failed: %s", result.stderr) + except subprocess.TimeoutExpired: + logger.warning("gofmt timed out") + except Exception: + logger.debug("gofmt error", exc_info=True) + + return source + + +def normalize_go_code(source: str) -> str: + lines = source.splitlines() + normalized: list[str] = [] + in_block_comment = False + + for line in lines: + if in_block_comment: + if "*/" in line: + in_block_comment = False + line = line[line.index("*/") + 2 :] + else: + continue + + if "//" in line: + comment_pos = _find_line_comment_pos(line) + if comment_pos >= 0: + line = line[:comment_pos] + + if "/*" in line: + start_idx = line.index("/*") + if "*/" in line[start_idx:]: + end_idx = line.index("*/", start_idx) + line = line[:start_idx] + line[end_idx + 2 :] + else: + in_block_comment = True + line = line[:start_idx] + + stripped = line.strip() + if stripped: + normalized.append(stripped) + + return "\n".join(normalized) + + +def _find_line_comment_pos(line: str) -> int: + in_string = False + in_rune = False + escape_next = False + in_raw_string = False + + i = 0 + while i < len(line): + char = line[i] + + if escape_next: + escape_next = False + i += 1 + continue + + if in_raw_string: + if char == "`": + in_raw_string = False + i += 1 + continue + + if char == "`": + in_raw_string = True + i += 1 + continue + + if char == "\\": + escape_next = True + i += 1 + continue + + if char == '"' and not in_rune: + in_string = not in_string + elif char == "'" and not in_string: + in_rune = not in_rune + elif not in_string and not in_rune and i < len(line) - 1 and line[i : i + 2] == "//": + return i + + i += 1 + + return -1 diff --git a/codeflash/languages/golang/parser.py b/codeflash/languages/golang/parser.py index ce0f173e6..05f378ad7 100644 --- a/codeflash/languages/golang/parser.py +++ b/codeflash/languages/golang/parser.py @@ -87,6 +87,10 @@ def __init__(self) -> None: self._source_bytes: bytes | None = None self._tree: Tree | None = None + @property + def last_tree(self) -> Tree | None: + return self._tree + def parse(self, source: str) -> Tree: self._source_bytes = source.encode("utf-8") self._tree = self._parser.parse(self._source_bytes) @@ -221,7 +225,7 @@ def _parse_method_node(self, node: Node) -> GoMethodNode | None: receiver_node = node.child_by_field_name("receiver") if receiver_node is None: return None - receiver_name, receiver_is_pointer = self._parse_receiver(receiver_node) + receiver_name, receiver_is_pointer = self.parse_receiver(receiver_node) if receiver_name is None: return None @@ -240,7 +244,7 @@ def _parse_method_node(self, node: Node) -> GoMethodNode | None: doc_start_line=doc_line, ) - def _parse_receiver(self, receiver_node: Node) -> tuple[str | None, bool]: + def parse_receiver(self, receiver_node: Node) -> tuple[str | None, bool]: for param in _children_of_type(receiver_node, "parameter_declaration"): type_node = param.child_by_field_name("type") if type_node is None: @@ -284,7 +288,7 @@ def extract_function_source(self, source: str, func_name: str, receiver_type: st continue recv_node = node.child_by_field_name("receiver") if recv_node is not None: - recv_name, _ = self._parse_receiver(recv_node) + recv_name, _ = self.parse_receiver(recv_node) if recv_name == receiver_type: return self._get_source_with_doc(node) return None diff --git a/codeflash/languages/golang/replacement.py b/codeflash/languages/golang/replacement.py new file mode 100644 index 000000000..63dc1dcc4 --- /dev/null +++ b/codeflash/languages/golang/replacement.py @@ -0,0 +1,142 @@ +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING + +from codeflash.languages.golang.parser import GoAnalyzer + +if TYPE_CHECKING: + from codeflash.discovery.functions_to_optimize import FunctionToOptimize + +logger = logging.getLogger(__name__) + + +def replace_function( + source: str, function: FunctionToOptimize, new_source: str, analyzer: GoAnalyzer | None = None +) -> str: + analyzer = analyzer or GoAnalyzer() + receiver_type = function.parents[0].name if function.parents else None + + tree = analyzer.parse(source) + target_node = None + + for node in tree.root_node.children: + if receiver_type is None and node.type == "function_declaration": + name_node = node.child_by_field_name("name") + if name_node is not None and analyzer.get_node_text(name_node) == function.function_name: + target_node = node + break + elif receiver_type is not None and node.type == "method_declaration": + name_node = node.child_by_field_name("name") + if name_node is None or analyzer.get_node_text(name_node) != function.function_name: + continue + recv_node = node.child_by_field_name("receiver") + if recv_node is not None: + recv_name, _ = analyzer.parse_receiver(recv_node) + if recv_name == receiver_type: + target_node = node + break + + if target_node is None: + logger.warning("Could not find function %s in source for replacement", function.function_name) + return source + + lines = source.splitlines(keepends=True) + doc_line = _find_doc_comment_start(target_node) + start_line = (doc_line if doc_line is not None else target_node.start_point.row + 1) - 1 + end_line = target_node.end_point.row + 1 + + new_source_stripped = new_source.rstrip("\n") + "\n" + + result_lines = [*lines[:start_line], new_source_stripped, *lines[end_line:]] + return "".join(result_lines) + + +def add_global_declarations(optimized_code: str, original_source: str, analyzer: GoAnalyzer | None = None) -> str: + analyzer = analyzer or GoAnalyzer() + + opt_imports = analyzer.find_imports(optimized_code) + orig_imports = analyzer.find_imports(original_source) + orig_paths = {imp.path for imp in orig_imports} + + new_imports = [imp for imp in opt_imports if imp.path not in orig_paths] + if not new_imports: + return original_source + + lines = original_source.splitlines(keepends=True) + + import_block_end = _find_import_block_end(original_source, analyzer) + + new_import_lines = [] + for imp in new_imports: + if imp.alias: + new_import_lines.append(f'\t{imp.alias} "{imp.path}"\n') + else: + new_import_lines.append(f'\t"{imp.path}"\n') + + if orig_imports: + last_import = max(orig_imports, key=lambda i: i.ending_line) + insert_at = last_import.ending_line + for node in analyzer.last_tree.root_node.children if analyzer.last_tree else []: + if node.type == "import_declaration": + for child in node.children: + if child.type == "import_spec_list": + close_paren_line = child.end_point.row + insert_at = close_paren_line + break + return "".join([*lines[:insert_at], *new_import_lines, *lines[insert_at:]]) + + insert_at = import_block_end + import_block = "import (\n" + "".join(new_import_lines) + ")\n\n" + return "".join([*lines[:insert_at], import_block, *lines[insert_at:]]) + + +def remove_test_functions(test_source: str, functions_to_remove: list[str], analyzer: GoAnalyzer | None = None) -> str: + analyzer = analyzer or GoAnalyzer() + tree = analyzer.parse(test_source) + lines = test_source.splitlines(keepends=True) + + regions_to_remove: list[tuple[int, int]] = [] + + for node in tree.root_node.children: + if node.type == "function_declaration": + name_node = node.child_by_field_name("name") + if name_node is not None and analyzer.get_node_text(name_node) in functions_to_remove: + doc_start = _find_doc_comment_start(node) + start = (doc_start if doc_start is not None else node.start_point.row + 1) - 1 + end = node.end_point.row + 1 + regions_to_remove.append((start, end)) + + for start, end in reversed(regions_to_remove): + del lines[start:end] + + return "".join(lines) + + +def _find_doc_comment_start(node: object) -> int | None: + prev = getattr(node, "prev_named_sibling", None) + if prev is None: + return None + if getattr(prev, "type", None) != "comment": + return None + if prev.end_point.row + 1 != node.start_point.row: + return None + comment_start = prev.start_point.row + 1 + current = prev + while True: + earlier = getattr(current, "prev_named_sibling", None) + if earlier is None or getattr(earlier, "type", None) != "comment": + break + if earlier.end_point.row + 1 != current.start_point.row: + break + comment_start = earlier.start_point.row + 1 + current = earlier + return comment_start + + +def _find_import_block_end(source: str, analyzer: GoAnalyzer) -> int: + tree = analyzer.parse(source) + for node in tree.root_node.children: + if node.type == "package_clause": + return node.end_point.row + 1 + return 0 diff --git a/codeflash/languages/golang/support.py b/codeflash/languages/golang/support.py index 4f5d08b3f..4aacf1768 100644 --- a/codeflash/languages/golang/support.py +++ b/codeflash/languages/golang/support.py @@ -4,9 +4,20 @@ from pathlib import Path from typing import TYPE_CHECKING, Any +from codeflash.languages.golang.comparator import compare_test_results as _compare_results from codeflash.languages.golang.config import detect_go_project, detect_go_version +from codeflash.languages.golang.context import extract_code_context as _extract_context +from codeflash.languages.golang.context import find_helper_functions as _find_helpers from codeflash.languages.golang.discovery import discover_functions_from_source +from codeflash.languages.golang.formatter import format_go_code, normalize_go_code from codeflash.languages.golang.parser import GoAnalyzer +from codeflash.languages.golang.replacement import add_global_declarations as _add_globals +from codeflash.languages.golang.replacement import remove_test_functions as _remove_tests +from codeflash.languages.golang.replacement import replace_function as _replace_func +from codeflash.languages.golang.test_discovery import discover_tests as _discover_tests +from codeflash.languages.golang.test_runner import parse_test_results as _parse_results +from codeflash.languages.golang.test_runner import run_behavioral_tests as _run_behavioral +from codeflash.languages.golang.test_runner import run_benchmarking_tests as _run_benchmarking from codeflash.languages.language_enum import Language from codeflash.languages.registry import register_language @@ -19,6 +30,7 @@ FunctionFilterCriteria, HelperFunction, ReferenceInfo, + TestInfo, ) from codeflash.models.function_types import FunctionToOptimize @@ -80,39 +92,47 @@ def discover_functions( ) -> list[FunctionToOptimize]: return discover_functions_from_source(source, file_path, filter_criteria, self._analyzer) - def discover_tests(self, test_root: Path, source_functions: Sequence[FunctionToOptimize]) -> dict[str, list[Any]]: - raise NotImplementedError + def discover_tests( + self, test_root: Path, source_functions: Sequence[FunctionToOptimize] + ) -> dict[str, list[TestInfo]]: + return _discover_tests(test_root, source_functions) def validate_syntax(self, source: str, file_path: Path | None = None) -> bool: return self._analyzer.validate_syntax(source) def extract_code_context(self, function: FunctionToOptimize, project_root: Path, module_root: Path) -> CodeContext: - raise NotImplementedError + return _extract_context(function, project_root, module_root, self._analyzer) def find_helper_functions(self, function: FunctionToOptimize, project_root: Path) -> list[HelperFunction]: - raise NotImplementedError + try: + source = function.file_path.read_text(encoding="utf-8") + except Exception: + return [] + return _find_helpers(source, function, self._analyzer) def find_references( self, function: FunctionToOptimize, project_root: Path, tests_root: Path | None = None, max_files: int = 100 ) -> list[ReferenceInfo]: - raise NotImplementedError + return [] def replace_function(self, source: str, function: FunctionToOptimize, new_source: str) -> str: - raise NotImplementedError + return _replace_func(source, function, new_source, self._analyzer) def format_code(self, source: str, file_path: Path | None = None) -> str: - raise NotImplementedError + return format_go_code(source, file_path) def normalize_code(self, source: str) -> str: - raise NotImplementedError + return normalize_go_code(source) def add_global_declarations(self, optimized_code: str, original_source: str, module_abspath: Path) -> str: - raise NotImplementedError + return _add_globals(optimized_code, original_source, self._analyzer) def prepare_module( self, module_code: str, module_path: Path, project_root: Path ) -> tuple[dict[Path, Any], None] | None: - raise NotImplementedError + if not self._analyzer.validate_syntax(module_code): + return None + return {module_path: module_code}, None def setup_test_config(self, test_cfg: Any) -> None: project_root = getattr(test_cfg, "project_root_path", Path.cwd()) @@ -124,17 +144,53 @@ def setup_test_config(self, test_cfg: Any) -> None: def detect_module_system(self, project_root: Path, source_file: Path | None = None) -> str | None: return None - def run_behavioral_tests(self, *args: Any, **kwargs: Any) -> Any: - raise NotImplementedError - - def run_benchmarking_tests(self, *args: Any, **kwargs: Any) -> Any: - raise NotImplementedError + def run_behavioral_tests( + self, + test_paths: Any, + test_env: dict[str, str], + cwd: Path, + timeout: int | None = None, + project_root: Path | None = None, + enable_coverage: bool = False, + candidate_index: int = 0, + ) -> tuple[Path, Any, Path | None, Path | None]: + return _run_behavioral(test_paths, test_env, cwd, timeout, project_root, enable_coverage, candidate_index) + + def run_benchmarking_tests( + self, + test_paths: Any, + test_env: dict[str, str], + cwd: Path, + timeout: int | None = None, + project_root: Path | None = None, + min_loops: int = 5, + max_loops: int = 100_000, + target_duration_seconds: float = 10.0, + inner_iterations: int = 100, + ) -> tuple[Path, Any]: + return _run_benchmarking( + test_paths, + test_env, + cwd, + timeout, + project_root, + min_loops, + max_loops, + target_duration_seconds, + inner_iterations, + ) def run_line_profile_tests(self, *args: Any, **kwargs: Any) -> Any: raise NotImplementedError - def compare_test_results(self, *args: Any, **kwargs: Any) -> tuple[bool, list[Any]]: - raise NotImplementedError + def compare_test_results( + self, + original_results_path: Path, + candidate_results_path: Path, + project_root: Path | None = None, + project_classpath: str | None = None, + ) -> tuple[bool, list[Any]]: + return _compare_results(original_results_path, candidate_results_path, project_root, project_classpath) def instrument_for_behavior(self, source: str, functions: Sequence[FunctionToOptimize]) -> str: return source @@ -178,12 +234,12 @@ def add_runtime_comments( return test_source def remove_test_functions(self, test_source: str, functions_to_remove: list[str]) -> str: - raise NotImplementedError + return _remove_tests(test_source, functions_to_remove, self._analyzer) def get_test_dir_for_source(self, test_dir: Path, source_file: Path | None = None) -> Path | None: if source_file is not None: return source_file.parent return test_dir - def parse_test_results(self, *args: Any, **kwargs: Any) -> Any: - raise NotImplementedError + def parse_test_results(self, json_output_path: Path, stdout: str) -> Any: + return _parse_results(json_output_path, stdout) diff --git a/codeflash/languages/golang/test_discovery.py b/codeflash/languages/golang/test_discovery.py new file mode 100644 index 000000000..d7ec039fc --- /dev/null +++ b/codeflash/languages/golang/test_discovery.py @@ -0,0 +1,97 @@ +from __future__ import annotations + +import logging +import re +from typing import TYPE_CHECKING + +from codeflash.languages.base import TestInfo + +if TYPE_CHECKING: + from collections.abc import Sequence + from pathlib import Path + + from codeflash.discovery.functions_to_optimize import FunctionToOptimize + +logger = logging.getLogger(__name__) + +GO_TEST_FUNC_RE = re.compile(r"^func\s+(Test\w+)\s*\(", re.MULTILINE) + + +def discover_tests(test_root: Path, source_functions: Sequence[FunctionToOptimize]) -> dict[str, list[TestInfo]]: + func_name_to_qn: dict[str, list[str]] = {} + for func in source_functions: + func_name_to_qn.setdefault(func.function_name, []).append(func.qualified_name) + + test_files = list(test_root.rglob("*_test.go")) + result: dict[str, list[TestInfo]] = {} + + for test_file in test_files: + try: + content = test_file.read_text(encoding="utf-8") + except Exception: + logger.debug("Could not read test file %s", test_file) + continue + + test_func_names = GO_TEST_FUNC_RE.findall(content) + for test_func_name in test_func_names: + matched_qns = _match_test_to_functions(test_func_name, content, func_name_to_qn) + for qn in matched_qns: + info = TestInfo(test_name=test_func_name, test_file=test_file) + result.setdefault(qn, []).append(info) + + return result + + +def _match_test_to_functions(test_func_name: str, test_source: str, func_name_to_qn: dict[str, list[str]]) -> list[str]: + matched: list[str] = [] + + target_name = _extract_target_name(test_func_name) + if target_name and target_name in func_name_to_qn: + matched.extend(func_name_to_qn[target_name]) + return matched + + for func_name, qns in func_name_to_qn.items(): + if _test_calls_function(test_source, test_func_name, func_name): + matched.extend(qns) + + return matched + + +def _extract_target_name(test_func_name: str) -> str | None: + if not test_func_name.startswith("Test"): + return None + remainder = test_func_name[4:] + if not remainder: + return None + name = remainder.split("_")[0] + if not name: + return None + return name + + +def _test_calls_function(test_source: str, test_func_name: str, func_name: str) -> bool: + func_body = _extract_test_body(test_source, test_func_name) + if func_body is None: + return False + call_pattern = re.compile(rf"\b{re.escape(func_name)}\s*\(") + return call_pattern.search(func_body) is not None + + +def _extract_test_body(test_source: str, test_func_name: str) -> str | None: + pattern = re.compile(rf"func\s+{re.escape(test_func_name)}\s*\([^)]*\)\s*\{{") + match = pattern.search(test_source) + if match is None: + return None + + start = match.end() + depth = 1 + pos = start + while pos < len(test_source) and depth > 0: + ch = test_source[pos] + if ch == "{": + depth += 1 + elif ch == "}": + depth -= 1 + pos += 1 + + return test_source[start : pos - 1] if depth == 0 else None diff --git a/codeflash/languages/golang/test_runner.py b/codeflash/languages/golang/test_runner.py new file mode 100644 index 000000000..0d253655d --- /dev/null +++ b/codeflash/languages/golang/test_runner.py @@ -0,0 +1,238 @@ +from __future__ import annotations + +import contextlib +import json +import logging +import os +import signal +import subprocess +import sys +from typing import TYPE_CHECKING, Any + +from codeflash.languages.base import TestResult + +if TYPE_CHECKING: + from pathlib import Path + +logger = logging.getLogger(__name__) + + +def run_behavioral_tests( + test_paths: Any, + test_env: dict[str, str], + cwd: Path, + timeout: int | None = None, + project_root: Path | None = None, + enable_coverage: bool = False, + candidate_index: int = 0, +) -> tuple[Path, subprocess.CompletedProcess[str], Path | None, Path | None]: + result_dir = cwd / ".codeflash" / "go_test_results" + result_dir.mkdir(parents=True, exist_ok=True) + json_output_file = result_dir / f"behavioral_{candidate_index}.jsonl" + + test_file_paths = _collect_test_file_paths(test_paths) + packages = _test_files_to_packages(test_file_paths, cwd) + if not packages: + packages = ["./..."] + + env = {**os.environ, **test_env} + + cmd = ["go", "test", "-json", "-v", "-count=1", *packages] + + proc_result = _run_cmd_kill_pg_on_timeout(cmd, cwd=cwd, env=env, timeout=timeout) + + json_output_file.write_text(proc_result.stdout or "", encoding="utf-8") + + return json_output_file, proc_result, None, None + + +def run_benchmarking_tests( + test_paths: Any, + test_env: dict[str, str], + cwd: Path, + timeout: int | None = None, + project_root: Path | None = None, + min_loops: int = 5, + max_loops: int = 100_000, + target_duration_seconds: float = 10.0, + inner_iterations: int = 100, +) -> tuple[Path, subprocess.CompletedProcess[str]]: + result_dir = cwd / ".codeflash" / "go_test_results" + result_dir.mkdir(parents=True, exist_ok=True) + json_output_file = result_dir / "benchmark.jsonl" + + test_file_paths = _collect_test_file_paths(test_paths) + packages = _test_files_to_packages(test_file_paths, cwd) + if not packages: + packages = ["./..."] + + env = {**os.environ, **test_env} + + benchtime = f"{target_duration_seconds:.0f}s" + cmd = [ + "go", + "test", + "-json", + "-v", + "-bench=.", + f"-benchtime={benchtime}", + "-benchmem", + f"-count={min_loops}", + "-run=^$", + *packages, + ] + + proc_result = _run_cmd_kill_pg_on_timeout(cmd, cwd=cwd, env=env, timeout=timeout) + + json_output_file.write_text(proc_result.stdout or "", encoding="utf-8") + + return json_output_file, proc_result + + +def parse_go_test_json(json_output: str) -> list[TestResult]: + results: dict[str, TestResult] = {} + + for line in json_output.splitlines(): + line = line.strip() + if not line: + continue + try: + event = json.loads(line) + except json.JSONDecodeError: + continue + + action = event.get("Action") + test_name = event.get("Test") + if test_name is None: + continue + + package = event.get("Package", "") + + if action == "pass": + elapsed = event.get("Elapsed", 0) + results[test_name] = TestResult( + test_name=test_name, + test_file=_package_to_path(package), + passed=True, + runtime_ns=int(elapsed * 1_000_000_000) if elapsed else None, + ) + elif action == "fail": + elapsed = event.get("Elapsed", 0) + existing = results.get(test_name) + stdout = existing.stdout if existing else "" + results[test_name] = TestResult( + test_name=test_name, + test_file=_package_to_path(package), + passed=False, + runtime_ns=int(elapsed * 1_000_000_000) if elapsed else None, + stdout=stdout, + error_message=f"Test {test_name} failed", + ) + elif action == "output": + output_text = event.get("Output", "") + if test_name in results: + results[test_name] = TestResult( + test_name=results[test_name].test_name, + test_file=results[test_name].test_file, + passed=results[test_name].passed, + runtime_ns=results[test_name].runtime_ns, + stdout=results[test_name].stdout + output_text, + stderr=results[test_name].stderr, + error_message=results[test_name].error_message, + ) + else: + results[test_name] = TestResult( + test_name=test_name, test_file=_package_to_path(package), passed=True, stdout=output_text + ) + + return list(results.values()) + + +def parse_test_results(json_output_path: Path, stdout: str) -> list[TestResult]: + try: + content = json_output_path.read_text(encoding="utf-8") + except Exception: + content = stdout + return parse_go_test_json(content) + + +def _package_to_path(package: str) -> Path: + from pathlib import Path as _Path + + if package: + return _Path(package.replace("/", os.sep)) + return _Path() + + +def _collect_test_file_paths(test_paths: Any) -> list[Path]: + from pathlib import Path as _Path + + if test_paths is None: + return [] + + if hasattr(test_paths, "test_files"): + paths = [] + for tf in test_paths.test_files: + p = getattr(tf, "instrumented_behavior_file_path", None) or getattr(tf, "original_file_path", None) + if p is not None: + paths.append(_Path(p)) + return paths + + if isinstance(test_paths, list): + return [_Path(p) for p in test_paths] + + return [] + + +def _test_files_to_packages(test_files: list[Path], cwd: Path) -> list[str]: + dirs: set[str] = set() + for f in test_files: + try: + rel = f.resolve().parent.relative_to(cwd.resolve()) + dirs.add(f"./{rel.as_posix()}") + except ValueError: + continue + return sorted(dirs) if dirs else [] + + +def _run_cmd_kill_pg_on_timeout( + cmd: list[str], *, cwd: Path | None = None, env: dict[str, str] | None = None, timeout: int | None = None +) -> subprocess.CompletedProcess[str]: + if sys.platform == "win32": + try: + return subprocess.run(cmd, cwd=cwd, env=env, capture_output=True, text=True, timeout=timeout, check=False) + except subprocess.TimeoutExpired: + return subprocess.CompletedProcess( + args=cmd, returncode=-2, stdout="", stderr=f"Process timed out after {timeout}s" + ) + + proc = subprocess.Popen( + cmd, cwd=cwd, env=env, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, start_new_session=True + ) + try: + stdout, stderr = proc.communicate(timeout=timeout) + return subprocess.CompletedProcess(args=cmd, returncode=proc.returncode, stdout=stdout, stderr=stderr) + except subprocess.TimeoutExpired: + pgid = None + try: + pgid = os.getpgid(proc.pid) + os.killpg(pgid, signal.SIGTERM) + except (ProcessLookupError, OSError): + proc.kill() + try: + proc.wait(timeout=5) + except subprocess.TimeoutExpired: + if pgid is not None: + with contextlib.suppress(ProcessLookupError, OSError): + os.killpg(pgid, signal.SIGKILL) + else: + proc.kill() + proc.wait() + try: + stdout_data = proc.stdout.read() if proc.stdout else "" + stderr_data = proc.stderr.read() if proc.stderr else "" + except Exception: + stdout_data, stderr_data = "", "" + return subprocess.CompletedProcess( + args=cmd, returncode=-2, stdout=stdout_data, stderr=stderr_data or f"Process timed out after {timeout}s" + ) diff --git a/tests/test_languages/test_golang/test_comparator.py b/tests/test_languages/test_golang/test_comparator.py new file mode 100644 index 000000000..43bb4631f --- /dev/null +++ b/tests/test_languages/test_golang/test_comparator.py @@ -0,0 +1,109 @@ +from __future__ import annotations + +from pathlib import Path + +from codeflash.languages.golang.comparator import compare_test_results + + +class TestCompareTestResults: + def test_equivalent_results(self, tmp_path: Path) -> None: + orig = (tmp_path / "original.jsonl").resolve() + cand = (tmp_path / "candidate.jsonl").resolve() + orig.write_text( + '{"Action":"pass","Test":"TestAdd","Package":"calc"}\n' + '{"Action":"pass","Test":"TestSub","Package":"calc"}\n', + encoding="utf-8", + ) + cand.write_text( + '{"Action":"pass","Test":"TestAdd","Package":"calc"}\n' + '{"Action":"pass","Test":"TestSub","Package":"calc"}\n', + encoding="utf-8", + ) + eq, diffs = compare_test_results(orig, cand) + assert eq is True + assert diffs == [] + + def test_candidate_fails_one(self, tmp_path: Path) -> None: + orig = (tmp_path / "original.jsonl").resolve() + cand = (tmp_path / "candidate.jsonl").resolve() + orig.write_text( + '{"Action":"pass","Test":"TestAdd","Package":"calc"}\n' + '{"Action":"pass","Test":"TestSub","Package":"calc"}\n', + encoding="utf-8", + ) + cand.write_text( + '{"Action":"pass","Test":"TestAdd","Package":"calc"}\n' + '{"Action":"fail","Test":"TestSub","Package":"calc"}\n', + encoding="utf-8", + ) + eq, diffs = compare_test_results(orig, cand) + assert eq is False + assert len(diffs) == 1 + assert diffs[0].test_name == "TestSub" + assert diffs[0].original_passed is True + assert diffs[0].candidate_passed is False + + def test_missing_test_in_candidate(self, tmp_path: Path) -> None: + orig = (tmp_path / "original.jsonl").resolve() + cand = (tmp_path / "candidate.jsonl").resolve() + orig.write_text( + '{"Action":"pass","Test":"TestAdd","Package":"calc"}\n' + '{"Action":"pass","Test":"TestSub","Package":"calc"}\n', + encoding="utf-8", + ) + cand.write_text( + '{"Action":"pass","Test":"TestAdd","Package":"calc"}\n', + encoding="utf-8", + ) + eq, diffs = compare_test_results(orig, cand) + assert eq is False + assert len(diffs) == 1 + assert diffs[0].test_name == "TestSub" + + def test_extra_test_in_candidate(self, tmp_path: Path) -> None: + orig = (tmp_path / "original.jsonl").resolve() + cand = (tmp_path / "candidate.jsonl").resolve() + orig.write_text( + '{"Action":"pass","Test":"TestAdd","Package":"calc"}\n', + encoding="utf-8", + ) + cand.write_text( + '{"Action":"pass","Test":"TestAdd","Package":"calc"}\n' + '{"Action":"pass","Test":"TestNew","Package":"calc"}\n', + encoding="utf-8", + ) + eq, diffs = compare_test_results(orig, cand) + assert eq is False + assert len(diffs) == 1 + assert diffs[0].test_name == "TestNew" + + def test_both_empty(self, tmp_path: Path) -> None: + orig = (tmp_path / "original.jsonl").resolve() + cand = (tmp_path / "candidate.jsonl").resolve() + orig.write_text("", encoding="utf-8") + cand.write_text("", encoding="utf-8") + eq, diffs = compare_test_results(orig, cand) + assert eq is True + assert diffs == [] + + def test_missing_files(self, tmp_path: Path) -> None: + orig = (tmp_path / "missing1.jsonl").resolve() + cand = (tmp_path / "missing2.jsonl").resolve() + eq, diffs = compare_test_results(orig, cand) + assert eq is True + assert diffs == [] + + def test_both_fail_same_test(self, tmp_path: Path) -> None: + orig = (tmp_path / "original.jsonl").resolve() + cand = (tmp_path / "candidate.jsonl").resolve() + orig.write_text( + '{"Action":"fail","Test":"TestBroken","Package":"calc"}\n', + encoding="utf-8", + ) + cand.write_text( + '{"Action":"fail","Test":"TestBroken","Package":"calc"}\n', + encoding="utf-8", + ) + eq, diffs = compare_test_results(orig, cand) + assert eq is True + assert diffs == [] diff --git a/tests/test_languages/test_golang/test_context.py b/tests/test_languages/test_golang/test_context.py new file mode 100644 index 000000000..a21844d7a --- /dev/null +++ b/tests/test_languages/test_golang/test_context.py @@ -0,0 +1,225 @@ +from __future__ import annotations + +from pathlib import Path + +from codeflash.languages.base import Language +from codeflash.languages.golang.context import extract_code_context, find_helper_functions +from codeflash.models.function_types import FunctionParent, FunctionToOptimize + +GO_SOURCE_WITH_METHOD = """\ +package calc + +import "math" + +type Calculator struct { +\tResult float64 +} + +// Add returns the sum. +func Add(a, b int) int { +\treturn a + b +} + +func subtract(a, b int) int { +\treturn a - b +} + +func (c *Calculator) AddFloat(val float64) float64 { +\tc.Result += val +\treturn c.Result +} +""" + + +class TestExtractCodeContextFunction: + def test_target_code_with_doc_comment(self, tmp_path: Path) -> None: + source_file = (tmp_path / "calc.go").resolve() + source_file.write_text(GO_SOURCE_WITH_METHOD, encoding="utf-8") + func = FunctionToOptimize( + function_name="Add", file_path=source_file, language="go", starting_line=10, ending_line=12 + ) + ctx = extract_code_context(func, tmp_path.resolve()) + assert ctx.target_code == "// Add returns the sum.\nfunc Add(a, b int) int {\n\treturn a + b\n}\n" + + def test_target_code_no_doc(self, tmp_path: Path) -> None: + source_file = (tmp_path / "calc.go").resolve() + source_file.write_text(GO_SOURCE_WITH_METHOD, encoding="utf-8") + func = FunctionToOptimize( + function_name="subtract", file_path=source_file, language="go", starting_line=14, ending_line=16 + ) + ctx = extract_code_context(func, tmp_path.resolve()) + assert ctx.target_code == "func subtract(a, b int) int {\n\treturn a - b\n}" + + def test_imports_extracted(self, tmp_path: Path) -> None: + source_file = (tmp_path / "calc.go").resolve() + source_file.write_text(GO_SOURCE_WITH_METHOD, encoding="utf-8") + func = FunctionToOptimize( + function_name="Add", file_path=source_file, language="go", starting_line=10, ending_line=12 + ) + ctx = extract_code_context(func, tmp_path.resolve()) + assert ctx.imports == ['"math"'] + + def test_no_read_only_context_for_function(self, tmp_path: Path) -> None: + source_file = (tmp_path / "calc.go").resolve() + source_file.write_text(GO_SOURCE_WITH_METHOD, encoding="utf-8") + func = FunctionToOptimize( + function_name="Add", file_path=source_file, language="go", starting_line=10, ending_line=12 + ) + ctx = extract_code_context(func, tmp_path.resolve()) + assert ctx.read_only_context == "" + + def test_helpers_for_function(self, tmp_path: Path) -> None: + source_file = (tmp_path / "calc.go").resolve() + source_file.write_text(GO_SOURCE_WITH_METHOD, encoding="utf-8") + func = FunctionToOptimize( + function_name="Add", file_path=source_file, language="go", starting_line=10, ending_line=12 + ) + ctx = extract_code_context(func, tmp_path.resolve()) + helper_names = [h.name for h in ctx.helper_functions] + assert "subtract" in helper_names + assert "AddFloat" in helper_names + assert "Add" not in helper_names + + def test_language_is_go(self, tmp_path: Path) -> None: + source_file = (tmp_path / "calc.go").resolve() + source_file.write_text(GO_SOURCE_WITH_METHOD, encoding="utf-8") + func = FunctionToOptimize( + function_name="Add", file_path=source_file, language="go", starting_line=10, ending_line=12 + ) + ctx = extract_code_context(func, tmp_path.resolve()) + assert ctx.language == Language.GO + + def test_target_file_path(self, tmp_path: Path) -> None: + source_file = (tmp_path / "calc.go").resolve() + source_file.write_text(GO_SOURCE_WITH_METHOD, encoding="utf-8") + func = FunctionToOptimize( + function_name="Add", file_path=source_file, language="go", starting_line=10, ending_line=12 + ) + ctx = extract_code_context(func, tmp_path.resolve()) + assert ctx.target_file == source_file + + +class TestExtractCodeContextMethod: + def test_method_target_code(self, tmp_path: Path) -> None: + source_file = (tmp_path / "calc.go").resolve() + source_file.write_text(GO_SOURCE_WITH_METHOD, encoding="utf-8") + func = FunctionToOptimize( + function_name="AddFloat", + file_path=source_file, + parents=[FunctionParent(name="Calculator", type="StructDef")], + language="go", + is_method=True, + starting_line=18, + ending_line=21, + ) + ctx = extract_code_context(func, tmp_path.resolve()) + assert ctx.target_code == ( + "func (c *Calculator) AddFloat(val float64) float64 {\n" + "\tc.Result += val\n" + "\treturn c.Result\n" + "}" + ) + + def test_method_read_only_context_is_struct(self, tmp_path: Path) -> None: + source_file = (tmp_path / "calc.go").resolve() + source_file.write_text(GO_SOURCE_WITH_METHOD, encoding="utf-8") + func = FunctionToOptimize( + function_name="AddFloat", + file_path=source_file, + parents=[FunctionParent(name="Calculator", type="StructDef")], + language="go", + is_method=True, + starting_line=18, + ending_line=21, + ) + ctx = extract_code_context(func, tmp_path.resolve()) + assert ctx.read_only_context == "type Calculator struct {\n\tResult float64\n}" + + def test_method_helpers_exclude_self(self, tmp_path: Path) -> None: + source_file = (tmp_path / "calc.go").resolve() + source_file.write_text(GO_SOURCE_WITH_METHOD, encoding="utf-8") + func = FunctionToOptimize( + function_name="AddFloat", + file_path=source_file, + parents=[FunctionParent(name="Calculator", type="StructDef")], + language="go", + is_method=True, + starting_line=18, + ending_line=21, + ) + ctx = extract_code_context(func, tmp_path.resolve()) + helper_names = [h.name for h in ctx.helper_functions] + assert "Add" in helper_names + assert "subtract" in helper_names + assert "AddFloat" not in helper_names + + def test_method_helper_qualified_names(self, tmp_path: Path) -> None: + source_file = (tmp_path / "calc.go").resolve() + source_file.write_text(GO_SOURCE_WITH_METHOD, encoding="utf-8") + func = FunctionToOptimize( + function_name="AddFloat", + file_path=source_file, + parents=[FunctionParent(name="Calculator", type="StructDef")], + language="go", + is_method=True, + starting_line=18, + ending_line=21, + ) + ctx = extract_code_context(func, tmp_path.resolve()) + helper_qns = [h.qualified_name for h in ctx.helper_functions] + assert "Add" in helper_qns + assert "subtract" in helper_qns + + +class TestExtractCodeContextEdgeCases: + def test_missing_file(self, tmp_path: Path) -> None: + missing = (tmp_path / "missing.go").resolve() + func = FunctionToOptimize(function_name="Foo", file_path=missing, language="go") + ctx = extract_code_context(func, tmp_path.resolve()) + assert ctx.target_code == "" + assert ctx.language == Language.GO + + def test_function_not_in_source(self, tmp_path: Path) -> None: + source_file = (tmp_path / "calc.go").resolve() + source_file.write_text("package calc\n\nfunc Other() {}\n", encoding="utf-8") + func = FunctionToOptimize(function_name="Missing", file_path=source_file, language="go") + ctx = extract_code_context(func, tmp_path.resolve()) + assert ctx.target_code == "" + + def test_multi_import(self, tmp_path: Path) -> None: + source = 'package calc\n\nimport (\n\t"fmt"\n\t"os"\n\tstr "strings"\n)\n\nfunc Hello() string {\n\treturn "hi"\n}\n' + source_file = (tmp_path / "hello.go").resolve() + source_file.write_text(source, encoding="utf-8") + func = FunctionToOptimize(function_name="Hello", file_path=source_file, language="go") + ctx = extract_code_context(func, tmp_path.resolve()) + assert ctx.imports == ['"fmt"', '"os"', 'str "strings"'] + + +class TestFindHelperFunctions: + def test_skips_init_and_main(self, tmp_path: Path) -> None: + source = "package main\n\nfunc init() { println() }\n\nfunc main() { println() }\n\nfunc Target() int { return 1 }\n" + source_file = (tmp_path / "main.go").resolve() + func = FunctionToOptimize(function_name="Target", file_path=source_file, language="go") + helpers = find_helper_functions(source, func) + helper_names = [h.name for h in helpers] + assert "init" not in helper_names + assert "main" not in helper_names + + def test_method_helpers_have_qualified_names(self, tmp_path: Path) -> None: + source = ( + "package calc\n\n" + "type Calc struct{}\n\n" + "func (c Calc) Target() int { return 1 }\n\n" + "func (c Calc) Helper() int { return 2 }\n" + ) + source_file = (tmp_path / "calc.go").resolve() + func = FunctionToOptimize( + function_name="Target", + file_path=source_file, + parents=[FunctionParent(name="Calc", type="StructDef")], + language="go", + is_method=True, + ) + helpers = find_helper_functions(source, func) + assert len(helpers) == 1 + assert helpers[0].qualified_name == "Calc.Helper" diff --git a/tests/test_languages/test_golang/test_formatter.py b/tests/test_languages/test_golang/test_formatter.py new file mode 100644 index 000000000..2c36f44b6 --- /dev/null +++ b/tests/test_languages/test_golang/test_formatter.py @@ -0,0 +1,100 @@ +from __future__ import annotations + +from unittest.mock import patch + +from codeflash.languages.golang.formatter import format_go_code, normalize_go_code + + +class TestNormalizeGoCode: + def test_strips_line_comments(self) -> None: + source = "package calc\n\n// Add returns the sum.\nfunc Add(a, b int) int {\n\treturn a + b // fast path\n}\n" + result = normalize_go_code(source) + expected = "package calc\nfunc Add(a, b int) int {\nreturn a + b\n}" + assert result == expected + + def test_strips_single_line_block_comment(self) -> None: + source = "package calc\n\n/* block comment */\nfunc Subtract(a, b int) int {\n\treturn a - b\n}\n" + result = normalize_go_code(source) + expected = "package calc\nfunc Subtract(a, b int) int {\nreturn a - b\n}" + assert result == expected + + def test_strips_multi_line_block_comment(self) -> None: + source = "package calc\n\n/*\nThis is a\nmulti-line comment.\n*/\nfunc Add(a, b int) int {\n\treturn a + b\n}\n" + result = normalize_go_code(source) + expected = "package calc\nfunc Add(a, b int) int {\nreturn a + b\n}" + assert result == expected + + def test_preserves_comment_in_string(self) -> None: + source = 'func Greet() string {\n\treturn "hello // world"\n}\n' + result = normalize_go_code(source) + expected = 'func Greet() string {\nreturn "hello // world"\n}' + assert result == expected + + def test_preserves_comment_in_raw_string(self) -> None: + source = "func Greet() string {\n\treturn `hello // world`\n}\n" + result = normalize_go_code(source) + expected = "func Greet() string {\nreturn `hello // world`\n}" + assert result == expected + + def test_strips_whitespace_and_empty_lines(self) -> None: + source = "package calc\n\n\n\nfunc Add(a, b int) int {\n\t\treturn a + b\n\t}\n" + result = normalize_go_code(source) + expected = "package calc\nfunc Add(a, b int) int {\nreturn a + b\n}" + assert result == expected + + def test_mixed_comments(self) -> None: + source = ( + "package calc\n\n" + "// Add returns the sum.\n" + "func Add(a, b int) int {\n" + "\treturn a + b // fast path\n" + "}\n\n" + "/* block comment */\n" + "func Subtract(a, b int) int {\n" + "\treturn a - b\n" + "}\n" + ) + result = normalize_go_code(source) + expected = "package calc\nfunc Add(a, b int) int {\nreturn a + b\n}\nfunc Subtract(a, b int) int {\nreturn a - b\n}" + assert result == expected + + def test_inline_block_comment(self) -> None: + source = "func Add(a /* first */, b int) int {\n\treturn a + b\n}\n" + result = normalize_go_code(source) + expected = "func Add(a , b int) int {\nreturn a + b\n}" + assert result == expected + + def test_empty_input(self) -> None: + assert normalize_go_code("") == "" + + def test_only_comments(self) -> None: + source = "// just a comment\n// another line\n" + result = normalize_go_code(source) + assert result == "" + + +class TestFormatGoCode: + def test_no_formatter_returns_source(self) -> None: + source = "package calc\n\nfunc Add(a, b int) int {\nreturn a+b\n}\n" + with patch("codeflash.languages.golang.formatter.shutil.which", return_value=None): + result = format_go_code(source) + assert result == source + + def test_format_with_gofmt(self) -> None: + import shutil + + if shutil.which("gofmt") is None: + return + source = "package calc\n\nfunc Add(a,b int)int{\nreturn a+b\n}\n" + result = format_go_code(source) + assert result != source + assert "func Add" in result + + def test_format_failure_returns_source(self) -> None: + source = "this is not valid go" + with patch("codeflash.languages.golang.formatter.shutil.which", return_value="/usr/bin/gofmt"): + with patch("codeflash.languages.golang.formatter.subprocess.run") as mock_run: + mock_run.return_value.returncode = 2 + mock_run.return_value.stderr = "syntax error" + result = format_go_code(source) + assert result == source diff --git a/tests/test_languages/test_golang/test_replacement.py b/tests/test_languages/test_golang/test_replacement.py new file mode 100644 index 000000000..c36f3a012 --- /dev/null +++ b/tests/test_languages/test_golang/test_replacement.py @@ -0,0 +1,200 @@ +from __future__ import annotations + +from pathlib import Path + +from codeflash.languages.golang.replacement import add_global_declarations, remove_test_functions, replace_function +from codeflash.models.function_types import FunctionParent, FunctionToOptimize + + +class TestReplaceFunction: + def test_replace_basic_function(self) -> None: + source = "package calc\n\nfunc Add(a, b int) int {\n\treturn a + b\n}\n\nfunc Subtract(a, b int) int {\n\treturn a - b\n}\n" + func = FunctionToOptimize(function_name="Add", file_path=Path("/project/calc.go"), language="go") + new_body = "func Add(a, b int) int {\n\tresult := a + b\n\treturn result\n}" + result = replace_function(source, func, new_body) + expected = "package calc\n\nfunc Add(a, b int) int {\n\tresult := a + b\n\treturn result\n}\n\nfunc Subtract(a, b int) int {\n\treturn a - b\n}\n" + assert result == expected + + def test_replace_function_with_doc_comment(self) -> None: + source = "package calc\n\n// Add returns the sum.\nfunc Add(a, b int) int {\n\treturn a + b\n}\n" + func = FunctionToOptimize(function_name="Add", file_path=Path("/project/calc.go"), language="go") + new_body = "// Add returns an optimized sum.\nfunc Add(a, b int) int {\n\treturn a + b\n}" + result = replace_function(source, func, new_body) + expected = "package calc\n\n// Add returns an optimized sum.\nfunc Add(a, b int) int {\n\treturn a + b\n}\n" + assert result == expected + + def test_replace_method(self) -> None: + source = ( + "package calc\n\n" + "type Calc struct {\n\tResult float64\n}\n\n" + "// AddFloat adds a value.\n" + "func (c *Calc) AddFloat(val float64) float64 {\n\tc.Result += val\n\treturn c.Result\n}\n\n" + "func (c Calc) GetResult() float64 {\n\treturn c.Result\n}\n" + ) + func = FunctionToOptimize( + function_name="AddFloat", + file_path=Path("/project/calc.go"), + parents=[FunctionParent(name="Calc", type="StructDef")], + language="go", + is_method=True, + ) + new_body = "// AddFloat adds a value (optimized).\nfunc (c *Calc) AddFloat(val float64) float64 {\n\tc.Result = c.Result + val\n\treturn c.Result\n}" + result = replace_function(source, func, new_body) + expected = ( + "package calc\n\n" + "type Calc struct {\n\tResult float64\n}\n\n" + "// AddFloat adds a value (optimized).\n" + "func (c *Calc) AddFloat(val float64) float64 {\n\tc.Result = c.Result + val\n\treturn c.Result\n}\n\n" + "func (c Calc) GetResult() float64 {\n\treturn c.Result\n}\n" + ) + assert result == expected + + def test_replace_nonexistent_returns_original(self) -> None: + source = "package calc\n\nfunc Add(a, b int) int {\n\treturn a + b\n}\n" + func = FunctionToOptimize(function_name="Missing", file_path=Path("/project/calc.go"), language="go") + result = replace_function(source, func, "func Missing() {}") + assert result == source + + def test_replace_preserves_surrounding_code(self) -> None: + source = ( + "package calc\n\n" + "var version = 1\n\n" + "func Add(a, b int) int {\n\treturn a + b\n}\n\n" + "func Subtract(a, b int) int {\n\treturn a - b\n}\n" + ) + func = FunctionToOptimize(function_name="Add", file_path=Path("/project/calc.go"), language="go") + new_body = "func Add(a, b int) int {\n\treturn b + a\n}" + result = replace_function(source, func, new_body) + expected = ( + "package calc\n\n" + "var version = 1\n\n" + "func Add(a, b int) int {\n\treturn b + a\n}\n\n" + "func Subtract(a, b int) int {\n\treturn a - b\n}\n" + ) + assert result == expected + + +class TestAddGlobalDeclarations: + def test_add_import_to_existing_block(self) -> None: + original = 'package calc\n\nimport (\n\t"fmt"\n)\n\nfunc Add(a, b int) int {\n\treturn a + b\n}\n' + optimized = 'package calc\n\nimport (\n\t"fmt"\n\t"math"\n)\n\nfunc Add(a, b int) int {\n\treturn a + b\n}\n' + result = add_global_declarations(optimized, original) + expected = 'package calc\n\nimport (\n\t"fmt"\n\t"math"\n)\n\nfunc Add(a, b int) int {\n\treturn a + b\n}\n' + assert result == expected + + def test_add_aliased_import(self) -> None: + original = 'package calc\n\nimport (\n\t"fmt"\n)\n\nfunc Add(a, b int) int {\n\treturn a + b\n}\n' + optimized = 'package calc\n\nimport (\n\t"fmt"\n\tstr "strings"\n)\n\nfunc Add(a, b int) int {\n\treturn a + b\n}\n' + result = add_global_declarations(optimized, original) + expected = 'package calc\n\nimport (\n\t"fmt"\n\tstr "strings"\n)\n\nfunc Add(a, b int) int {\n\treturn a + b\n}\n' + assert result == expected + + def test_add_import_when_no_existing_imports(self) -> None: + original = "package calc\n\nfunc Add(a, b int) int {\n\treturn a + b\n}\n" + optimized = 'package calc\n\nimport "math"\n\nfunc Add(a, b int) int {\n\treturn a + b\n}\n' + result = add_global_declarations(optimized, original) + expected = 'package calc\nimport (\n\t"math"\n)\n\n\nfunc Add(a, b int) int {\n\treturn a + b\n}\n' + assert result == expected + + def test_no_new_imports_returns_unchanged(self) -> None: + source = "package calc\n\nfunc Add(a, b int) int {\n\treturn a + b\n}\n" + result = add_global_declarations(source, source) + assert result == source + + +class TestRemoveTestFunctions: + def test_remove_single_function(self) -> None: + test_source = ( + "package calc\n\n" + 'import "testing"\n\n' + "func TestAdd(t *testing.T) {\n" + "\tresult := Add(2, 3)\n" + "\tif result != 5 {\n" + '\t\tt.Errorf("want 5, got %d", result)\n' + "\t}\n" + "}\n\n" + "func TestSubtract(t *testing.T) {\n" + "\tresult := Subtract(5, 3)\n" + "\tif result != 2 {\n" + '\t\tt.Errorf("want 2, got %d", result)\n' + "\t}\n" + "}\n" + ) + result = remove_test_functions(test_source, ["TestAdd"]) + expected = ( + "package calc\n\n" + 'import "testing"\n\n\n' + "func TestSubtract(t *testing.T) {\n" + "\tresult := Subtract(5, 3)\n" + "\tif result != 2 {\n" + '\t\tt.Errorf("want 2, got %d", result)\n' + "\t}\n" + "}\n" + ) + assert result == expected + + def test_remove_multiple_functions(self) -> None: + test_source = ( + "package calc\n\n" + 'import "testing"\n\n' + "// TestAdd tests addition.\n" + "func TestAdd(t *testing.T) {\n" + "\tif Add(1, 2) != 3 {\n" + "\t\tt.Fail()\n" + "\t}\n" + "}\n\n" + "func TestSubtract(t *testing.T) {\n" + "\tif Subtract(5, 3) != 2 {\n" + "\t\tt.Fail()\n" + "\t}\n" + "}\n\n" + "func TestMultiply(t *testing.T) {\n" + "\tif Multiply(2, 3) != 6 {\n" + "\t\tt.Fail()\n" + "\t}\n" + "}\n" + ) + result = remove_test_functions(test_source, ["TestAdd", "TestMultiply"]) + expected = ( + "package calc\n\n" + 'import "testing"\n\n\n' + "func TestSubtract(t *testing.T) {\n" + "\tif Subtract(5, 3) != 2 {\n" + "\t\tt.Fail()\n" + "\t}\n" + "}\n\n" + ) + assert result == expected + + def test_remove_function_with_doc_comment(self) -> None: + test_source = ( + "package calc\n\n" + 'import "testing"\n\n' + "// TestAdd tests addition.\n" + "func TestAdd(t *testing.T) {\n" + "\tif Add(1, 2) != 3 {\n" + "\t\tt.Fail()\n" + "\t}\n" + "}\n\n" + "func TestSubtract(t *testing.T) {\n" + "\tif Subtract(5, 3) != 2 {\n" + "\t\tt.Fail()\n" + "\t}\n" + "}\n" + ) + result = remove_test_functions(test_source, ["TestAdd"]) + expected = ( + "package calc\n\n" + 'import "testing"\n\n\n' + "func TestSubtract(t *testing.T) {\n" + "\tif Subtract(5, 3) != 2 {\n" + "\t\tt.Fail()\n" + "\t}\n" + "}\n" + ) + assert result == expected + + def test_remove_none_returns_unchanged(self) -> None: + test_source = "package calc\n\nfunc TestAdd(t *testing.T) {\n\tt.Log(\"ok\")\n}\n" + result = remove_test_functions(test_source, []) + assert result == test_source diff --git a/tests/test_languages/test_golang/test_test_discovery.py b/tests/test_languages/test_golang/test_test_discovery.py new file mode 100644 index 000000000..c31c8d006 --- /dev/null +++ b/tests/test_languages/test_golang/test_test_discovery.py @@ -0,0 +1,159 @@ +from __future__ import annotations + +from pathlib import Path + +from codeflash.languages.golang.test_discovery import ( + _extract_target_name, + _extract_test_body, + _test_calls_function, + discover_tests, +) +from codeflash.models.function_types import FunctionToOptimize + +GO_TEST_SOURCE = """\ +package calc + +import "testing" + +func TestAdd(t *testing.T) { +\tresult := Add(2, 3) +\tif result != 5 { +\t\tt.Fail() +\t} +} + +func TestSubtract(t *testing.T) { +\tresult := Subtract(5, 3) +\tif result != 2 { +\t\tt.Fail() +\t} +} + +func TestHelper(t *testing.T) { +\tx := 1 + 2 +\t_ = x +} +""" + + +class TestExtractTargetName: + def test_simple(self) -> None: + assert _extract_target_name("TestAdd") == "Add" + + def test_with_underscore_suffix(self) -> None: + assert _extract_target_name("TestAdd_negative") == "Add" + + def test_long_name(self) -> None: + assert _extract_target_name("TestFibonacci") == "Fibonacci" + + def test_bare_test(self) -> None: + assert _extract_target_name("Test") is None + + def test_not_a_test(self) -> None: + assert _extract_target_name("NotATest") is None + + +class TestExtractTestBody: + def test_extracts_body(self) -> None: + body = _extract_test_body(GO_TEST_SOURCE, "TestAdd") + assert body == "\n\tresult := Add(2, 3)\n\tif result != 5 {\n\t\tt.Fail()\n\t}\n" + + def test_extracts_second_body(self) -> None: + body = _extract_test_body(GO_TEST_SOURCE, "TestSubtract") + assert body == "\n\tresult := Subtract(5, 3)\n\tif result != 2 {\n\t\tt.Fail()\n\t}\n" + + def test_missing_function(self) -> None: + assert _extract_test_body(GO_TEST_SOURCE, "TestMissing") is None + + +class TestTestCallsFunction: + def test_calls_add(self) -> None: + assert _test_calls_function(GO_TEST_SOURCE, "TestAdd", "Add") is True + + def test_does_not_call_subtract(self) -> None: + assert _test_calls_function(GO_TEST_SOURCE, "TestAdd", "Subtract") is False + + def test_helper_does_not_call_add(self) -> None: + assert _test_calls_function(GO_TEST_SOURCE, "TestHelper", "Add") is False + + +class TestDiscoverTests: + def test_matches_by_name_convention(self, tmp_path: Path) -> None: + root = tmp_path.resolve() + (root / "calc.go").write_text( + "package calc\n\nfunc Add(a, b int) int { return a + b }\n", encoding="utf-8" + ) + (root / "calc_test.go").write_text(GO_TEST_SOURCE, encoding="utf-8") + + funcs = [FunctionToOptimize(function_name="Add", file_path=root / "calc.go", language="go")] + result = discover_tests(root, funcs) + assert "Add" in result + assert len(result["Add"]) == 1 + assert result["Add"][0].test_name == "TestAdd" + + def test_matches_multiple_functions(self, tmp_path: Path) -> None: + root = tmp_path.resolve() + (root / "calc.go").write_text( + "package calc\n\nfunc Add(a, b int) int { return a + b }\n\nfunc Subtract(a, b int) int { return a - b }\n", + encoding="utf-8", + ) + (root / "calc_test.go").write_text(GO_TEST_SOURCE, encoding="utf-8") + + funcs = [ + FunctionToOptimize(function_name="Add", file_path=root / "calc.go", language="go"), + FunctionToOptimize(function_name="Subtract", file_path=root / "calc.go", language="go"), + ] + result = discover_tests(root, funcs) + assert "Add" in result + assert "Subtract" in result + assert result["Add"][0].test_name == "TestAdd" + assert result["Subtract"][0].test_name == "TestSubtract" + + def test_no_match_returns_empty(self, tmp_path: Path) -> None: + root = tmp_path.resolve() + (root / "calc.go").write_text( + "package calc\n\nfunc Multiply(a, b int) int { return a * b }\n", encoding="utf-8" + ) + (root / "calc_test.go").write_text(GO_TEST_SOURCE, encoding="utf-8") + + funcs = [FunctionToOptimize(function_name="Multiply", file_path=root / "calc.go", language="go")] + result = discover_tests(root, funcs) + assert "Multiply" not in result + + def test_no_test_files(self, tmp_path: Path) -> None: + root = tmp_path.resolve() + (root / "calc.go").write_text("package calc\n\nfunc Add(a, b int) int { return a + b }\n", encoding="utf-8") + + funcs = [FunctionToOptimize(function_name="Add", file_path=root / "calc.go", language="go")] + result = discover_tests(root, funcs) + assert result == {} + + def test_subdirectory_test_files(self, tmp_path: Path) -> None: + root = tmp_path.resolve() + pkg = root / "pkg" + pkg.mkdir() + (pkg / "calc.go").write_text( + "package calc\n\nfunc Add(a, b int) int { return a + b }\n", encoding="utf-8" + ) + (pkg / "calc_test.go").write_text(GO_TEST_SOURCE, encoding="utf-8") + + funcs = [FunctionToOptimize(function_name="Add", file_path=pkg / "calc.go", language="go")] + result = discover_tests(root, funcs) + assert "Add" in result + assert result["Add"][0].test_file == pkg / "calc_test.go" + + def test_fallback_content_match(self, tmp_path: Path) -> None: + root = tmp_path.resolve() + (root / "calc.go").write_text( + "package calc\n\nfunc DoMath(a, b int) int { return a + b }\n", encoding="utf-8" + ) + (root / "calc_test.go").write_text( + 'package calc\n\nimport "testing"\n\nfunc TestComputation(t *testing.T) {\n' + "\tresult := DoMath(2, 3)\n\tif result != 5 {\n\t\tt.Fail()\n\t}\n}\n", + encoding="utf-8", + ) + + funcs = [FunctionToOptimize(function_name="DoMath", file_path=root / "calc.go", language="go")] + result = discover_tests(root, funcs) + assert "DoMath" in result + assert result["DoMath"][0].test_name == "TestComputation" diff --git a/tests/test_languages/test_golang/test_test_runner.py b/tests/test_languages/test_golang/test_test_runner.py new file mode 100644 index 000000000..c812f8c60 --- /dev/null +++ b/tests/test_languages/test_golang/test_test_runner.py @@ -0,0 +1,92 @@ +from __future__ import annotations + +from pathlib import Path + +from codeflash.languages.golang.test_runner import parse_go_test_json, parse_test_results + + +GO_TEST_JSON_ALL_PASS = """\ +{"Time":"2024-01-01T00:00:00Z","Action":"run","Package":"example.com/calc","Test":"TestAdd"} +{"Time":"2024-01-01T00:00:00Z","Action":"output","Package":"example.com/calc","Test":"TestAdd","Output":"=== RUN TestAdd\\n"} +{"Time":"2024-01-01T00:00:00Z","Action":"output","Package":"example.com/calc","Test":"TestAdd","Output":"--- PASS: TestAdd (0.00s)\\n"} +{"Time":"2024-01-01T00:00:00Z","Action":"pass","Package":"example.com/calc","Test":"TestAdd","Elapsed":0.001} +{"Time":"2024-01-01T00:00:00Z","Action":"run","Package":"example.com/calc","Test":"TestSub"} +{"Time":"2024-01-01T00:00:00Z","Action":"output","Package":"example.com/calc","Test":"TestSub","Output":"--- PASS: TestSub (0.00s)\\n"} +{"Time":"2024-01-01T00:00:00Z","Action":"pass","Package":"example.com/calc","Test":"TestSub","Elapsed":0.002} +""" + +GO_TEST_JSON_WITH_FAILURE = """\ +{"Time":"2024-01-01T00:00:00Z","Action":"run","Package":"example.com/calc","Test":"TestAdd"} +{"Time":"2024-01-01T00:00:00Z","Action":"pass","Package":"example.com/calc","Test":"TestAdd","Elapsed":0.001} +{"Time":"2024-01-01T00:00:00Z","Action":"run","Package":"example.com/calc","Test":"TestBroken"} +{"Time":"2024-01-01T00:00:00Z","Action":"output","Package":"example.com/calc","Test":"TestBroken","Output":" calc_test.go:15: expected 5, got 3\\n"} +{"Time":"2024-01-01T00:00:00Z","Action":"fail","Package":"example.com/calc","Test":"TestBroken","Elapsed":0.003} +""" + + +class TestParseGoTestJson: + def test_all_pass(self) -> None: + results = parse_go_test_json(GO_TEST_JSON_ALL_PASS) + assert len(results) == 2 + by_name = {r.test_name: r for r in results} + assert by_name["TestAdd"].passed is True + assert by_name["TestAdd"].runtime_ns == 1_000_000 + assert by_name["TestSub"].passed is True + assert by_name["TestSub"].runtime_ns == 2_000_000 + + def test_with_failure(self) -> None: + results = parse_go_test_json(GO_TEST_JSON_WITH_FAILURE) + assert len(results) == 2 + by_name = {r.test_name: r for r in results} + assert by_name["TestAdd"].passed is True + assert by_name["TestBroken"].passed is False + assert by_name["TestBroken"].error_message == "Test TestBroken failed" + + def test_empty_input(self) -> None: + results = parse_go_test_json("") + assert results == [] + + def test_invalid_json_lines_skipped(self) -> None: + json_output = 'not json\n{"Action":"pass","Package":"calc","Test":"TestOk","Elapsed":0.001}\n' + results = parse_go_test_json(json_output) + assert len(results) == 1 + assert results[0].test_name == "TestOk" + assert results[0].passed is True + + def test_package_level_events_ignored(self) -> None: + json_output = '{"Action":"pass","Package":"example.com/calc","Elapsed":0.5}\n' + results = parse_go_test_json(json_output) + assert results == [] + + def test_runtime_ns_conversion(self) -> None: + json_output = '{"Action":"pass","Package":"calc","Test":"TestFast","Elapsed":0.0005}\n' + results = parse_go_test_json(json_output) + assert len(results) == 1 + assert results[0].runtime_ns == 500_000 + + def test_zero_elapsed(self) -> None: + json_output = '{"Action":"pass","Package":"calc","Test":"TestZero","Elapsed":0}\n' + results = parse_go_test_json(json_output) + assert len(results) == 1 + assert results[0].runtime_ns is None + + +class TestParseTestResults: + def test_reads_from_file(self, tmp_path: Path) -> None: + json_file = (tmp_path / "results.jsonl").resolve() + json_file.write_text( + '{"Action":"pass","Package":"calc","Test":"TestAdd","Elapsed":0.001}\n', + encoding="utf-8", + ) + results = parse_test_results(json_file, "") + assert len(results) == 1 + assert results[0].test_name == "TestAdd" + assert results[0].passed is True + + def test_falls_back_to_stdout(self, tmp_path: Path) -> None: + missing_file = (tmp_path / "missing.jsonl").resolve() + stdout = '{"Action":"fail","Package":"calc","Test":"TestBad","Elapsed":0.002}\n' + results = parse_test_results(missing_file, stdout) + assert len(results) == 1 + assert results[0].test_name == "TestBad" + assert results[0].passed is False From ac478de753070c2f0bef682b1d47b6e3704821e3 Mon Sep 17 00:00:00 2001 From: ali Date: Thu, 23 Apr 2026 17:40:20 +0200 Subject: [PATCH 03/10] go function optimizer and handle global vars context correctly --- codeflash/languages/golang/context.py | 36 +- .../languages/golang/function_optimizer.py | 160 +++++ codeflash/languages/golang/parser.py | 84 +++ codeflash/languages/golang/replacement.py | 73 ++ codeflash/languages/golang/support.py | 4 +- .../test_golang/test_context.py | 114 +++ .../test_golang/test_function_optimizer.py | 653 ++++++++++++++++++ .../test_languages/test_golang/test_parser.py | 122 ++++ .../test_golang/test_replacement.py | 460 +++++++++++- 9 files changed, 1701 insertions(+), 5 deletions(-) create mode 100644 codeflash/languages/golang/function_optimizer.py create mode 100644 tests/test_languages/test_golang/test_function_optimizer.py diff --git a/codeflash/languages/golang/context.py b/codeflash/languages/golang/context.py index 499da4ba1..3808a0d1e 100644 --- a/codeflash/languages/golang/context.py +++ b/codeflash/languages/golang/context.py @@ -36,9 +36,15 @@ def extract_code_context( imports = analyzer.find_imports(source) import_lines = [_import_to_line(imp) for imp in imports] - read_only_context = "" + read_only_parts: list[str] = [] if receiver_type: - read_only_context = _extract_struct_context(source, receiver_type, analyzer) + struct_ctx = _extract_struct_context(source, receiver_type, analyzer) + if struct_ctx: + read_only_parts.append(struct_ctx) + + init_ctx = _extract_init_context(source, analyzer) + if init_ctx: + read_only_parts.append(init_ctx) helpers = find_helper_functions(source, function, analyzer) @@ -46,7 +52,7 @@ def extract_code_context( target_code=target_code, target_file=function.file_path, helper_functions=helpers, - read_only_context=read_only_context, + read_only_context="\n\n".join(read_only_parts), imports=import_lines, language=Language.GO, ) @@ -125,3 +131,27 @@ def _extract_struct_context(source: str, struct_name: str, analyzer: GoAnalyzer) lines = source.splitlines() return "\n".join(lines[s.starting_line - 1 : s.ending_line]) return "" + + +def _extract_init_context(source: str, analyzer: GoAnalyzer) -> str: + init_source = analyzer.extract_function_source(source, "init") + if init_source is None: + return "" + + init_ids = analyzer.collect_body_identifiers(source, "init") + if not init_ids: + return init_source + + parts: list[str] = [] + + for decl in analyzer.find_global_declarations(source): + if init_ids & set(decl.names): + parts.append(decl.source_code) + + for struct in analyzer.find_structs(source): + if struct.name in init_ids: + lines = source.splitlines() + parts.append("\n".join(lines[struct.starting_line - 1 : struct.ending_line])) + + parts.append(init_source) + return "\n\n".join(parts) diff --git a/codeflash/languages/golang/function_optimizer.py b/codeflash/languages/golang/function_optimizer.py new file mode 100644 index 000000000..c3113276e --- /dev/null +++ b/codeflash/languages/golang/function_optimizer.py @@ -0,0 +1,160 @@ +from __future__ import annotations + +import hashlib +from collections import defaultdict +from typing import TYPE_CHECKING + +from codeflash.code_utils.code_utils import encoded_tokens_len +from codeflash.code_utils.config_consts import ( + OPTIMIZATION_CONTEXT_TOKEN_LIMIT, + READ_WRITABLE_LIMIT_ERROR, + TESTGEN_CONTEXT_TOKEN_LIMIT, + TESTGEN_LIMIT_ERROR, +) +from codeflash.either import Failure, Success +from codeflash.languages.function_optimizer import FunctionOptimizer +from codeflash.models.models import CodeOptimizationContext, CodeString, CodeStringsMarkdown, FunctionSource +from codeflash.verification.equivalence import compare_test_results + +if TYPE_CHECKING: + from pathlib import Path + + from codeflash.either import Result + from codeflash.languages.base import CodeContext, HelperFunction + from codeflash.models.models import OriginalCodeBaseline, TestDiff, TestResults + + +class GoFunctionOptimizer(FunctionOptimizer): + def get_code_optimization_context(self) -> Result[CodeOptimizationContext, str]: + from codeflash.languages import get_language_support + from codeflash.languages.base import Language + + language = Language(self.function_to_optimize.language) + lang_support = get_language_support(language) + + try: + code_context = lang_support.extract_code_context( + self.function_to_optimize, self.project_root, self.project_root + ) + return Success( + _build_optimization_context( + code_context, + self.function_to_optimize.file_path, + self.function_to_optimize.language, + self.project_root, + ) + ) + except ValueError as e: + return Failure(str(e)) + + def compare_candidate_results( + self, + baseline_results: OriginalCodeBaseline, + candidate_behavior_results: TestResults, + optimization_candidate_index: int, + ) -> tuple[bool, list[TestDiff]]: + return compare_test_results( + baseline_results.behavior_test_results, candidate_behavior_results, pass_fail_only=True + ) + + def replace_function_and_helpers_with_optimized_code( + self, + code_context: CodeOptimizationContext, + optimized_code: CodeStringsMarkdown, + original_helper_code: dict[Path, str], + ) -> bool: + from codeflash.languages.code_replacer import replace_function_definitions_for_language + + did_update = False + for module_abspath, qualified_names in self.group_functions_by_file(code_context).items(): + did_update |= replace_function_definitions_for_language( + function_names=list(qualified_names), + optimized_code=optimized_code, + module_abspath=module_abspath, + project_root_path=self.project_root, + lang_support=self.language_support, + function_to_optimize=self.function_to_optimize, + ) + return did_update + + +def _build_optimization_context( + code_context: CodeContext, + file_path: Path, + language: str, + project_root: Path, + optim_token_limit: int = OPTIMIZATION_CONTEXT_TOKEN_LIMIT, + testgen_token_limit: int = TESTGEN_CONTEXT_TOKEN_LIMIT, +) -> CodeOptimizationContext: + if code_context.imports: + inner = "\n".join(f"\t{imp}" for imp in code_context.imports) + imports_code = f"import (\n{inner}\n)" + else: + imports_code = "" + + try: + target_relative_path = file_path.resolve().relative_to(project_root.resolve()) + except ValueError: + target_relative_path = file_path + + helpers_by_file: dict[Path, list[HelperFunction]] = defaultdict(list) + helper_function_sources = [] + + for helper in code_context.helper_functions: + helpers_by_file[helper.file_path].append(helper) + helper_function_sources.append( + FunctionSource( + file_path=helper.file_path, + qualified_name=helper.qualified_name, + fully_qualified_name=helper.qualified_name, + only_function_name=helper.name, + source_code=helper.source_code, + ) + ) + + target_file_code = code_context.target_code + same_file_helpers = helpers_by_file.get(file_path, []) + if same_file_helpers: + helper_code = "\n\n".join(h.source_code for h in same_file_helpers) + target_file_code = target_file_code + "\n\n" + helper_code + + if imports_code: + target_file_code = imports_code + "\n\n" + target_file_code + + read_writable_code_strings = [CodeString(code=target_file_code, file_path=target_relative_path, language=language)] + + for helper_file_path, file_helpers in helpers_by_file.items(): + if helper_file_path == file_path: + continue + try: + helper_relative_path = helper_file_path.resolve().relative_to(project_root.resolve()) + except ValueError: + helper_relative_path = helper_file_path + combined_helper_code = "\n\n".join(h.source_code for h in file_helpers) + read_writable_code_strings.append( + CodeString(code=combined_helper_code, file_path=helper_relative_path, language=language) + ) + + read_writable_code = CodeStringsMarkdown(code_strings=read_writable_code_strings, language=language) + testgen_context = CodeStringsMarkdown(code_strings=read_writable_code_strings.copy(), language=language) + + read_writable_tokens = encoded_tokens_len(read_writable_code.markdown) + if read_writable_tokens > optim_token_limit: + raise ValueError(READ_WRITABLE_LIMIT_ERROR) + + testgen_tokens = encoded_tokens_len(testgen_context.markdown) + if testgen_tokens > testgen_token_limit: + raise ValueError(TESTGEN_LIMIT_ERROR) + + code_hash = hashlib.sha256(read_writable_code.flat.encode("utf-8")).hexdigest() + + return CodeOptimizationContext( + testgen_context=testgen_context, + read_writable_code=read_writable_code, + read_only_context_code=code_context.read_only_context, + hashing_code_context=read_writable_code.flat, + hashing_code_context_hash=code_hash, + helper_functions=helper_function_sources, + testgen_helper_fqns=[fs.fully_qualified_name for fs in helper_function_sources], + preexisting_objects=set(), + ) diff --git a/codeflash/languages/golang/parser.py b/codeflash/languages/golang/parser.py index 05f378ad7..e2d43ffc2 100644 --- a/codeflash/languages/golang/parser.py +++ b/codeflash/languages/golang/parser.py @@ -81,6 +81,15 @@ class GoImportInfo: ending_line: int +@dataclass(frozen=True) +class GoGlobalDeclaration: + names: tuple[str, ...] + kind: str + source_code: str + starting_line: int + ending_line: int + + class GoAnalyzer: def __init__(self) -> None: self._parser = _get_go_parser() @@ -189,6 +198,45 @@ def find_imports(self, source: str) -> list[GoImportInfo]: ) return results + def find_global_declarations(self, source: str) -> list[GoGlobalDeclaration]: + tree = self.parse(source) + results: list[GoGlobalDeclaration] = [] + for node in tree.root_node.children: + if node.type in ("var_declaration", "const_declaration"): + kind = "var" if node.type == "var_declaration" else "const" + names = _extract_declaration_names(node, self) + if names: + results.append( + GoGlobalDeclaration( + names=tuple(names), + kind=kind, + source_code=self.get_node_text(node), + starting_line=node.start_point.row + 1, + ending_line=node.end_point.row + 1, + ) + ) + return results + + def collect_body_identifiers(self, source: str, func_name: str, receiver_type: str | None = None) -> set[str]: + tree = self.parse(source) + for node in tree.root_node.children: + if receiver_type is None and node.type == "function_declaration": + name_node = node.child_by_field_name("name") + if name_node is not None and self.get_node_text(name_node) == func_name: + body = node.child_by_field_name("body") + return _collect_identifiers(body) if body else set() + if receiver_type is not None and node.type == "method_declaration": + name_node = node.child_by_field_name("name") + if name_node is None or self.get_node_text(name_node) != func_name: + continue + recv_node = node.child_by_field_name("receiver") + if recv_node is not None: + recv_name, _ = self.parse_receiver(recv_node) + if recv_name == receiver_type: + body = node.child_by_field_name("body") + return _collect_identifiers(body) if body else set() + return set() + def find_package_name(self, source: str) -> str | None: tree = self.parse(source) for node in tree.root_node.children: @@ -317,6 +365,42 @@ def _iter_import_specs(import_node: Node) -> list[Node]: return results +def _extract_declaration_names(node: Node, analyzer: GoAnalyzer) -> list[str]: + names: list[str] = [] + for child in node.children: + if child.type in ("var_spec", "const_spec"): + name_node = child.child_by_field_name("name") + if name_node is not None: + names.append(analyzer.get_node_text(name_node)) + elif child.type in ("var_spec_list", "const_spec_list"): + for spec in child.children: + if spec.type in ("var_spec", "const_spec"): + name_node = spec.child_by_field_name("name") + if name_node is not None: + names.append(analyzer.get_node_text(name_node)) + return names + + +def _collect_identifiers(node: Node | None) -> set[str]: + if node is None: + return set() + ids: set[str] = set() + stack = [node] + while stack: + n = stack.pop() + if n.type in ("identifier", "type_identifier"): + text = n.parent + if text is not None and text.type not in ("parameter_declaration", "short_var_declaration"): + ids.add(n.text.decode("utf-8") if n.text else "") + elif text is not None and text.type == "short_var_declaration": + name_node = text.child_by_field_name("left") + if name_node is not n and (name_node is None or n not in (name_node, *tuple(name_node.children))): + ids.add(n.text.decode("utf-8") if n.text else "") + stack.extend(n.children) + ids.discard("") + return ids + + def _find_preceding_comment_line(node: Node) -> int | None: prev = node.prev_named_sibling if prev is None: diff --git a/codeflash/languages/golang/replacement.py b/codeflash/languages/golang/replacement.py index 63dc1dcc4..68b1c811e 100644 --- a/codeflash/languages/golang/replacement.py +++ b/codeflash/languages/golang/replacement.py @@ -55,6 +55,11 @@ def replace_function( def add_global_declarations(optimized_code: str, original_source: str, analyzer: GoAnalyzer | None = None) -> str: analyzer = analyzer or GoAnalyzer() + merged = _merge_imports(optimized_code, original_source, analyzer) + return _merge_global_var_const(optimized_code, merged, analyzer) + + +def _merge_imports(optimized_code: str, original_source: str, analyzer: GoAnalyzer) -> str: opt_imports = analyzer.find_imports(optimized_code) orig_imports = analyzer.find_imports(original_source) orig_paths = {imp.path for imp in orig_imports} @@ -91,6 +96,74 @@ def add_global_declarations(optimized_code: str, original_source: str, analyzer: return "".join([*lines[:insert_at], import_block, *lines[insert_at:]]) +def _merge_global_var_const(optimized_code: str, original_source: str, analyzer: GoAnalyzer) -> str: + opt_decls = analyzer.find_global_declarations(optimized_code) + if not opt_decls: + return original_source + + orig_decls = analyzer.find_global_declarations(original_source) + orig_names_to_decl: dict[str, object] = {} + for decl in orig_decls: + for name in decl.names: + orig_names_to_decl[name] = decl + + new_decls: list[str] = [] + replaced_decls: set[int] = set() + + for opt_decl in opt_decls: + overlapping_orig = None + for name in opt_decl.names: + if name in orig_names_to_decl: + overlapping_orig = orig_names_to_decl[name] + break + + if overlapping_orig is None: + new_decls.append(opt_decl.source_code) + elif overlapping_orig.source_code.strip() != opt_decl.source_code.strip(): + orig_id = id(overlapping_orig) + if orig_id not in replaced_decls: + replaced_decls.add(orig_id) + original_source = _replace_declaration_block(original_source, overlapping_orig, opt_decl.source_code) + + if new_decls: + original_source = _insert_new_declarations(original_source, new_decls, analyzer) + + return original_source + + +def _replace_declaration_block(source: str, orig_decl: object, new_source_code: str) -> str: + lines = source.splitlines(keepends=True) + start = orig_decl.starting_line - 1 + end = orig_decl.ending_line + replacement = new_source_code.rstrip("\n") + "\n" + return "".join([*lines[:start], replacement, *lines[end:]]) + + +def _insert_new_declarations(source: str, new_decls: list[str], analyzer: GoAnalyzer) -> str: + lines = source.splitlines(keepends=True) + + insert_at = _find_declarations_insert_point(source, analyzer) + + block = "\n".join(new_decls) + "\n\n" + return "".join([*lines[:insert_at], block, *lines[insert_at:]]) + + +def _find_declarations_insert_point(source: str, analyzer: GoAnalyzer) -> int: + tree = analyzer.parse(source) + last_line = 0 + for node in tree.root_node.children: + if node.type in ("import_declaration", "var_declaration", "const_declaration"): + candidate = node.end_point.row + 1 + last_line = max(last_line, candidate) + if last_line > 0: + return last_line + + for node in tree.root_node.children: + if node.type == "package_clause": + return node.end_point.row + 1 + return 0 + + def remove_test_functions(test_source: str, functions_to_remove: list[str], analyzer: GoAnalyzer | None = None) -> str: analyzer = analyzer or GoAnalyzer() tree = analyzer.parse(test_source) diff --git a/codeflash/languages/golang/support.py b/codeflash/languages/golang/support.py index 4aacf1768..ec68c7d08 100644 --- a/codeflash/languages/golang/support.py +++ b/codeflash/languages/golang/support.py @@ -85,7 +85,9 @@ def test_result_serialization_format(self) -> str: @property def function_optimizer_class(self) -> type: - raise NotImplementedError + from codeflash.languages.golang.function_optimizer import GoFunctionOptimizer + + return GoFunctionOptimizer def discover_functions( self, source: str, file_path: Path, filter_criteria: FunctionFilterCriteria | None = None diff --git a/tests/test_languages/test_golang/test_context.py b/tests/test_languages/test_golang/test_context.py index a21844d7a..c79092511 100644 --- a/tests/test_languages/test_golang/test_context.py +++ b/tests/test_languages/test_golang/test_context.py @@ -195,6 +195,120 @@ def test_multi_import(self, tmp_path: Path) -> None: assert ctx.imports == ['"fmt"', '"os"', 'str "strings"'] +GO_SOURCE_WITH_INIT = """\ +package server + +import "sync" + +var ( +\tglobalCache map[string]int +\tmu sync.Mutex +) + +const MaxRetries = 5 + +type Config struct { +\tName string +\tMax int +} + +func init() { +\tglobalCache = make(map[string]int) +\tglobalCache["default"] = 0 +\tmu.Lock() +\tmu.Unlock() +} + +func Process() int { +\treturn MaxRetries +} +""" + + +class TestExtractCodeContextWithInit: + def test_init_in_read_only_context(self, tmp_path: Path) -> None: + source_file = (tmp_path / "server.go").resolve() + source_file.write_text(GO_SOURCE_WITH_INIT, encoding="utf-8") + func = FunctionToOptimize(function_name="Process", file_path=source_file, language="go") + ctx = extract_code_context(func, tmp_path.resolve()) + assert "func init()" in ctx.read_only_context + + def test_init_referenced_globals_in_read_only_context(self, tmp_path: Path) -> None: + source_file = (tmp_path / "server.go").resolve() + source_file.write_text(GO_SOURCE_WITH_INIT, encoding="utf-8") + func = FunctionToOptimize(function_name="Process", file_path=source_file, language="go") + ctx = extract_code_context(func, tmp_path.resolve()) + assert "globalCache" in ctx.read_only_context + assert "mu" in ctx.read_only_context + + def test_init_not_in_helpers(self, tmp_path: Path) -> None: + source_file = (tmp_path / "server.go").resolve() + source_file.write_text(GO_SOURCE_WITH_INIT, encoding="utf-8") + func = FunctionToOptimize(function_name="Process", file_path=source_file, language="go") + ctx = extract_code_context(func, tmp_path.resolve()) + helper_names = [h.name for h in ctx.helper_functions] + assert "init" not in helper_names + + def test_no_init_no_extra_context(self, tmp_path: Path) -> None: + source_file = (tmp_path / "calc.go").resolve() + source_file.write_text(GO_SOURCE_WITH_METHOD, encoding="utf-8") + func = FunctionToOptimize(function_name="Add", file_path=source_file, language="go") + ctx = extract_code_context(func, tmp_path.resolve()) + assert "func init()" not in ctx.read_only_context + + def test_full_init_read_only_context(self, tmp_path: Path) -> None: + source_file = (tmp_path / "server.go").resolve() + source_file.write_text(GO_SOURCE_WITH_INIT, encoding="utf-8") + func = FunctionToOptimize(function_name="Process", file_path=source_file, language="go") + ctx = extract_code_context(func, tmp_path.resolve()) + expected = ( + "var (\n" + "\tglobalCache map[string]int\n" + "\tmu sync.Mutex\n" + ")\n" + "\n" + "func init() {\n" + "\tglobalCache = make(map[string]int)\n" + "\tglobalCache[\"default\"] = 0\n" + "\tmu.Lock()\n" + "\tmu.Unlock()\n" + "}" + ) + assert ctx.read_only_context == expected + + def test_method_with_init_combines_struct_and_init_context(self, tmp_path: Path) -> None: + source = """\ +package server + +var globalOffset = 10 + +type Calc struct { +\tVal int +} + +func init() { +\tglobalOffset = 42 +} + +func (c *Calc) Compute() int { +\treturn c.Val + globalOffset +} +""" + source_file = (tmp_path / "server.go").resolve() + source_file.write_text(source, encoding="utf-8") + func = FunctionToOptimize( + function_name="Compute", + file_path=source_file, + parents=[FunctionParent(name="Calc", type="StructDef")], + language="go", + is_method=True, + ) + ctx = extract_code_context(func, tmp_path.resolve()) + assert "type Calc struct" in ctx.read_only_context + assert "func init()" in ctx.read_only_context + assert "var globalOffset = 10" in ctx.read_only_context + + class TestFindHelperFunctions: def test_skips_init_and_main(self, tmp_path: Path) -> None: source = "package main\n\nfunc init() { println() }\n\nfunc main() { println() }\n\nfunc Target() int { return 1 }\n" diff --git a/tests/test_languages/test_golang/test_function_optimizer.py b/tests/test_languages/test_golang/test_function_optimizer.py new file mode 100644 index 000000000..60915223b --- /dev/null +++ b/tests/test_languages/test_golang/test_function_optimizer.py @@ -0,0 +1,653 @@ +from __future__ import annotations + +import hashlib +from pathlib import Path +from textwrap import dedent +from typing import TYPE_CHECKING + +import pytest + +from codeflash.languages.golang.context import extract_code_context +from codeflash.languages.golang.function_optimizer import _build_optimization_context +from codeflash.models.function_types import FunctionParent, FunctionToOptimize + +if TYPE_CHECKING: + from codeflash.models.models import CodeOptimizationContext + +# --------------------------------------------------------------------------- +# Realistic Go sources used across test classes +# --------------------------------------------------------------------------- + +CALCULATOR_SOURCE = dedent("""\ + package calc + + import ( + \t"fmt" + \t"math" + \tstr "strings" + ) + + // Calculator holds running computation state. + type Calculator struct { + \tResult float64 + \tHistory []float64 + } + + // Formatter controls output rendering. + type Formatter interface { + \tFormat(val float64) string + } + + // Add returns the sum of two integers. + func Add(a, b int) int { + \treturn a + b + } + + func subtract(a, b int) int { + \treturn a - b + } + + func multiply(a, b int) int { + \treturn a * b + } + + // Greet builds a greeting message. + func Greet(name string) string { + \treturn fmt.Sprintf("Hello, %s", str.TrimSpace(name)) + } + + // AddFloat adds a float value and records history. + func (c *Calculator) AddFloat(val float64) float64 { + \tc.Result += val + \tc.History = append(c.History, c.Result) + \treturn c.Result + } + + // Sqrt computes the square root of the current result. + func (c *Calculator) Sqrt() float64 { + \tc.Result = math.Sqrt(c.Result) + \tc.History = append(c.History, c.Result) + \treturn c.Result + } + + // Reset zeroes out the calculator. + func (c Calculator) Reset() Calculator { + \tc.Result = 0 + \tc.History = nil + \treturn c + } +""") + +SIMPLE_SOURCE = dedent("""\ + package simple + + func Double(x int) int { + \treturn x * 2 + } +""") + +INIT_SOURCE = dedent("""\ + package server + + import ( + \t"fmt" + \t"sync" + ) + + var ( + \tglobalCache map[string]int + \tmu sync.Mutex + ) + + var singleVar = 42 + + const MaxRetries = 5 + + type Config struct { + \tName string + \tMax int + } + + func init() { + \tglobalCache = make(map[string]int) + \tglobalCache["default"] = 0 + \tdefaultCfg := Config{Name: "prod", Max: MaxRetries} + \t_ = defaultCfg + \tmu.Lock() + \tmu.Unlock() + } + + func Process() int { + \tfmt.Println("processing") + \treturn singleVar + MaxRetries + } +""") + + +# --------------------------------------------------------------------------- +# Helpers to drive the full extract → build pipeline +# --------------------------------------------------------------------------- + + +def _build_context_for_function( + source: str, + filename: str, + function_name: str, + tmp_path: Path, + parents: list[FunctionParent] | None = None, + is_method: bool = False, +) -> CodeOptimizationContext: + root = tmp_path.resolve() + source_file = (root / filename).resolve() + source_file.write_text(source, encoding="utf-8") + + func = FunctionToOptimize( + function_name=function_name, file_path=source_file, parents=parents or [], language="go", is_method=is_method + ) + code_context = extract_code_context(func, root) + return _build_optimization_context(code_context, source_file, "go", root) + + +# --------------------------------------------------------------------------- +# Tests: targeting a plain exported function +# --------------------------------------------------------------------------- + + +class TestBuildContextExportedFunction: + """Target: Add(a, b int) int — a plain exported function with a doc comment.""" + + def test_full_assembled_code_string(self, tmp_path: Path) -> None: + result = _build_context_for_function(CALCULATOR_SOURCE, "calc.go", "Add", tmp_path) + code = result.read_writable_code.code_strings[0].code + + expected = dedent("""\ + import ( + \t"fmt" + \t"math" + \tstr "strings" + ) + + // Add returns the sum of two integers. + func Add(a, b int) int { + \treturn a + b + } + + + func subtract(a, b int) int { + \treturn a - b + } + + func multiply(a, b int) int { + \treturn a * b + } + + // Greet builds a greeting message. + func Greet(name string) string { + \treturn fmt.Sprintf("Hello, %s", str.TrimSpace(name)) + } + + + // AddFloat adds a float value and records history. + func (c *Calculator) AddFloat(val float64) float64 { + \tc.Result += val + \tc.History = append(c.History, c.Result) + \treturn c.Result + } + + + // Sqrt computes the square root of the current result. + func (c *Calculator) Sqrt() float64 { + \tc.Result = math.Sqrt(c.Result) + \tc.History = append(c.History, c.Result) + \treturn c.Result + } + + + // Reset zeroes out the calculator. + func (c Calculator) Reset() Calculator { + \tc.Result = 0 + \tc.History = nil + \treturn c + } + """) + assert code == expected + + def test_code_excludes_package_clause(self, tmp_path: Path) -> None: + result = _build_context_for_function(CALCULATOR_SOURCE, "calc.go", "Add", tmp_path) + code = result.read_writable_code.code_strings[0].code + assert "package calc" not in code + + def test_code_excludes_struct_definition(self, tmp_path: Path) -> None: + result = _build_context_for_function(CALCULATOR_SOURCE, "calc.go", "Add", tmp_path) + code = result.read_writable_code.code_strings[0].code + assert "type Calculator struct" not in code + + def test_code_excludes_interface_definition(self, tmp_path: Path) -> None: + result = _build_context_for_function(CALCULATOR_SOURCE, "calc.go", "Add", tmp_path) + code = result.read_writable_code.code_strings[0].code + assert "type Formatter interface" not in code + + def test_helpers_include_other_functions_and_methods(self, tmp_path: Path) -> None: + result = _build_context_for_function(CALCULATOR_SOURCE, "calc.go", "Add", tmp_path) + helper_names = sorted(h.only_function_name for h in result.helper_functions) + assert "subtract" in helper_names + assert "multiply" in helper_names + assert "Greet" in helper_names + assert "AddFloat" in helper_names + assert "Sqrt" in helper_names + assert "Reset" in helper_names + assert "Add" not in helper_names + + def test_helper_sources_are_full_functions(self, tmp_path: Path) -> None: + result = _build_context_for_function(CALCULATOR_SOURCE, "calc.go", "Add", tmp_path) + by_name = {h.only_function_name: h for h in result.helper_functions} + + assert by_name["subtract"].source_code == dedent("""\ + func subtract(a, b int) int { + \treturn a - b + }""") + + assert by_name["Greet"].source_code == dedent("""\ + // Greet builds a greeting message. + func Greet(name string) string { + \treturn fmt.Sprintf("Hello, %s", str.TrimSpace(name)) + } + """) + + def test_method_helpers_have_qualified_names(self, tmp_path: Path) -> None: + result = _build_context_for_function(CALCULATOR_SOURCE, "calc.go", "Add", tmp_path) + by_name = {h.only_function_name: h for h in result.helper_functions} + assert by_name["AddFloat"].qualified_name == "Calculator.AddFloat" + assert by_name["AddFloat"].fully_qualified_name == "Calculator.AddFloat" + assert by_name["subtract"].qualified_name == "subtract" + + def test_no_read_only_context_for_plain_function(self, tmp_path: Path) -> None: + result = _build_context_for_function(CALCULATOR_SOURCE, "calc.go", "Add", tmp_path) + assert result.read_only_context_code == "" + + def test_relative_path(self, tmp_path: Path) -> None: + result = _build_context_for_function(CALCULATOR_SOURCE, "calc.go", "Add", tmp_path) + assert result.read_writable_code.code_strings[0].file_path == Path("calc.go") + + def test_language_tag(self, tmp_path: Path) -> None: + result = _build_context_for_function(CALCULATOR_SOURCE, "calc.go", "Add", tmp_path) + assert result.read_writable_code.code_strings[0].language == "go" + + def test_testgen_fqns_match_helpers(self, tmp_path: Path) -> None: + result = _build_context_for_function(CALCULATOR_SOURCE, "calc.go", "Add", tmp_path) + fqns = set(result.testgen_helper_fqns) + helper_fqns = {h.fully_qualified_name for h in result.helper_functions} + assert fqns == helper_fqns + + +# --------------------------------------------------------------------------- +# Tests: targeting a method with a pointer receiver +# --------------------------------------------------------------------------- + + +class TestBuildContextPointerReceiverMethod: + """Target: (c *Calculator) AddFloat(val float64) — pointer receiver method.""" + + def _build(self, tmp_path: Path) -> CodeOptimizationContext: + return _build_context_for_function( + CALCULATOR_SOURCE, + "calc.go", + "AddFloat", + tmp_path, + parents=[FunctionParent(name="Calculator", type="StructDef")], + is_method=True, + ) + + def test_full_assembled_code_string(self, tmp_path: Path) -> None: + result = self._build(tmp_path) + code = result.read_writable_code.code_strings[0].code + + expected = dedent("""\ + import ( + \t"fmt" + \t"math" + \tstr "strings" + ) + + // AddFloat adds a float value and records history. + func (c *Calculator) AddFloat(val float64) float64 { + \tc.Result += val + \tc.History = append(c.History, c.Result) + \treturn c.Result + } + + + // Add returns the sum of two integers. + func Add(a, b int) int { + \treturn a + b + } + + + func subtract(a, b int) int { + \treturn a - b + } + + func multiply(a, b int) int { + \treturn a * b + } + + // Greet builds a greeting message. + func Greet(name string) string { + \treturn fmt.Sprintf("Hello, %s", str.TrimSpace(name)) + } + + + // Sqrt computes the square root of the current result. + func (c *Calculator) Sqrt() float64 { + \tc.Result = math.Sqrt(c.Result) + \tc.History = append(c.History, c.Result) + \treturn c.Result + } + + + // Reset zeroes out the calculator. + func (c Calculator) Reset() Calculator { + \tc.Result = 0 + \tc.History = nil + \treturn c + } + """) + assert code == expected + + def test_code_excludes_package_and_type_defs(self, tmp_path: Path) -> None: + result = self._build(tmp_path) + code = result.read_writable_code.code_strings[0].code + assert "package calc" not in code + assert "type Calculator struct" not in code + assert "type Formatter interface" not in code + + def test_read_only_context_is_struct_definition(self, tmp_path: Path) -> None: + result = self._build(tmp_path) + assert result.read_only_context_code == dedent("""\ + type Calculator struct { + \tResult float64 + \tHistory []float64 + }""") + + def test_helpers_exclude_self_include_others(self, tmp_path: Path) -> None: + result = self._build(tmp_path) + helper_names = sorted(h.only_function_name for h in result.helper_functions) + assert "AddFloat" not in helper_names + assert "Add" in helper_names + assert "subtract" in helper_names + assert "multiply" in helper_names + assert "Greet" in helper_names + assert "Sqrt" in helper_names + assert "Reset" in helper_names + + def test_target_not_duplicated_in_code_string(self, tmp_path: Path) -> None: + result = self._build(tmp_path) + code = result.read_writable_code.code_strings[0].code + assert code.count("func (c *Calculator) AddFloat") == 1 + + +# --------------------------------------------------------------------------- +# Tests: targeting a value receiver method +# --------------------------------------------------------------------------- + + +class TestBuildContextValueReceiverMethod: + """Target: (c Calculator) Reset() — value receiver method.""" + + def _build(self, tmp_path: Path) -> CodeOptimizationContext: + return _build_context_for_function( + CALCULATOR_SOURCE, + "calc.go", + "Reset", + tmp_path, + parents=[FunctionParent(name="Calculator", type="StructDef")], + is_method=True, + ) + + def test_target_in_code_string(self, tmp_path: Path) -> None: + result = self._build(tmp_path) + code = result.read_writable_code.code_strings[0].code + + expected_target = dedent("""\ + // Reset zeroes out the calculator. + func (c Calculator) Reset() Calculator { + \tc.Result = 0 + \tc.History = nil + \treturn c + }""") + assert code.count("func (c Calculator) Reset()") == 1 + assert expected_target in code + + def test_helpers_include_other_methods_on_same_struct(self, tmp_path: Path) -> None: + result = self._build(tmp_path) + helper_names = sorted(h.only_function_name for h in result.helper_functions) + assert "Reset" not in helper_names + assert "AddFloat" in helper_names + assert "Sqrt" in helper_names + assert "Add" in helper_names + + def test_helper_code_in_assembled_string(self, tmp_path: Path) -> None: + result = self._build(tmp_path) + code = result.read_writable_code.code_strings[0].code + assert "func (c *Calculator) AddFloat" in code + assert "func (c *Calculator) Sqrt()" in code + assert "func Add(a, b int) int" in code + assert "func subtract(a, b int) int" in code + + def test_struct_in_read_only_context(self, tmp_path: Path) -> None: + result = self._build(tmp_path) + assert result.read_only_context_code == dedent("""\ + type Calculator struct { + \tResult float64 + \tHistory []float64 + }""") + + +# --------------------------------------------------------------------------- +# Tests: simple source with no imports, no methods, one function +# --------------------------------------------------------------------------- + + +class TestBuildContextMinimalSource: + """Target: Double(x int) — minimal file with no imports or structs.""" + + def test_no_imports_no_prefix(self, tmp_path: Path) -> None: + result = _build_context_for_function(SIMPLE_SOURCE, "simple.go", "Double", tmp_path) + code = result.read_writable_code.code_strings[0].code + assert code == dedent("""\ + func Double(x int) int { + \treturn x * 2 + }""") + + def test_no_helpers(self, tmp_path: Path) -> None: + result = _build_context_for_function(SIMPLE_SOURCE, "simple.go", "Double", tmp_path) + assert result.helper_functions == [] + assert result.testgen_helper_fqns == [] + + def test_empty_read_only_context(self, tmp_path: Path) -> None: + result = _build_context_for_function(SIMPLE_SOURCE, "simple.go", "Double", tmp_path) + assert result.read_only_context_code == "" + + def test_preexisting_objects_empty(self, tmp_path: Path) -> None: + result = _build_context_for_function(SIMPLE_SOURCE, "simple.go", "Double", tmp_path) + assert result.preexisting_objects == set() + + +# --------------------------------------------------------------------------- +# Tests: init function and globals in context +# --------------------------------------------------------------------------- + + +class TestBuildContextWithInit: + """Target: Process() — source has init(), global vars, consts, struct.""" + + def test_init_in_read_only_context(self, tmp_path: Path) -> None: + result = _build_context_for_function(INIT_SOURCE, "server.go", "Process", tmp_path) + assert "func init()" in result.read_only_context_code + + def test_referenced_globals_in_read_only_context(self, tmp_path: Path) -> None: + result = _build_context_for_function(INIT_SOURCE, "server.go", "Process", tmp_path) + assert "globalCache" in result.read_only_context_code + assert "mu" in result.read_only_context_code + + def test_referenced_const_in_read_only_context(self, tmp_path: Path) -> None: + result = _build_context_for_function(INIT_SOURCE, "server.go", "Process", tmp_path) + assert "MaxRetries" in result.read_only_context_code + + def test_referenced_struct_in_read_only_context(self, tmp_path: Path) -> None: + result = _build_context_for_function(INIT_SOURCE, "server.go", "Process", tmp_path) + assert "type Config struct" in result.read_only_context_code + + def test_init_not_in_helpers(self, tmp_path: Path) -> None: + result = _build_context_for_function(INIT_SOURCE, "server.go", "Process", tmp_path) + helper_names = [h.only_function_name for h in result.helper_functions] + assert "init" not in helper_names + + def test_init_not_in_read_writable_code(self, tmp_path: Path) -> None: + result = _build_context_for_function(INIT_SOURCE, "server.go", "Process", tmp_path) + code = result.read_writable_code.code_strings[0].code + assert "func init()" not in code + + def test_full_read_only_context_string(self, tmp_path: Path) -> None: + result = _build_context_for_function(INIT_SOURCE, "server.go", "Process", tmp_path) + expected = dedent("""\ + var ( + \tglobalCache map[string]int + \tmu sync.Mutex + ) + + const MaxRetries = 5 + + type Config struct { + \tName string + \tMax int + } + + func init() { + \tglobalCache = make(map[string]int) + \tglobalCache["default"] = 0 + \tdefaultCfg := Config{Name: "prod", Max: MaxRetries} + \t_ = defaultCfg + \tmu.Lock() + \tmu.Unlock() + }""") + assert result.read_only_context_code == expected + + +class TestBuildContextNoInit: + """Source without init — verify no init context is added.""" + + def test_no_init_no_extra_read_only(self, tmp_path: Path) -> None: + result = _build_context_for_function(CALCULATOR_SOURCE, "calc.go", "Add", tmp_path) + assert "func init()" not in result.read_only_context_code + + def test_no_init_read_only_empty_for_function(self, tmp_path: Path) -> None: + result = _build_context_for_function(CALCULATOR_SOURCE, "calc.go", "Add", tmp_path) + assert result.read_only_context_code == "" + + +# --------------------------------------------------------------------------- +# Tests: subdirectory / relative path handling +# --------------------------------------------------------------------------- + + +class TestBuildContextSubdirectory: + """Source file in a pkg/ subdirectory.""" + + def test_relative_path_includes_subdir(self, tmp_path: Path) -> None: + root = tmp_path.resolve() + pkg = root / "pkg" + pkg.mkdir() + source_file = (pkg / "calc.go").resolve() + source_file.write_text(SIMPLE_SOURCE, encoding="utf-8") + + func = FunctionToOptimize(function_name="Double", file_path=source_file, language="go") + ctx = extract_code_context(func, root) + result = _build_optimization_context(ctx, source_file, "go", root) + + assert result.read_writable_code.code_strings[0].file_path == Path("pkg/calc.go") + + +# --------------------------------------------------------------------------- +# Tests: hashing +# --------------------------------------------------------------------------- + + +class TestBuildContextHashing: + def test_hash_is_sha256_of_flat(self, tmp_path: Path) -> None: + result = _build_context_for_function(CALCULATOR_SOURCE, "calc.go", "Add", tmp_path) + expected_hash = hashlib.sha256(result.read_writable_code.flat.encode("utf-8")).hexdigest() + assert result.hashing_code_context_hash == expected_hash + + def test_hashing_code_equals_flat(self, tmp_path: Path) -> None: + result = _build_context_for_function(CALCULATOR_SOURCE, "calc.go", "Add", tmp_path) + assert result.hashing_code_context == result.read_writable_code.flat + + def test_different_targets_different_hashes(self, tmp_path: Path) -> None: + dir_a = tmp_path / "a" + dir_a.mkdir() + dir_b = tmp_path / "b" + dir_b.mkdir() + + r1 = _build_context_for_function(CALCULATOR_SOURCE, "calc.go", "Add", dir_a) + r2 = _build_context_for_function(CALCULATOR_SOURCE, "calc.go", "Greet", dir_b) + + assert r1.hashing_code_context_hash != r2.hashing_code_context_hash + + +# --------------------------------------------------------------------------- +# Tests: testgen context +# --------------------------------------------------------------------------- + + +class TestBuildContextTestgen: + def test_testgen_matches_read_writable(self, tmp_path: Path) -> None: + result = _build_context_for_function(CALCULATOR_SOURCE, "calc.go", "Add", tmp_path) + assert result.testgen_context.markdown == result.read_writable_code.markdown + + +# --------------------------------------------------------------------------- +# Tests: token limit enforcement +# --------------------------------------------------------------------------- + + +class TestBuildContextTokenLimits: + def test_exceeds_optim_token_limit(self, tmp_path: Path) -> None: + root = tmp_path.resolve() + source_file = (root / "big.go").resolve() + huge_code = "package big\n\nfunc Big() string {\n\treturn " + '"x" + ' * 100000 + '"x"\n}\n' + source_file.write_text(huge_code, encoding="utf-8") + + func = FunctionToOptimize(function_name="Big", file_path=source_file, language="go") + ctx = extract_code_context(func, root) + + with pytest.raises(ValueError, match="Read-writable code has exceeded token limit"): + _build_optimization_context(ctx, source_file, "go", root, optim_token_limit=10) + + def test_exceeds_testgen_token_limit(self, tmp_path: Path) -> None: + root = tmp_path.resolve() + source_file = (root / "big.go").resolve() + huge_code = "package big\n\nfunc Big() string {\n\treturn " + '"x" + ' * 100000 + '"x"\n}\n' + source_file.write_text(huge_code, encoding="utf-8") + + func = FunctionToOptimize(function_name="Big", file_path=source_file, language="go") + ctx = extract_code_context(func, root) + + with pytest.raises(ValueError, match="Testgen code context has exceeded token limit"): + _build_optimization_context( + ctx, source_file, "go", root, optim_token_limit=1_000_000, testgen_token_limit=10 + ) + + +# --------------------------------------------------------------------------- +# Tests: GoSupport wiring +# --------------------------------------------------------------------------- + + +class TestGoSupportFunctionOptimizerClass: + def test_returns_go_function_optimizer(self) -> None: + from codeflash.languages.golang.function_optimizer import GoFunctionOptimizer + from codeflash.languages.golang.support import GoSupport + + support = GoSupport() + assert support.function_optimizer_class is GoFunctionOptimizer diff --git a/tests/test_languages/test_golang/test_parser.py b/tests/test_languages/test_golang/test_parser.py index 5ce663227..2179e92db 100644 --- a/tests/test_languages/test_golang/test_parser.py +++ b/tests/test_languages/test_golang/test_parser.py @@ -216,3 +216,125 @@ def test_extract_method_source(self) -> None: def test_extract_nonexistent(self) -> None: analyzer = GoAnalyzer() assert analyzer.extract_function_source(GO_SOURCE, "DoesNotExist") is None + + +GLOBALS_SOURCE = """\ +package server + +import "sync" + +var ( +\tglobalCache map[string]int +\tmu sync.Mutex +) + +var singleVar = 42 + +const MaxRetries = 5 + +const ( +\tDefaultName = "prod" +\tTimeout = 30 +) + +type Config struct { +\tName string +\tMax int +} + +func init() { +\tglobalCache = make(map[string]int) +\tglobalCache["default"] = 0 +\tdefaultCfg := Config{Name: DefaultName, Max: MaxRetries} +\t_ = defaultCfg +\tmu.Lock() +\tmu.Unlock() +} + +func Process() int { +\treturn singleVar + MaxRetries +} +""" + + +class TestGoAnalyzerGlobalDeclarations: + def test_find_var_group(self) -> None: + analyzer = GoAnalyzer() + decls = analyzer.find_global_declarations(GLOBALS_SOURCE) + var_decls = [d for d in decls if d.kind == "var"] + all_names = [name for d in var_decls for name in d.names] + assert "globalCache" in all_names + assert "mu" in all_names + assert "singleVar" in all_names + + def test_find_const_group(self) -> None: + analyzer = GoAnalyzer() + decls = analyzer.find_global_declarations(GLOBALS_SOURCE) + const_decls = [d for d in decls if d.kind == "const"] + all_names = [name for d in const_decls for name in d.names] + assert "MaxRetries" in all_names + assert "DefaultName" in all_names + assert "Timeout" in all_names + + def test_grouped_var_names_together(self) -> None: + analyzer = GoAnalyzer() + decls = analyzer.find_global_declarations(GLOBALS_SOURCE) + var_group = next(d for d in decls if "globalCache" in d.names) + assert var_group.names == ("globalCache", "mu") + + def test_single_var(self) -> None: + analyzer = GoAnalyzer() + decls = analyzer.find_global_declarations(GLOBALS_SOURCE) + single = next(d for d in decls if "singleVar" in d.names) + assert single.kind == "var" + assert single.source_code == "var singleVar = 42" + + def test_const_group_source_code(self) -> None: + analyzer = GoAnalyzer() + decls = analyzer.find_global_declarations(GLOBALS_SOURCE) + group = next(d for d in decls if "DefaultName" in d.names) + assert "DefaultName" in group.source_code + assert "Timeout" in group.source_code + + def test_no_globals_in_clean_source(self) -> None: + analyzer = GoAnalyzer() + decls = analyzer.find_global_declarations("package main\n\nfunc main() {}\n") + assert decls == [] + + +class TestGoAnalyzerCollectBodyIdentifiers: + def test_init_body_identifiers(self) -> None: + analyzer = GoAnalyzer() + ids = analyzer.collect_body_identifiers(GLOBALS_SOURCE, "init") + assert "globalCache" in ids + assert "Config" in ids + assert "DefaultName" in ids + assert "MaxRetries" in ids + assert "mu" in ids + + def test_process_body_identifiers(self) -> None: + analyzer = GoAnalyzer() + ids = analyzer.collect_body_identifiers(GLOBALS_SOURCE, "Process") + assert "singleVar" in ids + assert "MaxRetries" in ids + + def test_nonexistent_function_returns_empty(self) -> None: + analyzer = GoAnalyzer() + ids = analyzer.collect_body_identifiers(GLOBALS_SOURCE, "DoesNotExist") + assert ids == set() + + def test_method_body_identifiers(self) -> None: + source = """\ +package calc + +type Calc struct{ val int } + +var offset = 10 + +func (c *Calc) Compute() int { +\treturn c.val + offset +} +""" + analyzer = GoAnalyzer() + ids = analyzer.collect_body_identifiers(source, "Compute", receiver_type="Calc") + assert "offset" in ids diff --git a/tests/test_languages/test_golang/test_replacement.py b/tests/test_languages/test_golang/test_replacement.py index c36f3a012..5cd444aab 100644 --- a/tests/test_languages/test_golang/test_replacement.py +++ b/tests/test_languages/test_golang/test_replacement.py @@ -74,7 +74,7 @@ def test_replace_preserves_surrounding_code(self) -> None: assert result == expected -class TestAddGlobalDeclarations: +class TestAddGlobalDeclarationsImports: def test_add_import_to_existing_block(self) -> None: original = 'package calc\n\nimport (\n\t"fmt"\n)\n\nfunc Add(a, b int) int {\n\treturn a + b\n}\n' optimized = 'package calc\n\nimport (\n\t"fmt"\n\t"math"\n)\n\nfunc Add(a, b int) int {\n\treturn a + b\n}\n' @@ -102,6 +102,464 @@ def test_no_new_imports_returns_unchanged(self) -> None: assert result == source +class TestAddGlobalDeclarationsNewVar: + def test_add_single_new_var(self) -> None: + original = ( + "package calc\n\n" + "func Add(a, b int) int {\n" + "\treturn a + b\n" + "}\n" + ) + optimized = ( + "package calc\n\n" + "var cache = make(map[int]int)\n\n" + "func Add(a, b int) int {\n" + "\treturn a + b\n" + "}\n" + ) + result = add_global_declarations(optimized, original) + expected = ( + "package calc\n" + "var cache = make(map[int]int)\n\n" + "\n" + "func Add(a, b int) int {\n" + "\treturn a + b\n" + "}\n" + ) + assert result == expected + + def test_add_grouped_var_block(self) -> None: + original = ( + "package server\n\n" + 'import "fmt"\n\n' + "func Process() {\n" + "\tfmt.Println()\n" + "}\n" + ) + optimized = ( + "package server\n\n" + 'import "fmt"\n\n' + "var (\n" + "\tcache map[string]int\n" + "\tbuffer []byte\n" + ")\n\n" + "func Process() {\n" + "\tfmt.Println()\n" + "}\n" + ) + result = add_global_declarations(optimized, original) + expected = ( + "package server\n\n" + 'import "fmt"\n' + "var (\n" + "\tcache map[string]int\n" + "\tbuffer []byte\n" + ")\n\n" + "\n" + "func Process() {\n" + "\tfmt.Println()\n" + "}\n" + ) + assert result == expected + + def test_add_new_var_preserves_existing_var(self) -> None: + original = ( + "package calc\n\n" + "var version = 1\n\n" + "func Add(a, b int) int {\n" + "\treturn a + b\n" + "}\n" + ) + optimized = ( + "package calc\n\n" + "var version = 1\n\n" + "var cache = make(map[int]int)\n\n" + "func Add(a, b int) int {\n" + "\treturn a + b\n" + "}\n" + ) + result = add_global_declarations(optimized, original) + expected = ( + "package calc\n\n" + "var version = 1\n" + "var cache = make(map[int]int)\n\n" + "\n" + "func Add(a, b int) int {\n" + "\treturn a + b\n" + "}\n" + ) + assert result == expected + + +class TestAddGlobalDeclarationsNewConst: + def test_add_single_new_const(self) -> None: + original = ( + "package calc\n\n" + "func Add(a, b int) int {\n" + "\treturn a + b\n" + "}\n" + ) + optimized = ( + "package calc\n\n" + "const maxSize = 1024\n\n" + "func Add(a, b int) int {\n" + "\treturn a + b\n" + "}\n" + ) + result = add_global_declarations(optimized, original) + expected = ( + "package calc\n" + "const maxSize = 1024\n\n" + "\n" + "func Add(a, b int) int {\n" + "\treturn a + b\n" + "}\n" + ) + assert result == expected + + def test_add_grouped_const_block(self) -> None: + original = ( + "package calc\n\n" + "func Add(a, b int) int {\n" + "\treturn a + b\n" + "}\n" + ) + optimized = ( + "package calc\n\n" + "const (\n" + "\tMaxRetries = 5\n" + "\tTimeout = 30\n" + ")\n\n" + "func Add(a, b int) int {\n" + "\treturn a + b\n" + "}\n" + ) + result = add_global_declarations(optimized, original) + expected = ( + "package calc\n" + "const (\n" + "\tMaxRetries = 5\n" + "\tTimeout = 30\n" + ")\n\n" + "\n" + "func Add(a, b int) int {\n" + "\treturn a + b\n" + "}\n" + ) + assert result == expected + + def test_add_new_const_preserves_existing_const(self) -> None: + original = ( + "package calc\n\n" + "const Pi = 3.14\n\n" + "func Area(r float64) float64 {\n" + "\treturn Pi * r * r\n" + "}\n" + ) + optimized = ( + "package calc\n\n" + "const Pi = 3.14\n\n" + "const TwoPi = 6.28\n\n" + "func Area(r float64) float64 {\n" + "\treturn Pi * r * r\n" + "}\n" + ) + result = add_global_declarations(optimized, original) + expected = ( + "package calc\n\n" + "const Pi = 3.14\n" + "const TwoPi = 6.28\n\n" + "\n" + "func Area(r float64) float64 {\n" + "\treturn Pi * r * r\n" + "}\n" + ) + assert result == expected + + +class TestAddGlobalDeclarationsModifyVar: + def test_modify_single_var_value(self) -> None: + original = ( + "package calc\n\n" + "var bufferSize = 256\n\n" + "func Process() int {\n" + "\treturn bufferSize\n" + "}\n" + ) + optimized = ( + "package calc\n\n" + "var bufferSize = 1024\n\n" + "func Process() int {\n" + "\treturn bufferSize\n" + "}\n" + ) + result = add_global_declarations(optimized, original) + expected = ( + "package calc\n\n" + "var bufferSize = 1024\n" + "\n" + "func Process() int {\n" + "\treturn bufferSize\n" + "}\n" + ) + assert result == expected + + def test_modify_grouped_var_block(self) -> None: + original = ( + "package server\n\n" + "var (\n" + '\thost = "localhost"\n' + "\tport = 8080\n" + ")\n\n" + "func Addr() string {\n" + "\treturn host\n" + "}\n" + ) + optimized = ( + "package server\n\n" + "var (\n" + '\thost = "0.0.0.0"\n' + "\tport = 9090\n" + ")\n\n" + "func Addr() string {\n" + "\treturn host\n" + "}\n" + ) + result = add_global_declarations(optimized, original) + expected = ( + "package server\n\n" + "var (\n" + '\thost = "0.0.0.0"\n' + "\tport = 9090\n" + ")\n" + "\n" + "func Addr() string {\n" + "\treturn host\n" + "}\n" + ) + assert result == expected + + def test_modify_var_type(self) -> None: + original = ( + "package calc\n\n" + "var counter int\n\n" + "func Inc() {\n" + "\tcounter++\n" + "}\n" + ) + optimized = ( + "package calc\n\n" + "var counter int64\n\n" + "func Inc() {\n" + "\tcounter++\n" + "}\n" + ) + result = add_global_declarations(optimized, original) + expected = ( + "package calc\n\n" + "var counter int64\n" + "\n" + "func Inc() {\n" + "\tcounter++\n" + "}\n" + ) + assert result == expected + + +class TestAddGlobalDeclarationsModifyConst: + def test_modify_single_const_value(self) -> None: + original = ( + "package calc\n\n" + "const MaxRetries = 3\n\n" + "func Retries() int {\n" + "\treturn MaxRetries\n" + "}\n" + ) + optimized = ( + "package calc\n\n" + "const MaxRetries = 10\n\n" + "func Retries() int {\n" + "\treturn MaxRetries\n" + "}\n" + ) + result = add_global_declarations(optimized, original) + expected = ( + "package calc\n\n" + "const MaxRetries = 10\n" + "\n" + "func Retries() int {\n" + "\treturn MaxRetries\n" + "}\n" + ) + assert result == expected + + def test_modify_const_group(self) -> None: + original = ( + "package server\n\n" + "const (\n" + "\tDefaultTimeout = 30\n" + "\tMaxConnections = 100\n" + ")\n\n" + "func Config() int {\n" + "\treturn DefaultTimeout\n" + "}\n" + ) + optimized = ( + "package server\n\n" + "const (\n" + "\tDefaultTimeout = 60\n" + "\tMaxConnections = 500\n" + ")\n\n" + "func Config() int {\n" + "\treturn DefaultTimeout\n" + "}\n" + ) + result = add_global_declarations(optimized, original) + expected = ( + "package server\n\n" + "const (\n" + "\tDefaultTimeout = 60\n" + "\tMaxConnections = 500\n" + ")\n" + "\n" + "func Config() int {\n" + "\treturn DefaultTimeout\n" + "}\n" + ) + assert result == expected + + +class TestAddGlobalDeclarationsMixed: + def test_new_import_and_new_var(self) -> None: + original = ( + "package calc\n\n" + "func Add(a, b int) int {\n" + "\treturn a + b\n" + "}\n" + ) + optimized = ( + "package calc\n\n" + 'import "sync"\n\n' + "var mu sync.Mutex\n\n" + "func Add(a, b int) int {\n" + "\treturn a + b\n" + "}\n" + ) + result = add_global_declarations(optimized, original) + expected = ( + "package calc\n" + "import (\n" + '\t"sync"\n' + ")\n" + "var mu sync.Mutex\n\n" + "\n\n" + "func Add(a, b int) int {\n" + "\treturn a + b\n" + "}\n" + ) + assert result == expected + + def test_new_and_modified_globals_together(self) -> None: + original = ( + "package server\n\n" + "var bufferSize = 256\n\n" + "func Process() int {\n" + "\treturn bufferSize\n" + "}\n" + ) + optimized = ( + "package server\n\n" + "var bufferSize = 1024\n\n" + "var cache = make(map[string]int)\n\n" + "func Process() int {\n" + "\treturn bufferSize\n" + "}\n" + ) + result = add_global_declarations(optimized, original) + expected = ( + "package server\n\n" + "var bufferSize = 1024\n" + "var cache = make(map[string]int)\n\n" + "\n" + "func Process() int {\n" + "\treturn bufferSize\n" + "}\n" + ) + assert result == expected + + def test_no_globals_in_optimized_returns_unchanged(self) -> None: + original = ( + "package calc\n\n" + "var version = 1\n\n" + "func Add(a, b int) int {\n" + "\treturn a + b\n" + "}\n" + ) + optimized = ( + "package calc\n\n" + "func Add(a, b int) int {\n" + "\treturn a + b\n" + "}\n" + ) + result = add_global_declarations(optimized, original) + assert result == original + + def test_identical_globals_returns_unchanged(self) -> None: + source = ( + "package calc\n\n" + "var version = 1\n\n" + "const MaxSize = 100\n\n" + "func Add(a, b int) int {\n" + "\treturn a + b\n" + "}\n" + ) + result = add_global_declarations(source, source) + assert result == source + + def test_full_round_trip_new_import_var_const(self) -> None: + original = ( + "package server\n\n" + "import (\n" + '\t"fmt"\n' + ")\n\n" + "const Version = 1\n\n" + "func Handle() {\n" + "\tfmt.Println()\n" + "}\n" + ) + optimized = ( + "package server\n\n" + "import (\n" + '\t"fmt"\n' + '\t"sync"\n' + ")\n\n" + "const Version = 1\n\n" + "var mu sync.Mutex\n\n" + "const MaxConns = 100\n\n" + "func Handle() {\n" + "\tmu.Lock()\n" + "\tdefer mu.Unlock()\n" + "\tfmt.Println()\n" + "}\n" + ) + result = add_global_declarations(optimized, original) + expected = ( + "package server\n\n" + "import (\n" + '\t"fmt"\n' + '\t"sync"\n' + ")\n\n" + "const Version = 1\n" + "var mu sync.Mutex\n" + "const MaxConns = 100\n\n" + "\n" + "func Handle() {\n" + "\tfmt.Println()\n" + "}\n" + ) + assert result == expected + + class TestRemoveTestFunctions: def test_remove_single_function(self) -> None: test_source = ( From 5ba8b742df2ca676030d0a83f1ddd3e0d970c7d5 Mon Sep 17 00:00:00 2001 From: ali Date: Sat, 25 Apr 2026 11:11:00 +0300 Subject: [PATCH 04/10] fixes and sample code to optimize --- code_to_optimize/go/algorithms.go | 195 ++++++++++ code_to_optimize/go/algorithms_test.go | 165 +++++++++ code_to_optimize/go/calculator.go | 117 ++++++ code_to_optimize/go/calculator_test.go | 149 ++++++++ code_to_optimize/go/fibonacci.go | 108 ++++++ code_to_optimize/go/fibonacci_test.go | 138 +++++++ code_to_optimize/go/go.mod | 3 + code_to_optimize/go/graph.go | 197 ++++++++++ code_to_optimize/go/graph_test.go | 109 ++++++ code_to_optimize/go/matrix.go | 122 +++++++ code_to_optimize/go/matrix_test.go | 112 ++++++ code_to_optimize/go/sorting.go | 94 +++++ code_to_optimize/go/sorting_test.go | 122 +++++++ code_to_optimize/go/stringutils.go | 125 +++++++ code_to_optimize/go/stringutils_test.go | 216 +++++++++++ codeflash/api/aiservice.py | 4 +- codeflash/cli_cmds/cli.py | 10 +- codeflash/code_utils/config_parser.py | 36 +- codeflash/discovery/functions_to_optimize.py | 2 + codeflash/languages/function_optimizer.py | 7 +- codeflash/languages/golang/context.py | 45 ++- codeflash/languages/golang/parse.py | 205 +++++++++++ codeflash/languages/golang/parser.py | 2 +- codeflash/languages/golang/support.py | 87 ++++- codeflash/languages/golang/test_runner.py | 54 ++- .../test_golang/test_context.py | 71 +++- .../test_golang/test_function_optimizer.py | 135 +------ .../test_languages/test_golang/test_parse.py | 342 ++++++++++++++++++ .../test_golang/test_support.py | 112 +++++- .../test_golang/test_test_runner.py | 90 ++++- 30 files changed, 3003 insertions(+), 171 deletions(-) create mode 100644 code_to_optimize/go/algorithms.go create mode 100644 code_to_optimize/go/algorithms_test.go create mode 100644 code_to_optimize/go/calculator.go create mode 100644 code_to_optimize/go/calculator_test.go create mode 100644 code_to_optimize/go/fibonacci.go create mode 100644 code_to_optimize/go/fibonacci_test.go create mode 100644 code_to_optimize/go/go.mod create mode 100644 code_to_optimize/go/graph.go create mode 100644 code_to_optimize/go/graph_test.go create mode 100644 code_to_optimize/go/matrix.go create mode 100644 code_to_optimize/go/matrix_test.go create mode 100644 code_to_optimize/go/sorting.go create mode 100644 code_to_optimize/go/sorting_test.go create mode 100644 code_to_optimize/go/stringutils.go create mode 100644 code_to_optimize/go/stringutils_test.go create mode 100644 codeflash/languages/golang/parse.py create mode 100644 tests/test_languages/test_golang/test_parse.py diff --git a/code_to_optimize/go/algorithms.go b/code_to_optimize/go/algorithms.go new file mode 100644 index 000000000..47961e8ff --- /dev/null +++ b/code_to_optimize/go/algorithms.go @@ -0,0 +1,195 @@ +package sample + +import "strings" + +func TwoSum(nums []int, target int) [2]int { + for i := 0; i < len(nums); i++ { + for j := i + 1; j < len(nums); j++ { + if nums[i]+nums[j] == target { + return [2]int{i, j} + } + } + } + return [2]int{-1, -1} +} + +func FindDuplicates(nums []int) []int { + var result []int + for i := 0; i < len(nums); i++ { + found := false + for j := 0; j < i; j++ { + if nums[i] == nums[j] { + found = true + break + } + } + if found { + alreadyAdded := false + for _, r := range result { + if r == nums[i] { + alreadyAdded = true + break + } + } + if !alreadyAdded { + result = append(result, nums[i]) + } + } + } + return result +} + +func UniqueElements(nums []int) []int { + var result []int + for _, num := range nums { + found := false + for _, r := range result { + if r == num { + found = true + break + } + } + if !found { + result = append(result, num) + } + } + return result +} + +func MostFrequent(nums []int) int { + if len(nums) == 0 { + return 0 + } + + maxCount := 0 + maxNum := nums[0] + + for _, num := range nums { + count := 0 + for _, other := range nums { + if other == num { + count++ + } + } + if count > maxCount { + maxCount = count + maxNum = num + } + } + return maxNum +} + +func Intersection(a, b []int) []int { + var result []int + for _, x := range a { + for _, y := range b { + if x == y { + already := false + for _, r := range result { + if r == x { + already = true + break + } + } + if !already { + result = append(result, x) + } + } + } + } + return result +} + +func MergeSortedSlices(a, b []int) []int { + var result []int + result = append(result, a...) + result = append(result, b...) + + for i := 0; i < len(result); i++ { + for j := i + 1; j < len(result); j++ { + if result[j] < result[i] { + result[i], result[j] = result[j], result[i] + } + } + } + return result +} + +func LongestCommonPrefix(strs []string) string { + if len(strs) == 0 { + return "" + } + + prefix := strs[0] + for _, s := range strs[1:] { + for !strings.HasPrefix(s, prefix) { + prefix = prefix[:len(prefix)-1] + if prefix == "" { + return "" + } + } + } + return prefix +} + +func MaxSubarraySum(nums []int) int { + if len(nums) == 0 { + return 0 + } + + maxSum := nums[0] + for i := 0; i < len(nums); i++ { + for j := i; j < len(nums); j++ { + sum := 0 + for k := i; k <= j; k++ { + sum += nums[k] + } + if sum > maxSum { + maxSum = sum + } + } + } + return maxSum +} + +func IsPrime(n int) bool { + if n < 2 { + return false + } + for i := 2; i < n; i++ { + if n%i == 0 { + return false + } + } + return true +} + +func PrimesUpTo(limit int) []int { + var primes []int + for i := 2; i <= limit; i++ { + if IsPrime(i) { + primes = append(primes, i) + } + } + return primes +} + +func GCD(a, b int) int { + if a < 0 { + a = -a + } + if b < 0 { + b = -b + } + for b != 0 { + a, b = b, a%b + } + return a +} + +func LCM(a, b int) int { + if a == 0 || b == 0 { + return 0 + } + return a / GCD(a, b) * b +} diff --git a/code_to_optimize/go/algorithms_test.go b/code_to_optimize/go/algorithms_test.go new file mode 100644 index 000000000..a6ebc1485 --- /dev/null +++ b/code_to_optimize/go/algorithms_test.go @@ -0,0 +1,165 @@ +package sample + +import ( + "reflect" + "testing" +) + +func TestTwoSum(t *testing.T) { + got := TwoSum([]int{2, 7, 11, 15}, 9) + if got != [2]int{0, 1} { + t.Errorf("TwoSum([2,7,11,15], 9) = %v, want [0,1]", got) + } + + got = TwoSum([]int{1, 2, 3}, 10) + if got != [2]int{-1, -1} { + t.Errorf("TwoSum no match = %v, want [-1,-1]", got) + } +} + +func TestFindDuplicates(t *testing.T) { + got := FindDuplicates([]int{1, 2, 3, 2, 4, 3, 5}) + want := []int{2, 3} + if !reflect.DeepEqual(got, want) { + t.Errorf("FindDuplicates = %v, want %v", got, want) + } + + got = FindDuplicates([]int{1, 2, 3}) + if len(got) != 0 { + t.Errorf("expected no duplicates, got %v", got) + } +} + +func TestUniqueElements(t *testing.T) { + got := UniqueElements([]int{1, 2, 2, 3, 3, 3, 4}) + want := []int{1, 2, 3, 4} + if !reflect.DeepEqual(got, want) { + t.Errorf("UniqueElements = %v, want %v", got, want) + } +} + +func TestMostFrequent(t *testing.T) { + got := MostFrequent([]int{1, 2, 2, 3, 3, 3, 2, 2}) + if got != 2 { + t.Errorf("MostFrequent = %d, want 2", got) + } + + got = MostFrequent([]int{}) + if got != 0 { + t.Errorf("MostFrequent empty = %d, want 0", got) + } +} + +func TestIntersection(t *testing.T) { + got := Intersection([]int{1, 2, 3, 4}, []int{3, 4, 5, 6}) + want := []int{3, 4} + if !reflect.DeepEqual(got, want) { + t.Errorf("Intersection = %v, want %v", got, want) + } + + got = Intersection([]int{1, 2}, []int{3, 4}) + if len(got) != 0 { + t.Errorf("expected empty intersection, got %v", got) + } +} + +func TestMergeSortedSlices(t *testing.T) { + got := MergeSortedSlices([]int{1, 3, 5}, []int{2, 4, 6}) + want := []int{1, 2, 3, 4, 5, 6} + if !reflect.DeepEqual(got, want) { + t.Errorf("MergeSortedSlices = %v, want %v", got, want) + } +} + +func TestLongestCommonPrefix(t *testing.T) { + got := LongestCommonPrefix([]string{"flower", "flow", "flight"}) + if got != "fl" { + t.Errorf("LongestCommonPrefix = %q, want \"fl\"", got) + } + + got = LongestCommonPrefix([]string{"dog", "racecar", "car"}) + if got != "" { + t.Errorf("LongestCommonPrefix = %q, want \"\"", got) + } + + got = LongestCommonPrefix([]string{}) + if got != "" { + t.Errorf("LongestCommonPrefix empty = %q, want \"\"", got) + } +} + +func TestMaxSubarraySum(t *testing.T) { + got := MaxSubarraySum([]int{-2, 1, -3, 4, -1, 2, 1, -5, 4}) + if got != 6 { + t.Errorf("MaxSubarraySum = %d, want 6", got) + } + + got = MaxSubarraySum([]int{-1, -2, -3}) + if got != -1 { + t.Errorf("MaxSubarraySum all negative = %d, want -1", got) + } + + got = MaxSubarraySum([]int{}) + if got != 0 { + t.Errorf("MaxSubarraySum empty = %d, want 0", got) + } +} + +func TestIsPrime(t *testing.T) { + primes := []int{2, 3, 5, 7, 11, 13, 17, 19, 23} + for _, p := range primes { + if !IsPrime(p) { + t.Errorf("IsPrime(%d) = false, want true", p) + } + } + + nonPrimes := []int{0, 1, 4, 6, 8, 9, 10, 15} + for _, n := range nonPrimes { + if IsPrime(n) { + t.Errorf("IsPrime(%d) = true, want false", n) + } + } +} + +func TestPrimesUpTo(t *testing.T) { + got := PrimesUpTo(20) + want := []int{2, 3, 5, 7, 11, 13, 17, 19} + if !reflect.DeepEqual(got, want) { + t.Errorf("PrimesUpTo(20) = %v, want %v", got, want) + } +} + +func TestGCD(t *testing.T) { + tests := []struct { + a, b, want int + }{ + {12, 8, 4}, + {7, 13, 1}, + {0, 5, 5}, + {-12, 8, 4}, + } + + for _, tc := range tests { + got := GCD(tc.a, tc.b) + if got != tc.want { + t.Errorf("GCD(%d, %d) = %d, want %d", tc.a, tc.b, got, tc.want) + } + } +} + +func TestLCM(t *testing.T) { + tests := []struct { + a, b, want int + }{ + {4, 6, 12}, + {7, 13, 91}, + {0, 5, 0}, + } + + for _, tc := range tests { + got := LCM(tc.a, tc.b) + if got != tc.want { + t.Errorf("LCM(%d, %d) = %d, want %d", tc.a, tc.b, got, tc.want) + } + } +} diff --git a/code_to_optimize/go/calculator.go b/code_to_optimize/go/calculator.go new file mode 100644 index 000000000..161537293 --- /dev/null +++ b/code_to_optimize/go/calculator.go @@ -0,0 +1,117 @@ +package sample + +import "math" + +func Factorial(n int) int64 { + if n < 0 { + panic("factorial not defined for negative numbers") + } + if n <= 1 { + return 1 + } + return int64(n) * Factorial(n-1) +} + +func Power(base float64, exp int) float64 { + if exp < 0 { + return 1.0 / Power(base, -exp) + } + if exp == 0 { + return 1 + } + result := 1.0 + for i := 0; i < exp; i++ { + result *= base + } + return result +} + +func SumRange(start, end int) int64 { + var sum int64 + for i := start; i <= end; i++ { + sum += int64(i) + } + return sum +} + +func Average(nums []float64) float64 { + if len(nums) == 0 { + return 0 + } + sum := 0.0 + for _, n := range nums { + sum = sum + n + } + return sum / float64(len(nums)) +} + +func StandardDeviation(nums []float64) float64 { + if len(nums) == 0 { + return 0 + } + avg := Average(nums) + sumSqDiff := 0.0 + for _, n := range nums { + diff := n - avg + sumSqDiff = sumSqDiff + diff*diff + } + return math.Sqrt(sumSqDiff / float64(len(nums))) +} + +func Median(nums []float64) float64 { + if len(nums) == 0 { + return 0 + } + + sorted := make([]float64, len(nums)) + copy(sorted, nums) + for i := 0; i < len(sorted); i++ { + for j := i + 1; j < len(sorted); j++ { + if sorted[j] < sorted[i] { + sorted[i], sorted[j] = sorted[j], sorted[i] + } + } + } + + mid := len(sorted) / 2 + if len(sorted)%2 == 0 { + return (sorted[mid-1] + sorted[mid]) / 2 + } + return sorted[mid] +} + +func NthRoot(x float64, n int) float64 { + if n <= 0 { + return 0 + } + if x < 0 && n%2 == 0 { + return 0 + } + + guess := x / float64(n) + for i := 0; i < 1000; i++ { + powered := Power(guess, n-1) + if powered == 0 { + break + } + guess = guess - (Power(guess, n)-x)/(float64(n)*powered) + } + return guess +} + +func Combinations(n, k int) int64 { + if k < 0 || k > n { + return 0 + } + if k == 0 || k == n { + return 1 + } + return Factorial(n) / (Factorial(k) * Factorial(n-k)) +} + +func Permutations(n, k int) int64 { + if k < 0 || k > n { + return 0 + } + return Factorial(n) / Factorial(n-k) +} diff --git a/code_to_optimize/go/calculator_test.go b/code_to_optimize/go/calculator_test.go new file mode 100644 index 000000000..0331695c1 --- /dev/null +++ b/code_to_optimize/go/calculator_test.go @@ -0,0 +1,149 @@ +package sample + +import ( + "math" + "testing" +) + +func TestFactorial(t *testing.T) { + tests := []struct { + n int + want int64 + }{ + {0, 1}, + {1, 1}, + {5, 120}, + {10, 3628800}, + } + + for _, tc := range tests { + got := Factorial(tc.n) + if got != tc.want { + t.Errorf("Factorial(%d) = %d, want %d", tc.n, got, tc.want) + } + } +} + +func TestFactorialPanic(t *testing.T) { + defer func() { + if r := recover(); r == nil { + t.Error("expected panic for negative input") + } + }() + Factorial(-1) +} + +func TestPower(t *testing.T) { + tests := []struct { + base float64 + exp int + want float64 + }{ + {2, 10, 1024}, + {3, 0, 1}, + {5, 1, 5}, + {2, -1, 0.5}, + } + + for _, tc := range tests { + got := Power(tc.base, tc.exp) + if math.Abs(got-tc.want) > 1e-9 { + t.Errorf("Power(%f, %d) = %f, want %f", tc.base, tc.exp, got, tc.want) + } + } +} + +func TestSumRange(t *testing.T) { + if got := SumRange(1, 100); got != 5050 { + t.Errorf("SumRange(1,100) = %d, want 5050", got) + } + if got := SumRange(5, 5); got != 5 { + t.Errorf("SumRange(5,5) = %d, want 5", got) + } +} + +func TestAverage(t *testing.T) { + got := Average([]float64{1, 2, 3, 4, 5}) + if got != 3.0 { + t.Errorf("Average = %f, want 3.0", got) + } + + got = Average([]float64{}) + if got != 0 { + t.Errorf("Average empty = %f, want 0", got) + } +} + +func TestStandardDeviation(t *testing.T) { + got := StandardDeviation([]float64{2, 4, 4, 4, 5, 5, 7, 9}) + if math.Abs(got-2.0) > 0.01 { + t.Errorf("StandardDeviation = %f, want ~2.0", got) + } +} + +func TestMedian(t *testing.T) { + got := Median([]float64{3, 1, 2}) + if got != 2.0 { + t.Errorf("Median odd = %f, want 2.0", got) + } + + got = Median([]float64{4, 1, 3, 2}) + if got != 2.5 { + t.Errorf("Median even = %f, want 2.5", got) + } + + got = Median([]float64{}) + if got != 0 { + t.Errorf("Median empty = %f, want 0", got) + } +} + +func TestNthRoot(t *testing.T) { + got := NthRoot(27, 3) + if math.Abs(got-3.0) > 1e-6 { + t.Errorf("NthRoot(27,3) = %f, want 3.0", got) + } + + got = NthRoot(16, 4) + if math.Abs(got-2.0) > 1e-6 { + t.Errorf("NthRoot(16,4) = %f, want 2.0", got) + } +} + +func TestCombinations(t *testing.T) { + tests := []struct { + n, k int + want int64 + }{ + {5, 2, 10}, + {10, 3, 120}, + {5, 0, 1}, + {5, 5, 1}, + {3, 5, 0}, + } + + for _, tc := range tests { + got := Combinations(tc.n, tc.k) + if got != tc.want { + t.Errorf("Combinations(%d,%d) = %d, want %d", tc.n, tc.k, got, tc.want) + } + } +} + +func TestPermutations(t *testing.T) { + tests := []struct { + n, k int + want int64 + }{ + {5, 2, 20}, + {5, 0, 1}, + {3, 5, 0}, + } + + for _, tc := range tests { + got := Permutations(tc.n, tc.k) + if got != tc.want { + t.Errorf("Permutations(%d,%d) = %d, want %d", tc.n, tc.k, got, tc.want) + } + } +} diff --git a/code_to_optimize/go/fibonacci.go b/code_to_optimize/go/fibonacci.go new file mode 100644 index 000000000..fa9bb9ad1 --- /dev/null +++ b/code_to_optimize/go/fibonacci.go @@ -0,0 +1,108 @@ +package sample + +import "math" + +func Fibonacci(n int) int64 { + if n < 0 { + panic("fibonacci not defined for negative numbers") + } + if n <= 1 { + return int64(n) + } + return Fibonacci(n-1) + Fibonacci(n-2) +} + +func IsFibonacci(num int64) bool { + if num < 0 { + return false + } + check1 := 5*num*num + 4 + check2 := 5*num*num - 4 + return isPerfectSquare(check1) || isPerfectSquare(check2) +} + +func isPerfectSquare(n int64) bool { + if n < 0 { + return false + } + sqrt := int64(math.Sqrt(float64(n))) + return sqrt*sqrt == n +} + +func FibonacciSequence(n int) []int64 { + if n < 0 { + panic("n must be non-negative") + } + if n == 0 { + return []int64{} + } + + result := make([]int64, n) + for i := 0; i < n; i++ { + result[i] = Fibonacci(i) + } + return result +} + +func FibonacciIndex(fibNum int64) int { + if fibNum < 0 { + return -1 + } + if fibNum == 0 { + return 0 + } + if fibNum == 1 { + return 1 + } + + for index := 2; index <= 50; index++ { + fib := Fibonacci(index) + if fib == fibNum { + return index + } + if fib > fibNum { + return -1 + } + } + return -1 +} + +func SumFibonacci(n int) int64 { + if n <= 0 { + return 0 + } + var sum int64 + for i := 0; i < n; i++ { + sum += Fibonacci(i) + } + return sum +} + +func FibonacciUpTo(limit int64) []int64 { + var result []int64 + if limit <= 0 { + return result + } + + for index := 0; index <= 50; index++ { + fib := Fibonacci(index) + if fib >= limit { + break + } + result = append(result, fib) + } + return result +} + +func AreConsecutiveFibonacci(a, b int64) bool { + if !IsFibonacci(a) || !IsFibonacci(b) { + return false + } + indexA := FibonacciIndex(a) + indexB := FibonacciIndex(b) + diff := indexA - indexB + if diff < 0 { + diff = -diff + } + return diff == 1 +} diff --git a/code_to_optimize/go/fibonacci_test.go b/code_to_optimize/go/fibonacci_test.go new file mode 100644 index 000000000..3189ae754 --- /dev/null +++ b/code_to_optimize/go/fibonacci_test.go @@ -0,0 +1,138 @@ +package sample + +import ( + "reflect" + "testing" +) + +func TestFibonacci(t *testing.T) { + tests := []struct { + n int + want int64 + }{ + {0, 0}, + {1, 1}, + {2, 1}, + {5, 5}, + {10, 55}, + {20, 6765}, + } + + for _, tc := range tests { + got := Fibonacci(tc.n) + if got != tc.want { + t.Errorf("Fibonacci(%d) = %d, want %d", tc.n, got, tc.want) + } + } +} + +func TestFibonacciPanicsOnNegative(t *testing.T) { + defer func() { + if r := recover(); r == nil { + t.Error("expected panic for negative input") + } + }() + Fibonacci(-1) +} + +func TestIsFibonacci(t *testing.T) { + fibs := []int64{0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55} + for _, f := range fibs { + if !IsFibonacci(f) { + t.Errorf("IsFibonacci(%d) = false, want true", f) + } + } + + nonFibs := []int64{4, 6, 7, 9, 10, 22} + for _, f := range nonFibs { + if IsFibonacci(f) { + t.Errorf("IsFibonacci(%d) = true, want false", f) + } + } + + if IsFibonacci(-1) { + t.Error("IsFibonacci(-1) should be false") + } +} + +func TestFibonacciSequence(t *testing.T) { + got := FibonacciSequence(7) + want := []int64{0, 1, 1, 2, 3, 5, 8} + if !reflect.DeepEqual(got, want) { + t.Errorf("FibonacciSequence(7) = %v, want %v", got, want) + } + + got = FibonacciSequence(0) + if len(got) != 0 { + t.Errorf("FibonacciSequence(0) should be empty, got %v", got) + } +} + +func TestFibonacciIndex(t *testing.T) { + tests := []struct { + num int64 + want int + }{ + {0, 0}, + {1, 1}, + {5, 5}, + {8, 6}, + {55, 10}, + {4, -1}, + {-1, -1}, + } + + for _, tc := range tests { + got := FibonacciIndex(tc.num) + if got != tc.want { + t.Errorf("FibonacciIndex(%d) = %d, want %d", tc.num, got, tc.want) + } + } +} + +func TestSumFibonacci(t *testing.T) { + tests := []struct { + n int + want int64 + }{ + {0, 0}, + {1, 0}, + {5, 7}, + {7, 20}, + } + + for _, tc := range tests { + got := SumFibonacci(tc.n) + if got != tc.want { + t.Errorf("SumFibonacci(%d) = %d, want %d", tc.n, got, tc.want) + } + } +} + +func TestFibonacciUpTo(t *testing.T) { + got := FibonacciUpTo(10) + want := []int64{0, 1, 1, 2, 3, 5, 8} + if !reflect.DeepEqual(got, want) { + t.Errorf("FibonacciUpTo(10) = %v, want %v", got, want) + } + + got = FibonacciUpTo(0) + if len(got) != 0 { + t.Errorf("FibonacciUpTo(0) should be empty") + } +} + +func TestAreConsecutiveFibonacci(t *testing.T) { + if !AreConsecutiveFibonacci(5, 8) { + t.Error("5 and 8 are consecutive fibonacci numbers") + } + if !AreConsecutiveFibonacci(8, 5) { + t.Error("8 and 5 are consecutive fibonacci numbers") + } + if AreConsecutiveFibonacci(5, 13) { + t.Error("5 and 13 are not consecutive fibonacci numbers") + } + if AreConsecutiveFibonacci(4, 5) { + t.Error("4 is not a fibonacci number") + } +} diff --git a/code_to_optimize/go/go.mod b/code_to_optimize/go/go.mod new file mode 100644 index 000000000..f037eded9 --- /dev/null +++ b/code_to_optimize/go/go.mod @@ -0,0 +1,3 @@ +module example/codeflash-go-sample + +go 1.21 diff --git a/code_to_optimize/go/graph.go b/code_to_optimize/go/graph.go new file mode 100644 index 000000000..d91da788c --- /dev/null +++ b/code_to_optimize/go/graph.go @@ -0,0 +1,197 @@ +package sample + +func BFS(graph map[int][]int, start int) []int { + visited := make(map[int]bool) + var result []int + queue := []int{start} + visited[start] = true + + for len(queue) > 0 { + node := queue[0] + queue = queue[1:] + result = append(result, node) + + neighbors := graph[node] + for i := 0; i < len(neighbors); i++ { + for j := i + 1; j < len(neighbors); j++ { + if neighbors[j] < neighbors[i] { + neighbors[i], neighbors[j] = neighbors[j], neighbors[i] + } + } + } + + for _, neighbor := range neighbors { + if !visited[neighbor] { + visited[neighbor] = true + queue = append(queue, neighbor) + } + } + } + return result +} + +func DFS(graph map[int][]int, start int) []int { + visited := make(map[int]bool) + var result []int + dfsHelper(graph, start, visited, &result) + return result +} + +func dfsHelper(graph map[int][]int, node int, visited map[int]bool, result *[]int) { + if visited[node] { + return + } + visited[node] = true + *result = append(*result, node) + + neighbors := make([]int, len(graph[node])) + copy(neighbors, graph[node]) + for i := 0; i < len(neighbors); i++ { + for j := i + 1; j < len(neighbors); j++ { + if neighbors[j] < neighbors[i] { + neighbors[i], neighbors[j] = neighbors[j], neighbors[i] + } + } + } + + for _, neighbor := range neighbors { + dfsHelper(graph, neighbor, visited, result) + } +} + +func ShortestPath(graph map[int][]int, start, end int) int { + if start == end { + return 0 + } + + visited := make(map[int]bool) + type entry struct { + node int + dist int + } + queue := []entry{{start, 0}} + visited[start] = true + + for len(queue) > 0 { + curr := queue[0] + queue = queue[1:] + + for _, neighbor := range graph[curr.node] { + if neighbor == end { + return curr.dist + 1 + } + if !visited[neighbor] { + visited[neighbor] = true + queue = append(queue, entry{neighbor, curr.dist + 1}) + } + } + } + return -1 +} + +func HasCycle(graph map[int][]int) bool { + visited := make(map[int]bool) + recStack := make(map[int]bool) + + for node := range graph { + if hasCycleDFS(graph, node, visited, recStack) { + return true + } + } + return false +} + +func hasCycleDFS(graph map[int][]int, node int, visited, recStack map[int]bool) bool { + if recStack[node] { + return true + } + if visited[node] { + return false + } + + visited[node] = true + recStack[node] = true + + for _, neighbor := range graph[node] { + if hasCycleDFS(graph, neighbor, visited, recStack) { + return true + } + } + + recStack[node] = false + return false +} + +func TopologicalSort(graph map[int][]int) []int { + inDegree := make(map[int]int) + for node := range graph { + if _, ok := inDegree[node]; !ok { + inDegree[node] = 0 + } + for _, neighbor := range graph[node] { + inDegree[neighbor]++ + } + } + + var queue []int + for node, degree := range inDegree { + if degree == 0 { + queue = append(queue, node) + } + } + + for i := 0; i < len(queue); i++ { + for j := i + 1; j < len(queue); j++ { + if queue[j] < queue[i] { + queue[i], queue[j] = queue[j], queue[i] + } + } + } + + var result []int + for len(queue) > 0 { + node := queue[0] + queue = queue[1:] + result = append(result, node) + + for _, neighbor := range graph[node] { + inDegree[neighbor]-- + if inDegree[neighbor] == 0 { + queue = append(queue, neighbor) + for i := 0; i < len(queue); i++ { + for j := i + 1; j < len(queue); j++ { + if queue[j] < queue[i] { + queue[i], queue[j] = queue[j], queue[i] + } + } + } + } + } + } + return result +} + +func ConnectedComponents(graph map[int][]int) [][]int { + visited := make(map[int]bool) + var components [][]int + + for node := range graph { + if !visited[node] { + var component []int + componentDFS(graph, node, visited, &component) + components = append(components, component) + } + } + return components +} + +func componentDFS(graph map[int][]int, node int, visited map[int]bool, component *[]int) { + if visited[node] { + return + } + visited[node] = true + *component = append(*component, node) + for _, neighbor := range graph[node] { + componentDFS(graph, neighbor, visited, component) + } +} diff --git a/code_to_optimize/go/graph_test.go b/code_to_optimize/go/graph_test.go new file mode 100644 index 000000000..d33a38fe0 --- /dev/null +++ b/code_to_optimize/go/graph_test.go @@ -0,0 +1,109 @@ +package sample + +import ( + "reflect" + "testing" +) + +func TestBFS(t *testing.T) { + graph := map[int][]int{ + 0: {1, 2}, + 1: {3}, + 2: {3}, + 3: {}, + } + got := BFS(graph, 0) + want := []int{0, 1, 2, 3} + if !reflect.DeepEqual(got, want) { + t.Errorf("BFS = %v, want %v", got, want) + } +} + +func TestBFSSingleNode(t *testing.T) { + graph := map[int][]int{0: {}} + got := BFS(graph, 0) + want := []int{0} + if !reflect.DeepEqual(got, want) { + t.Errorf("BFS single = %v, want %v", got, want) + } +} + +func TestDFS(t *testing.T) { + graph := map[int][]int{ + 0: {1, 2}, + 1: {3}, + 2: {3}, + 3: {}, + } + got := DFS(graph, 0) + want := []int{0, 1, 3, 2} + if !reflect.DeepEqual(got, want) { + t.Errorf("DFS = %v, want %v", got, want) + } +} + +func TestShortestPath(t *testing.T) { + graph := map[int][]int{ + 0: {1, 2}, + 1: {3}, + 2: {3}, + 3: {}, + } + + if got := ShortestPath(graph, 0, 3); got != 2 { + t.Errorf("ShortestPath(0,3) = %d, want 2", got) + } + if got := ShortestPath(graph, 0, 0); got != 0 { + t.Errorf("ShortestPath(0,0) = %d, want 0", got) + } + if got := ShortestPath(graph, 3, 0); got != -1 { + t.Errorf("ShortestPath(3,0) = %d, want -1", got) + } +} + +func TestHasCycle(t *testing.T) { + acyclic := map[int][]int{ + 0: {1}, + 1: {2}, + 2: {}, + } + if HasCycle(acyclic) { + t.Error("expected no cycle in DAG") + } + + cyclic := map[int][]int{ + 0: {1}, + 1: {2}, + 2: {0}, + } + if !HasCycle(cyclic) { + t.Error("expected cycle") + } +} + +func TestTopologicalSort(t *testing.T) { + graph := map[int][]int{ + 0: {1, 2}, + 1: {3}, + 2: {3}, + 3: {}, + } + got := TopologicalSort(graph) + want := []int{0, 1, 2, 3} + if !reflect.DeepEqual(got, want) { + t.Errorf("TopologicalSort = %v, want %v", got, want) + } +} + +func TestConnectedComponents(t *testing.T) { + graph := map[int][]int{ + 0: {1}, + 1: {0}, + 2: {3}, + 3: {2}, + } + components := ConnectedComponents(graph) + if len(components) != 2 { + t.Errorf("expected 2 components, got %d", len(components)) + } +} diff --git a/code_to_optimize/go/matrix.go b/code_to_optimize/go/matrix.go new file mode 100644 index 000000000..0dd7e1179 --- /dev/null +++ b/code_to_optimize/go/matrix.go @@ -0,0 +1,122 @@ +package sample + +import "math" + +func MatrixMultiply(a, b [][]float64) [][]float64 { + if len(a) == 0 || len(b) == 0 { + return nil + } + + rows := len(a) + cols := len(b[0]) + inner := len(b) + + result := make([][]float64, rows) + for i := range result { + result[i] = make([]float64, cols) + } + + for i := 0; i < rows; i++ { + for j := 0; j < cols; j++ { + sum := 0.0 + for k := 0; k < inner; k++ { + sum = sum + a[i][k]*b[k][j] + } + result[i][j] = sum + } + } + return result +} + +func MatrixTranspose(m [][]float64) [][]float64 { + if len(m) == 0 { + return nil + } + + rows := len(m) + cols := len(m[0]) + + result := make([][]float64, cols) + for i := range result { + result[i] = make([]float64, rows) + } + + for i := 0; i < rows; i++ { + for j := 0; j < cols; j++ { + result[j][i] = m[i][j] + } + } + return result +} + +func MatrixAdd(a, b [][]float64) [][]float64 { + if len(a) == 0 || len(b) == 0 { + return nil + } + + rows := len(a) + cols := len(a[0]) + + result := make([][]float64, rows) + for i := range result { + result[i] = make([]float64, cols) + for j := 0; j < cols; j++ { + result[i][j] = a[i][j] + b[i][j] + } + } + return result +} + +func MatrixScale(m [][]float64, scalar float64) [][]float64 { + if len(m) == 0 { + return nil + } + + rows := len(m) + cols := len(m[0]) + + result := make([][]float64, rows) + for i := range result { + result[i] = make([]float64, cols) + for j := 0; j < cols; j++ { + result[i][j] = m[i][j] * scalar + } + } + return result +} + +func DotProduct(a, b []float64) float64 { + sum := 0.0 + for i := 0; i < len(a); i++ { + sum = sum + a[i]*b[i] + } + return sum +} + +func VectorNorm(v []float64) float64 { + sum := 0.0 + for _, val := range v { + sum = sum + val*val + } + return math.Sqrt(sum) +} + +func CosineSimilarity(a, b []float64) float64 { + dot := DotProduct(a, b) + normA := VectorNorm(a) + normB := VectorNorm(b) + if normA == 0 || normB == 0 { + return 0 + } + return dot / (normA * normB) +} + +func FlattenMatrix(m [][]float64) []float64 { + var result []float64 + for _, row := range m { + for _, val := range row { + result = append(result, val) + } + } + return result +} diff --git a/code_to_optimize/go/matrix_test.go b/code_to_optimize/go/matrix_test.go new file mode 100644 index 000000000..90471c5fa --- /dev/null +++ b/code_to_optimize/go/matrix_test.go @@ -0,0 +1,112 @@ +package sample + +import ( + "math" + "reflect" + "testing" +) + +func TestMatrixMultiply(t *testing.T) { + a := [][]float64{{1, 2}, {3, 4}} + b := [][]float64{{5, 6}, {7, 8}} + got := MatrixMultiply(a, b) + want := [][]float64{{19, 22}, {43, 50}} + if !reflect.DeepEqual(got, want) { + t.Errorf("MatrixMultiply = %v, want %v", got, want) + } +} + +func TestMatrixMultiplyEmpty(t *testing.T) { + got := MatrixMultiply([][]float64{}, [][]float64{{1}}) + if got != nil { + t.Errorf("expected nil for empty input, got %v", got) + } +} + +func TestMatrixMultiplyIdentity(t *testing.T) { + a := [][]float64{{1, 2, 3}, {4, 5, 6}} + identity := [][]float64{{1, 0, 0}, {0, 1, 0}, {0, 0, 1}} + got := MatrixMultiply(a, identity) + if !reflect.DeepEqual(got, a) { + t.Errorf("A * I = %v, want %v", got, a) + } +} + +func TestMatrixTranspose(t *testing.T) { + m := [][]float64{{1, 2, 3}, {4, 5, 6}} + got := MatrixTranspose(m) + want := [][]float64{{1, 4}, {2, 5}, {3, 6}} + if !reflect.DeepEqual(got, want) { + t.Errorf("MatrixTranspose = %v, want %v", got, want) + } +} + +func TestMatrixTransposeEmpty(t *testing.T) { + got := MatrixTranspose([][]float64{}) + if got != nil { + t.Errorf("expected nil for empty input") + } +} + +func TestMatrixAdd(t *testing.T) { + a := [][]float64{{1, 2}, {3, 4}} + b := [][]float64{{5, 6}, {7, 8}} + got := MatrixAdd(a, b) + want := [][]float64{{6, 8}, {10, 12}} + if !reflect.DeepEqual(got, want) { + t.Errorf("MatrixAdd = %v, want %v", got, want) + } +} + +func TestMatrixScale(t *testing.T) { + m := [][]float64{{1, 2}, {3, 4}} + got := MatrixScale(m, 2.0) + want := [][]float64{{2, 4}, {6, 8}} + if !reflect.DeepEqual(got, want) { + t.Errorf("MatrixScale = %v, want %v", got, want) + } +} + +func TestDotProduct(t *testing.T) { + got := DotProduct([]float64{1, 2, 3}, []float64{4, 5, 6}) + want := 32.0 + if got != want { + t.Errorf("DotProduct = %f, want %f", got, want) + } +} + +func TestVectorNorm(t *testing.T) { + got := VectorNorm([]float64{3, 4}) + want := 5.0 + if got != want { + t.Errorf("VectorNorm = %f, want %f", got, want) + } +} + +func TestCosineSimilarity(t *testing.T) { + a := []float64{1, 0} + b := []float64{0, 1} + got := CosineSimilarity(a, b) + if math.Abs(got) > 1e-9 { + t.Errorf("orthogonal vectors should have cosine similarity 0, got %f", got) + } + + got = CosineSimilarity([]float64{1, 2, 3}, []float64{1, 2, 3}) + if math.Abs(got-1.0) > 1e-9 { + t.Errorf("identical vectors should have cosine similarity 1, got %f", got) + } + + got = CosineSimilarity([]float64{0, 0}, []float64{1, 2}) + if got != 0 { + t.Errorf("zero vector should give 0, got %f", got) + } +} + +func TestFlattenMatrix(t *testing.T) { + m := [][]float64{{1, 2}, {3, 4}, {5, 6}} + got := FlattenMatrix(m) + want := []float64{1, 2, 3, 4, 5, 6} + if !reflect.DeepEqual(got, want) { + t.Errorf("FlattenMatrix = %v, want %v", got, want) + } +} diff --git a/code_to_optimize/go/sorting.go b/code_to_optimize/go/sorting.go new file mode 100644 index 000000000..7de2a322e --- /dev/null +++ b/code_to_optimize/go/sorting.go @@ -0,0 +1,94 @@ +package sample + +func BubbleSort(arr []int) []int { + if len(arr) == 0 { + return arr + } + + result := make([]int, len(arr)) + copy(result, arr) + n := len(result) + + for i := 0; i < n; i++ { + for j := 0; j < n-1; j++ { + if result[j] > result[j+1] { + temp := result[j] + result[j] = result[j+1] + result[j+1] = temp + } + } + } + return result +} + +func BubbleSortDescending(arr []int) []int { + if len(arr) == 0 { + return arr + } + + result := make([]int, len(arr)) + copy(result, arr) + n := len(result) + + for i := 0; i < n-1; i++ { + for j := 0; j < n-i-1; j++ { + if result[j] < result[j+1] { + temp := result[j] + result[j] = result[j+1] + result[j+1] = temp + } + } + } + return result +} + +func InsertionSort(arr []int) []int { + if len(arr) == 0 { + return arr + } + + result := make([]int, len(arr)) + copy(result, arr) + n := len(result) + + for i := 1; i < n; i++ { + key := result[i] + j := i - 1 + for j >= 0 && result[j] > key { + result[j+1] = result[j] + j-- + } + result[j+1] = key + } + return result +} + +func SelectionSort(arr []int) []int { + if len(arr) == 0 { + return arr + } + + result := make([]int, len(arr)) + copy(result, arr) + n := len(result) + + for i := 0; i < n-1; i++ { + minIdx := i + for j := i + 1; j < n; j++ { + if result[j] < result[minIdx] { + minIdx = j + } + } + result[minIdx], result[i] = result[i], result[minIdx] + } + return result +} + +func IsSorted(arr []int) bool { + for i := 0; i < len(arr)-1; i++ { + if arr[i] > arr[i+1] { + return false + } + } + return true +} diff --git a/code_to_optimize/go/sorting_test.go b/code_to_optimize/go/sorting_test.go new file mode 100644 index 000000000..ac6890693 --- /dev/null +++ b/code_to_optimize/go/sorting_test.go @@ -0,0 +1,122 @@ +package sample + +import ( + "reflect" + "testing" +) + +func TestBubbleSort(t *testing.T) { + tests := []struct { + input []int + expected []int + }{ + {[]int{5, 3, 1, 4, 2}, []int{1, 2, 3, 4, 5}}, + {[]int{3, 2, 1}, []int{1, 2, 3}}, + {[]int{1}, []int{1}}, + {[]int{}, []int{}}, + {[]int{1, 2, 3, 4, 5}, []int{1, 2, 3, 4, 5}}, + } + + for _, tc := range tests { + result := BubbleSort(tc.input) + if !reflect.DeepEqual(result, tc.expected) { + t.Errorf("BubbleSort(%v) = %v, want %v", tc.input, result, tc.expected) + } + } +} + +func TestBubbleSortWithDuplicates(t *testing.T) { + result := BubbleSort([]int{3, 2, 4, 1, 3, 2}) + expected := []int{1, 2, 2, 3, 3, 4} + if !reflect.DeepEqual(result, expected) { + t.Errorf("got %v, want %v", result, expected) + } +} + +func TestBubbleSortWithNegatives(t *testing.T) { + result := BubbleSort([]int{3, -2, 7, 0, -5}) + expected := []int{-5, -2, 0, 3, 7} + if !reflect.DeepEqual(result, expected) { + t.Errorf("got %v, want %v", result, expected) + } +} + +func TestBubbleSortDescending(t *testing.T) { + tests := []struct { + input []int + expected []int + }{ + {[]int{1, 3, 5, 2, 4}, []int{5, 4, 3, 2, 1}}, + {[]int{1, 2, 3}, []int{3, 2, 1}}, + {[]int{}, []int{}}, + } + + for _, tc := range tests { + result := BubbleSortDescending(tc.input) + if !reflect.DeepEqual(result, tc.expected) { + t.Errorf("BubbleSortDescending(%v) = %v, want %v", tc.input, result, tc.expected) + } + } +} + +func TestInsertionSort(t *testing.T) { + tests := []struct { + input []int + expected []int + }{ + {[]int{5, 3, 1, 4, 2}, []int{1, 2, 3, 4, 5}}, + {[]int{3, 2, 1}, []int{1, 2, 3}}, + {[]int{1}, []int{1}}, + {[]int{}, []int{}}, + } + + for _, tc := range tests { + result := InsertionSort(tc.input) + if !reflect.DeepEqual(result, tc.expected) { + t.Errorf("InsertionSort(%v) = %v, want %v", tc.input, result, tc.expected) + } + } +} + +func TestSelectionSort(t *testing.T) { + tests := []struct { + input []int + expected []int + }{ + {[]int{5, 3, 1, 4, 2}, []int{1, 2, 3, 4, 5}}, + {[]int{3, 2, 1}, []int{1, 2, 3}}, + {[]int{1}, []int{1}}, + } + + for _, tc := range tests { + result := SelectionSort(tc.input) + if !reflect.DeepEqual(result, tc.expected) { + t.Errorf("SelectionSort(%v) = %v, want %v", tc.input, result, tc.expected) + } + } +} + +func TestIsSorted(t *testing.T) { + if !IsSorted([]int{1, 2, 3, 4, 5}) { + t.Error("expected sorted") + } + if !IsSorted([]int{1}) { + t.Error("expected sorted") + } + if !IsSorted([]int{}) { + t.Error("expected sorted") + } + if IsSorted([]int{5, 3, 1}) { + t.Error("expected not sorted") + } +} + +func TestBubbleSortDoesNotMutateInput(t *testing.T) { + original := []int{5, 3, 1, 4, 2} + saved := make([]int, len(original)) + copy(saved, original) + BubbleSort(original) + if !reflect.DeepEqual(original, saved) { + t.Errorf("input was mutated: got %v, want %v", original, saved) + } +} diff --git a/code_to_optimize/go/stringutils.go b/code_to_optimize/go/stringutils.go new file mode 100644 index 000000000..5ee8166de --- /dev/null +++ b/code_to_optimize/go/stringutils.go @@ -0,0 +1,125 @@ +package sample + +import "strings" + +func ReverseString(s string) string { + result := "" + for i := len(s) - 1; i >= 0; i-- { + result = result + string(s[i]) + } + return result +} + +func IsPalindrome(s string) bool { + reversed := ReverseString(s) + return s == reversed +} + +func CountWords(s string) int { + trimmed := strings.TrimSpace(s) + if trimmed == "" { + return 0 + } + return len(strings.Fields(trimmed)) +} + +func CapitalizeWords(s string) string { + if s == "" { + return s + } + + words := strings.Split(s, " ") + result := "" + + for i, word := range words { + if len(word) > 0 { + capitalized := strings.ToUpper(word[:1]) + strings.ToLower(word[1:]) + result = result + capitalized + } + if i < len(words)-1 { + result = result + " " + } + } + return result +} + +func CountOccurrences(s, sub string) int { + if sub == "" { + return 0 + } + + count := 0 + index := 0 + for { + pos := strings.Index(s[index:], sub) + if pos == -1 { + break + } + count++ + index = index + pos + 1 + } + return count +} + +func RemoveWhitespace(s string) string { + result := "" + for _, c := range s { + if c != ' ' && c != '\t' && c != '\n' && c != '\r' { + result = result + string(c) + } + } + return result +} + +func FindAllIndices(s string, c byte) []int { + var indices []int + for i := 0; i < len(s); i++ { + if s[i] == c { + indices = append(indices, i) + } + } + return indices +} + +func IsNumeric(s string) bool { + if s == "" { + return false + } + for _, c := range s { + if c < '0' || c > '9' { + return false + } + } + return true +} + +func Repeat(s string, n int) string { + if n <= 0 { + return "" + } + result := "" + for i := 0; i < n; i++ { + result = result + s + } + return result +} + +func Truncate(s string, maxLen int) string { + if maxLen <= 0 { + return "" + } + if len(s) <= maxLen { + return s + } + if maxLen <= 3 { + return s[:maxLen] + } + return s[:maxLen-3] + "..." +} + +func ToTitleCase(s string) string { + if s == "" { + return s + } + return strings.ToUpper(s[:1]) + strings.ToLower(s[1:]) +} diff --git a/code_to_optimize/go/stringutils_test.go b/code_to_optimize/go/stringutils_test.go new file mode 100644 index 000000000..025928c2c --- /dev/null +++ b/code_to_optimize/go/stringutils_test.go @@ -0,0 +1,216 @@ +package sample + +import ( + "reflect" + "testing" +) + +func TestReverseString(t *testing.T) { + tests := []struct { + input, want string + }{ + {"hello", "olleh"}, + {"a", "a"}, + {"", ""}, + {"abcd", "dcba"}, + } + + for _, tc := range tests { + got := ReverseString(tc.input) + if got != tc.want { + t.Errorf("ReverseString(%q) = %q, want %q", tc.input, got, tc.want) + } + } +} + +func TestIsPalindrome(t *testing.T) { + palindromes := []string{"racecar", "madam", "a", "", "abba"} + for _, s := range palindromes { + if !IsPalindrome(s) { + t.Errorf("IsPalindrome(%q) = false, want true", s) + } + } + + nonPalindromes := []string{"hello", "ab"} + for _, s := range nonPalindromes { + if IsPalindrome(s) { + t.Errorf("IsPalindrome(%q) = true, want false", s) + } + } +} + +func TestCountWords(t *testing.T) { + tests := []struct { + input string + want int + }{ + {"hello world test", 3}, + {"hello", 1}, + {"", 0}, + {" ", 0}, + {" multiple spaces between words ", 4}, + } + + for _, tc := range tests { + got := CountWords(tc.input) + if got != tc.want { + t.Errorf("CountWords(%q) = %d, want %d", tc.input, got, tc.want) + } + } +} + +func TestCapitalizeWords(t *testing.T) { + tests := []struct { + input, want string + }{ + {"hello world", "Hello World"}, + {"HELLO", "Hello"}, + {"", ""}, + {"one two three", "One Two Three"}, + } + + for _, tc := range tests { + got := CapitalizeWords(tc.input) + if got != tc.want { + t.Errorf("CapitalizeWords(%q) = %q, want %q", tc.input, got, tc.want) + } + } +} + +func TestCountOccurrences(t *testing.T) { + tests := []struct { + s, sub string + want int + }{ + {"hello hello", "hello", 2}, + {"aaa", "a", 3}, + {"aaa", "aa", 2}, + {"hello", "world", 0}, + {"hello", "", 0}, + } + + for _, tc := range tests { + got := CountOccurrences(tc.s, tc.sub) + if got != tc.want { + t.Errorf("CountOccurrences(%q, %q) = %d, want %d", tc.s, tc.sub, got, tc.want) + } + } +} + +func TestRemoveWhitespace(t *testing.T) { + tests := []struct { + input, want string + }{ + {"hello world", "helloworld"}, + {" a b c ", "abc"}, + {"test", "test"}, + {" ", ""}, + {"", ""}, + } + + for _, tc := range tests { + got := RemoveWhitespace(tc.input) + if got != tc.want { + t.Errorf("RemoveWhitespace(%q) = %q, want %q", tc.input, got, tc.want) + } + } +} + +func TestFindAllIndices(t *testing.T) { + got := FindAllIndices("hello", 'l') + want := []int{2, 3} + if !reflect.DeepEqual(got, want) { + t.Errorf("FindAllIndices(\"hello\", 'l') = %v, want %v", got, want) + } + + got = FindAllIndices("aaa", 'a') + if len(got) != 3 { + t.Errorf("expected 3 indices, got %d", len(got)) + } + + got = FindAllIndices("hello", 'z') + if len(got) != 0 { + t.Errorf("expected 0 indices, got %d", len(got)) + } + + got = FindAllIndices("", 'a') + if len(got) != 0 { + t.Errorf("expected 0 indices, got %d", len(got)) + } +} + +func TestIsNumeric(t *testing.T) { + numerics := []string{"12345", "0", "007"} + for _, s := range numerics { + if !IsNumeric(s) { + t.Errorf("IsNumeric(%q) = false, want true", s) + } + } + + nonNumerics := []string{"12.34", "-123", "abc", "12a34", ""} + for _, s := range nonNumerics { + if IsNumeric(s) { + t.Errorf("IsNumeric(%q) = true, want false", s) + } + } +} + +func TestRepeat(t *testing.T) { + tests := []struct { + s string + n int + want string + }{ + {"abc", 3, "abcabcabc"}, + {"a", 3, "aaa"}, + {"abc", 0, ""}, + {"abc", -1, ""}, + } + + for _, tc := range tests { + got := Repeat(tc.s, tc.n) + if got != tc.want { + t.Errorf("Repeat(%q, %d) = %q, want %q", tc.s, tc.n, got, tc.want) + } + } +} + +func TestTruncate(t *testing.T) { + tests := []struct { + s string + maxLen int + want string + }{ + {"hello", 10, "hello"}, + {"hello world", 6, "hel..."}, + {"hello world", 8, "hello..."}, + {"hello", 0, ""}, + {"hello", 3, "hel"}, + } + + for _, tc := range tests { + got := Truncate(tc.s, tc.maxLen) + if got != tc.want { + t.Errorf("Truncate(%q, %d) = %q, want %q", tc.s, tc.maxLen, got, tc.want) + } + } +} + +func TestToTitleCase(t *testing.T) { + tests := []struct { + input, want string + }{ + {"hello", "Hello"}, + {"HELLO", "Hello"}, + {"hELLO", "Hello"}, + {"a", "A"}, + {"", ""}, + } + + for _, tc := range tests { + got := ToTitleCase(tc.input) + if got != tc.want { + t.Errorf("ToTitleCase(%q) = %q, want %q", tc.input, got, tc.want) + } + } +} diff --git a/codeflash/api/aiservice.py b/codeflash/api/aiservice.py index 3127649f2..ec2960a97 100644 --- a/codeflash/api/aiservice.py +++ b/codeflash/api/aiservice.py @@ -49,7 +49,7 @@ def __init__(self) -> None: self.is_local = self.base_url == "http://localhost:8000" # (connect_timeout, read_timeout) — connect should be fast; read # can be slow because the server runs LLM inference. - self.timeout: float | tuple[float, float] | None = (10, 300) + self.timeout: float | tuple[float, float] | None = (10, 600) def get_next_sequence(self) -> int: """Get the next LLM call sequence number.""" @@ -113,6 +113,8 @@ def make_ai_service_request( url = f"{self.base_url}/ai{endpoint}" if method.upper() == "POST": json_payload = json.dumps(payload, indent=None, default=pydantic_encoder) + print(f"url: {url}") + print(f"payload: {json_payload}") headers = {**self.headers, "Content-Type": "application/json"} response = requests.post(url, data=json_payload, headers=headers, timeout=timeout) else: diff --git a/codeflash/cli_cmds/cli.py b/codeflash/cli_cmds/cli.py index 2db13efe8..b39d9567c 100644 --- a/codeflash/cli_cmds/cli.py +++ b/codeflash/cli_cmds/cli.py @@ -116,13 +116,16 @@ def process_pyproject_config(args: Namespace) -> Namespace: # Default to module_root if not specified is_js_ts_project = pyproject_config.get("language") in ("javascript", "typescript") is_java_project = pyproject_config.get("language") == "java" + is_go_project = pyproject_config.get("language") == "go" # Set the test framework singleton for JS/TS projects if is_js_ts_project and pyproject_config.get("test_framework"): set_current_test_framework(pyproject_config["test_framework"]) if args.tests_root is None: - if is_java_project: + if is_go_project: + args.tests_root = args.module_root + elif is_java_project: # Try standard Maven/Gradle test directories for test_dir in ["src/test/java", "test", "tests"]: test_path = Path(args.module_root).parent / test_dir if "/" in test_dir else Path(test_dir) @@ -202,7 +205,10 @@ def process_pyproject_config(args: Namespace) -> Namespace: args.benchmarks_root = Path(args.benchmarks_root).resolve() args.test_project_root = project_root_from_module_root(args.tests_root, pyproject_file_path) - if is_java_project and pyproject_file_path.is_dir(): + if is_go_project and pyproject_file_path.is_dir(): + args.project_root = pyproject_file_path.resolve() + args.test_project_root = pyproject_file_path.resolve() + elif is_java_project and pyproject_file_path.is_dir(): # For Java projects, pyproject_file_path IS the project root directory (not a file). # Override project_root which may have resolved to a sub-module. args.project_root = pyproject_file_path.resolve() diff --git a/codeflash/code_utils/config_parser.py b/codeflash/code_utils/config_parser.py index 196779589..87960fa3f 100644 --- a/codeflash/code_utils/config_parser.py +++ b/codeflash/code_utils/config_parser.py @@ -12,6 +12,28 @@ ALL_CONFIG_FILES: dict[Path, dict[str, Path]] = {} +def _try_parse_go_config() -> tuple[dict[str, Any], Path] | None: + dir_path = Path.cwd() + while dir_path != dir_path.parent: + if (dir_path / "go.mod").exists(): + module_root = str(dir_path.resolve()) + return { + "language": "go", + "module_root": module_root, + "tests_root": module_root, + "pytest_cmd": "pytest", + "git_remote": "origin", + "disable_telemetry": False, + "disable_imports_sorting": False, + "override_fixtures": False, + "benchmark": False, + "formatter_cmds": [], + "ignore_paths": [], + }, dir_path + dir_path = dir_path.parent + return None + + def _try_parse_java_build_config() -> tuple[dict[str, Any], Path] | None: """Detect Java project from build files and parse config from pom.xml/gradle.properties. @@ -106,11 +128,23 @@ def find_conftest_files(test_paths: list[Path]) -> list[Path]: def parse_config_file( config_file_path: Path | None = None, override_formatter_check: bool = False ) -> tuple[dict[str, Any], Path]: - # Detect all config sources — Java build files, package.json, pyproject.toml + # Detect all config sources — Go modules, Java build files, package.json, pyproject.toml + go_result = _try_parse_go_config() if config_file_path is None else None java_result = _try_parse_java_build_config() if config_file_path is None else None package_json_path = find_package_json(config_file_path) pyproject_toml_path = find_closest_config_file("pyproject.toml") if config_file_path is None else None + # Use Go config only if no closer config exists + if go_result is not None: + go_depth = len(go_result[1].parts) + has_closer = ( + (java_result is not None and len(java_result[1].parts) >= go_depth) + or (package_json_path is not None and len(package_json_path.parent.parts) >= go_depth) + or (pyproject_toml_path is not None and len(pyproject_toml_path.parent.parts) >= go_depth) + ) + if not has_closer: + return go_result + # Use Java config only if no closer JS/Python config exists (monorepo support). # In a monorepo with a parent pom.xml and a child package.json, the closer config wins. if java_result is not None: diff --git a/codeflash/discovery/functions_to_optimize.py b/codeflash/discovery/functions_to_optimize.py index fdac43c25..e311e2190 100644 --- a/codeflash/discovery/functions_to_optimize.py +++ b/codeflash/discovery/functions_to_optimize.py @@ -883,6 +883,8 @@ def is_test_file(file_path_normalized: str) -> bool: site_packages_removed_count += len(_functions) continue if not file_path_normalized.startswith(module_root_str + os.sep): + print(f"module_root_str: {module_root_str}") + print(f"file_path_normalized: {file_path_normalized}") non_modules_removed_count += len(_functions) continue diff --git a/codeflash/languages/function_optimizer.py b/codeflash/languages/function_optimizer.py index 71ad03b18..b18e0c60e 100644 --- a/codeflash/languages/function_optimizer.py +++ b/codeflash/languages/function_optimizer.py @@ -1712,7 +1712,9 @@ def instrument_existing_tests(self, function_to_all_tests: dict[str, set[Functio logger.debug(f"Failed to instrument test file {test_file} for performance testing") continue - # For JS/TS, preserve .test.ts or .spec.ts suffix for Jest pattern matching + # Preserve language-specific test file naming conventions: + # JS/TS: .test.ts / .spec.ts for Jest pattern matching + # Go: _test.go required by `go test` def get_instrumented_path(original_path: str, suffix: str) -> Path: path_obj = Path(original_path) stem = path_obj.stem @@ -1724,6 +1726,9 @@ def get_instrumented_path(original_path: str, suffix: str) -> Path: elif ".spec" in stem: base, _ = stem.rsplit(".spec", 1) new_stem = f"{base}{suffix}.spec" + elif stem.endswith("_test") and ext == ".go": + base = stem.removesuffix("_test") + new_stem = f"{base}{suffix}_test" else: new_stem = f"{stem}{suffix}" diff --git a/codeflash/languages/golang/context.py b/codeflash/languages/golang/context.py index 3808a0d1e..a2a608e2b 100644 --- a/codeflash/languages/golang/context.py +++ b/codeflash/languages/golang/context.py @@ -63,16 +63,44 @@ def find_helper_functions( ) -> list[HelperFunction]: analyzer = analyzer or GoAnalyzer() target_name = function.function_name + receiver_type = _get_receiver_type(function) - functions = analyzer.find_functions(source) - methods = analyzer.find_methods(source) + all_functions = analyzer.find_functions(source) + all_methods = analyzer.find_methods(source) + + candidate_names: set[str] = set() + for func in all_functions: + if func.name not in ("init", "main") and func.name != target_name: + candidate_names.add(func.name) + for method in all_methods: + if not (method.name == target_name and method.receiver_name == receiver_type): + candidate_names.add(method.name) + + referenced = analyzer.collect_body_identifiers(source, target_name, receiver_type=receiver_type) + needed = referenced & candidate_names + + seen: set[str] = set() + queue = list(needed) + while queue: + name = queue.pop() + if name in seen: + continue + seen.add(name) + ids = analyzer.collect_body_identifiers(source, name) + if not ids: + for method in all_methods: + if method.name == name: + ids = analyzer.collect_body_identifiers(source, name, receiver_type=method.receiver_name) + if ids: + break + for transitive in ids & candidate_names: + if transitive not in seen: + queue.append(transitive) helpers: list[HelperFunction] = [] - for func in functions: - if func.name == target_name: - continue - if func.name in ("init", "main"): + for func in all_functions: + if func.name not in seen: continue extracted = analyzer.extract_function_source(source, func.name) if extracted is None: @@ -88,9 +116,8 @@ def find_helper_functions( ) ) - receiver_type = _get_receiver_type(function) - for method in methods: - if method.name == target_name and method.receiver_name == receiver_type: + for method in all_methods: + if method.name not in seen: continue extracted = analyzer.extract_function_source(source, method.name, receiver_type=method.receiver_name) if extracted is None: diff --git a/codeflash/languages/golang/parse.py b/codeflash/languages/golang/parse.py new file mode 100644 index 000000000..40ed18f91 --- /dev/null +++ b/codeflash/languages/golang/parse.py @@ -0,0 +1,205 @@ +from __future__ import annotations + +import json +import logging +import re +from typing import TYPE_CHECKING + +from codeflash.models.models import FunctionTestInvocation, InvocationId, TestResults + +if TYPE_CHECKING: + import subprocess + from pathlib import Path + + from codeflash.models.models import TestFiles + from codeflash.verification.verification_utils import TestConfig + +logger = logging.getLogger(__name__) + +BENCHMARK_RE = re.compile( + r"^(Benchmark\w+)(?:-\d+)?\s+" + r"(\d+)\s+" + r"([\d.]+)\s+ns/op" + r"(?:\s+(\d+)\s+B/op)?" + r"(?:\s+(\d+)\s+allocs/op)?" +) + + +def parse_go_test_output( + test_json_path: Path, + test_files: TestFiles, + test_config: TestConfig, + run_result: subprocess.CompletedProcess | None = None, +) -> TestResults: + test_results = TestResults() + + content = _read_json_output(test_json_path, run_result) + if not content: + logger.warning("No Go test output to parse from %s", test_json_path) + return test_results + + events = _parse_json_lines(content) + if not events: + logger.warning("No valid JSON events found in %s", test_json_path) + return test_results + + test_states: dict[str, _TestState] = {} + benchmark_results: dict[str, _BenchmarkResult] = {} + + for event in events: + action = event.get("Action") + test_name = event.get("Test") + package = event.get("Package", "") + + if test_name is None: + continue + + if test_name not in test_states: + test_states[test_name] = _TestState(package=package) + + state = test_states[test_name] + + if action == "output": + output_text = event.get("Output", "") + state.stdout += output_text + bench_match = BENCHMARK_RE.search(output_text) + if bench_match: + bench_name = bench_match.group(1) + iterations = int(bench_match.group(2)) + ns_per_op = float(bench_match.group(3)) + b_per_op = int(bench_match.group(4)) if bench_match.group(4) else None + allocs_per_op = int(bench_match.group(5)) if bench_match.group(5) else None + benchmark_results[bench_name] = _BenchmarkResult( + ns_per_op=ns_per_op, iterations=iterations, b_per_op=b_per_op, allocs_per_op=allocs_per_op + ) + elif action == "pass": + state.passed = True + elapsed = event.get("Elapsed", 0) + state.runtime_ns = int(elapsed * 1_000_000_000) if elapsed else None + elif action == "fail": + state.passed = False + elapsed = event.get("Elapsed", 0) + state.runtime_ns = int(elapsed * 1_000_000_000) if elapsed else None + + base_dir = test_config.tests_project_rootdir + + for test_name, state in test_states.items(): + if state.passed is None: + continue + + test_file_path = _resolve_test_file(test_name, state.package, test_files, base_dir) + test_type = _resolve_test_type(test_file_path, test_files) + if test_type is None: + logger.debug("Skipping test %s: could not resolve test type", test_name) + continue + + runtime_ns = state.runtime_ns + bench = benchmark_results.get(test_name) + if bench is not None: + runtime_ns = int(bench.ns_per_op) + + test_results.add( + FunctionTestInvocation( + loop_index=1, + id=InvocationId( + test_module_path=state.package, + test_class_name=None, + test_function_name=test_name, + function_getting_tested="", + iteration_id="", + ), + file_name=test_file_path, + runtime=runtime_ns, + test_framework="go-test", + did_pass=state.passed, + test_type=test_type, + return_value=None, + timed_out=False, + stdout=state.stdout, + ) + ) + + if not test_results: + logger.info("No Go test results parsed from %s", test_json_path) + if run_result is not None: + logger.debug("stdout: %s\nstderr: %s", run_result.stdout, run_result.stderr) + + return test_results + + +class _TestState: + __slots__ = ("package", "passed", "runtime_ns", "stdout") + + def __init__(self, package: str) -> None: + self.package = package + self.passed: bool | None = None + self.runtime_ns: int | None = None + self.stdout: str = "" + + +class _BenchmarkResult: + __slots__ = ("allocs_per_op", "b_per_op", "iterations", "ns_per_op") + + def __init__( + self, ns_per_op: float, iterations: int, b_per_op: int | None = None, allocs_per_op: int | None = None + ) -> None: + self.ns_per_op = ns_per_op + self.iterations = iterations + self.b_per_op = b_per_op + self.allocs_per_op = allocs_per_op + + +def _read_json_output(path: Path, run_result: subprocess.CompletedProcess | None) -> str: + try: + content = path.read_text(encoding="utf-8") + if content.strip(): + return content + except Exception: + pass + if run_result is not None: + stdout = run_result.stdout + if isinstance(stdout, bytes): + stdout = stdout.decode("utf-8", errors="replace") + return stdout or "" + return "" + + +def _parse_json_lines(content: str) -> list[dict]: + events: list[dict] = [] + for line in content.splitlines(): + line = line.strip() + if not line: + continue + try: + events.append(json.loads(line)) + except json.JSONDecodeError: + continue + return events + + +def _resolve_test_file(test_name: str, package: str, test_files: TestFiles, base_dir: Path) -> Path: + + for tf in test_files.test_files: + behavior_path = tf.instrumented_behavior_file_path + if behavior_path.exists(): + return behavior_path + if tf.original_file_path and tf.original_file_path.exists(): + return tf.original_file_path + + if package: + return base_dir / package.replace("/", "_") + return base_dir / f"{test_name}.go" + + +def _resolve_test_type(test_file_path: Path, test_files: TestFiles): + from codeflash.models.test_type import TestType + + test_type = test_files.get_test_type_by_instrumented_file_path(test_file_path) + if test_type is not None: + return test_type + test_type = test_files.get_test_type_by_original_file_path(test_file_path) + if test_type is not None: + return test_type + if test_files.test_files: + return test_files.test_files[0].test_type + return TestType.GENERATED_REGRESSION diff --git a/codeflash/languages/golang/parser.py b/codeflash/languages/golang/parser.py index e2d43ffc2..caa3dbb44 100644 --- a/codeflash/languages/golang/parser.py +++ b/codeflash/languages/golang/parser.py @@ -388,7 +388,7 @@ def _collect_identifiers(node: Node | None) -> set[str]: stack = [node] while stack: n = stack.pop() - if n.type in ("identifier", "type_identifier"): + if n.type in ("identifier", "type_identifier", "field_identifier"): text = n.parent if text is not None and text.type not in ("parameter_declaration", "short_var_declaration"): ids.add(n.text.decode("utf-8") if n.text else "") diff --git a/codeflash/languages/golang/support.py b/codeflash/languages/golang/support.py index ec68c7d08..e6c0e78dc 100644 --- a/codeflash/languages/golang/support.py +++ b/codeflash/languages/golang/support.py @@ -29,10 +29,12 @@ DependencyResolver, FunctionFilterCriteria, HelperFunction, + InvocationId, ReferenceInfo, TestInfo, ) from codeflash.models.function_types import FunctionToOptimize + from codeflash.models.models import GeneratedTestsList logger = logging.getLogger(__name__) @@ -102,6 +104,13 @@ def discover_tests( def validate_syntax(self, source: str, file_path: Path | None = None) -> bool: return self._analyzer.validate_syntax(source) + def parse_test_xml( + self, test_xml_file_path: Path, test_files: Any, test_config: Any, run_result: Any = None + ) -> Any: + from codeflash.languages.golang.parse import parse_go_test_output + + return parse_go_test_output(test_xml_file_path, test_files, test_config, run_result) + def extract_code_context(self, function: FunctionToOptimize, project_root: Path, module_root: Path) -> CodeContext: return _extract_context(function, project_root, module_root, self._analyzer) @@ -129,19 +138,29 @@ def normalize_code(self, source: str) -> str: def add_global_declarations(self, optimized_code: str, original_source: str, module_abspath: Path) -> str: return _add_globals(optimized_code, original_source, self._analyzer) + def get_module_path(self, source_file: Path, project_root: Path, tests_root: Path | None = None) -> str: + return str(source_file) + def prepare_module( self, module_code: str, module_path: Path, project_root: Path ) -> tuple[dict[Path, Any], None] | None: + from codeflash.models.models import ValidCode + if not self._analyzer.validate_syntax(module_code): return None - return {module_path: module_code}, None + validated: dict[Path, ValidCode] = { + module_path: ValidCode(source_code=module_code, normalized_code=normalize_go_code(module_code)) + } + return validated, None - def setup_test_config(self, test_cfg: Any) -> None: + def setup_test_config(self, test_cfg: Any, file_path: Path, current_worktree: Path | None = None) -> bool: + _ = file_path, current_worktree project_root = getattr(test_cfg, "project_root_path", Path.cwd()) config = detect_go_project(project_root) if config is not None and config.go_version: self._go_version = config.go_version self._go_version_detected = True + return True def detect_module_system(self, project_root: Path, source_file: Path | None = None) -> str | None: return None @@ -182,6 +201,9 @@ def run_benchmarking_tests( inner_iterations, ) + def generate_concolic_tests(self, *args: Any, **kwargs: Any) -> tuple[dict[str, Any], str]: + return {}, "" + def run_line_profile_tests(self, *args: Any, **kwargs: Any) -> Any: raise NotImplementedError @@ -200,14 +222,33 @@ def instrument_for_behavior(self, source: str, functions: Sequence[FunctionToOpt def instrument_for_benchmarking(self, test_source: str, target_function: FunctionToOptimize) -> str: return test_source - def instrument_existing_test(self, *args: Any, **kwargs: Any) -> tuple[bool, str | None]: - raise NotImplementedError + def instrument_existing_test( + self, test_path: Path, call_positions: Any, function_to_optimize: Any, tests_project_root: Path, mode: str + ) -> tuple[bool, str | None]: + _ = call_positions, function_to_optimize, tests_project_root, mode + try: + return True, test_path.read_text(encoding="utf-8") + except Exception: + return False, None - def postprocess_generated_tests(self, *args: Any, **kwargs: Any) -> Any: - raise NotImplementedError + def postprocess_generated_tests( + self, generated_tests: GeneratedTestsList, test_framework: str, project_root: Path, source_file_path: Path + ) -> GeneratedTestsList: + _ = test_framework, project_root, source_file_path + return generated_tests - def process_generated_test_strings(self, *args: Any, **kwargs: Any) -> Any: - raise NotImplementedError + def process_generated_test_strings( + self, + generated_test_source: str, + instrumented_behavior_test_source: str, + instrumented_perf_test_source: str, + function_to_optimize: Any, + test_path: Path, + test_cfg: Any, + project_module_system: str | None, + ) -> tuple[str, str, str]: + _ = function_to_optimize, test_path, test_cfg, project_module_system + return generated_test_source, instrumented_behavior_test_source, instrumented_perf_test_source def load_coverage(self, *args: Any, **kwargs: Any) -> Any: return None @@ -238,6 +279,36 @@ def add_runtime_comments( def remove_test_functions(self, test_source: str, functions_to_remove: list[str]) -> str: return _remove_tests(test_source, functions_to_remove, self._analyzer) + def add_runtime_comments_to_generated_tests( + self, + generated_tests: GeneratedTestsList, + original_runtimes: dict[InvocationId, list[int]], + optimized_runtimes: dict[InvocationId, list[int]], + tests_project_rootdir: Path | None = None, + ) -> GeneratedTestsList: + _ = original_runtimes, optimized_runtimes, tests_project_rootdir + return generated_tests + + def remove_test_functions_from_generated_tests( + self, generated_tests: GeneratedTestsList, functions_to_remove: list[str] + ) -> GeneratedTestsList: + from codeflash.models.models import GeneratedTests + + updated_tests: list[GeneratedTests] = [] + for test in generated_tests.generated_tests: + updated_tests.append( + GeneratedTests( + generated_original_test_source=self.remove_test_functions( + test.generated_original_test_source, functions_to_remove + ), + instrumented_behavior_test_source=test.instrumented_behavior_test_source, + instrumented_perf_test_source=test.instrumented_perf_test_source, + behavior_file_path=test.behavior_file_path, + perf_file_path=test.perf_file_path, + ) + ) + return type(generated_tests)(generated_tests=updated_tests) + def get_test_dir_for_source(self, test_dir: Path, source_file: Path | None = None) -> Path | None: if source_file is not None: return source_file.parent diff --git a/codeflash/languages/golang/test_runner.py b/codeflash/languages/golang/test_runner.py index 0d253655d..0a3006d78 100644 --- a/codeflash/languages/golang/test_runner.py +++ b/codeflash/languages/golang/test_runner.py @@ -12,6 +12,7 @@ from codeflash.languages.base import TestResult if TYPE_CHECKING: + from collections.abc import Generator from pathlib import Path logger = logging.getLogger(__name__) @@ -39,7 +40,9 @@ def run_behavioral_tests( cmd = ["go", "test", "-json", "-v", "-count=1", *packages] - proc_result = _run_cmd_kill_pg_on_timeout(cmd, cwd=cwd, env=env, timeout=timeout) + originals = _collect_original_file_paths(test_paths) + with _hide_original_test_files(originals): + proc_result = _run_cmd_kill_pg_on_timeout(cmd, cwd=cwd, env=env, timeout=timeout) json_output_file.write_text(proc_result.stdout or "", encoding="utf-8") @@ -82,7 +85,9 @@ def run_benchmarking_tests( *packages, ] - proc_result = _run_cmd_kill_pg_on_timeout(cmd, cwd=cwd, env=env, timeout=timeout) + originals = _collect_original_file_paths(test_paths) + with _hide_original_test_files(originals): + proc_result = _run_cmd_kill_pg_on_timeout(cmd, cwd=cwd, env=env, timeout=timeout) json_output_file.write_text(proc_result.stdout or "", encoding="utf-8") @@ -184,6 +189,51 @@ def _collect_test_file_paths(test_paths: Any) -> list[Path]: return [] +def _collect_original_file_paths(test_paths: Any) -> list[Path]: + from pathlib import Path as _Path + + if test_paths is None or not hasattr(test_paths, "test_files"): + return [] + + originals: list[Path] = [] + for tf in test_paths.test_files: + instrumented = getattr(tf, "instrumented_behavior_file_path", None) + original = getattr(tf, "original_file_path", None) + if instrumented is not None and original is not None: + instrumented_p = _Path(instrumented) + original_p = _Path(original) + if instrumented_p != original_p and original_p.exists(): + originals.append(original_p) + return originals + + +@contextlib.contextmanager +def _hide_original_test_files(originals: list[Path]) -> Generator[None, None, None]: + """Temporarily rename original test files so `go test` only sees the instrumented copies. + + Go compiles all *_test.go files in a package together, so having both the original + and its instrumented copy causes duplicate symbol errors. + """ + renamed: list[tuple[Path, Path]] = [] + for original in originals: + hidden = original.with_suffix(".go.codeflash_hidden") + try: + original.rename(hidden) + renamed.append((hidden, original)) + logger.debug("Temporarily hid %s during go test", original) + except OSError: + logger.debug("Could not hide %s, skipping", original) + try: + yield + finally: + for hidden, original in renamed: + try: + hidden.rename(original) + logger.debug("Restored %s", original) + except OSError: + logger.warning("Failed to restore %s from %s", original, hidden) + + def _test_files_to_packages(test_files: list[Path], cwd: Path) -> list[str]: dirs: set[str] = set() for f in test_files: diff --git a/tests/test_languages/test_golang/test_context.py b/tests/test_languages/test_golang/test_context.py index c79092511..705ea2d4b 100644 --- a/tests/test_languages/test_golang/test_context.py +++ b/tests/test_languages/test_golang/test_context.py @@ -68,17 +68,27 @@ def test_no_read_only_context_for_function(self, tmp_path: Path) -> None: ctx = extract_code_context(func, tmp_path.resolve()) assert ctx.read_only_context == "" - def test_helpers_for_function(self, tmp_path: Path) -> None: + def test_helpers_only_includes_called_functions(self, tmp_path: Path) -> None: source_file = (tmp_path / "calc.go").resolve() source_file.write_text(GO_SOURCE_WITH_METHOD, encoding="utf-8") func = FunctionToOptimize( function_name="Add", file_path=source_file, language="go", starting_line=10, ending_line=12 ) ctx = extract_code_context(func, tmp_path.resolve()) + assert ctx.helper_functions == [] + + def test_helpers_includes_called_function(self, tmp_path: Path) -> None: + source = ( + "package calc\n\n" + "func helper(x int) int { return x * 2 }\n\n" + "func Target(a int) int { return helper(a) }\n" + ) + source_file = (tmp_path / "calc.go").resolve() + source_file.write_text(source, encoding="utf-8") + func = FunctionToOptimize(function_name="Target", file_path=source_file, language="go") + ctx = extract_code_context(func, tmp_path.resolve()) helper_names = [h.name for h in ctx.helper_functions] - assert "subtract" in helper_names - assert "AddFloat" in helper_names - assert "Add" not in helper_names + assert helper_names == ["helper"] def test_language_is_go(self, tmp_path: Path) -> None: source_file = (tmp_path / "calc.go").resolve() @@ -148,27 +158,28 @@ def test_method_helpers_exclude_self(self, tmp_path: Path) -> None: ending_line=21, ) ctx = extract_code_context(func, tmp_path.resolve()) - helper_names = [h.name for h in ctx.helper_functions] - assert "Add" in helper_names - assert "subtract" in helper_names - assert "AddFloat" not in helper_names + assert ctx.helper_functions == [] - def test_method_helper_qualified_names(self, tmp_path: Path) -> None: + def test_method_helpers_with_calls(self, tmp_path: Path) -> None: + source = ( + "package calc\n\n" + "type Calc struct{ Val int }\n\n" + "func double(x int) int { return x * 2 }\n\n" + "func (c *Calc) Compute() int { return double(c.Val) }\n" + ) source_file = (tmp_path / "calc.go").resolve() - source_file.write_text(GO_SOURCE_WITH_METHOD, encoding="utf-8") + source_file.write_text(source, encoding="utf-8") func = FunctionToOptimize( - function_name="AddFloat", + function_name="Compute", file_path=source_file, - parents=[FunctionParent(name="Calculator", type="StructDef")], + parents=[FunctionParent(name="Calc", type="StructDef")], language="go", is_method=True, - starting_line=18, - ending_line=21, ) ctx = extract_code_context(func, tmp_path.resolve()) - helper_qns = [h.qualified_name for h in ctx.helper_functions] - assert "Add" in helper_qns - assert "subtract" in helper_qns + helper_names = [h.name for h in ctx.helper_functions] + assert helper_names == ["double"] + assert "Compute" not in helper_names class TestExtractCodeContextEdgeCases: @@ -323,7 +334,7 @@ def test_method_helpers_have_qualified_names(self, tmp_path: Path) -> None: source = ( "package calc\n\n" "type Calc struct{}\n\n" - "func (c Calc) Target() int { return 1 }\n\n" + "func (c Calc) Target() int { return c.Helper() }\n\n" "func (c Calc) Helper() int { return 2 }\n" ) source_file = (tmp_path / "calc.go").resolve() @@ -337,3 +348,27 @@ def test_method_helpers_have_qualified_names(self, tmp_path: Path) -> None: helpers = find_helper_functions(source, func) assert len(helpers) == 1 assert helpers[0].qualified_name == "Calc.Helper" + + def test_transitive_helpers(self, tmp_path: Path) -> None: + source = ( + "package calc\n\n" + "func innerHelper(x int) int { return x }\n\n" + "func outerHelper(x int) int { return innerHelper(x) }\n\n" + "func Target(a int) int { return outerHelper(a) }\n" + ) + source_file = (tmp_path / "calc.go").resolve() + func = FunctionToOptimize(function_name="Target", file_path=source_file, language="go") + helpers = find_helper_functions(source, func) + helper_names = sorted(h.name for h in helpers) + assert helper_names == ["innerHelper", "outerHelper"] + + def test_uncalled_functions_excluded(self, tmp_path: Path) -> None: + source = ( + "package calc\n\n" + "func unrelated() int { return 99 }\n\n" + "func Target(a int) int { return a + 1 }\n" + ) + source_file = (tmp_path / "calc.go").resolve() + func = FunctionToOptimize(function_name="Target", file_path=source_file, language="go") + helpers = find_helper_functions(source, func) + assert helpers == [] diff --git a/tests/test_languages/test_golang/test_function_optimizer.py b/tests/test_languages/test_golang/test_function_optimizer.py index 60915223b..6dff40585 100644 --- a/tests/test_languages/test_golang/test_function_optimizer.py +++ b/tests/test_languages/test_golang/test_function_optimizer.py @@ -171,44 +171,6 @@ def test_full_assembled_code_string(self, tmp_path: Path) -> None: func Add(a, b int) int { \treturn a + b } - - - func subtract(a, b int) int { - \treturn a - b - } - - func multiply(a, b int) int { - \treturn a * b - } - - // Greet builds a greeting message. - func Greet(name string) string { - \treturn fmt.Sprintf("Hello, %s", str.TrimSpace(name)) - } - - - // AddFloat adds a float value and records history. - func (c *Calculator) AddFloat(val float64) float64 { - \tc.Result += val - \tc.History = append(c.History, c.Result) - \treturn c.Result - } - - - // Sqrt computes the square root of the current result. - func (c *Calculator) Sqrt() float64 { - \tc.Result = math.Sqrt(c.Result) - \tc.History = append(c.History, c.Result) - \treturn c.Result - } - - - // Reset zeroes out the calculator. - func (c Calculator) Reset() Calculator { - \tc.Result = 0 - \tc.History = nil - \treturn c - } """) assert code == expected @@ -227,39 +189,9 @@ def test_code_excludes_interface_definition(self, tmp_path: Path) -> None: code = result.read_writable_code.code_strings[0].code assert "type Formatter interface" not in code - def test_helpers_include_other_functions_and_methods(self, tmp_path: Path) -> None: - result = _build_context_for_function(CALCULATOR_SOURCE, "calc.go", "Add", tmp_path) - helper_names = sorted(h.only_function_name for h in result.helper_functions) - assert "subtract" in helper_names - assert "multiply" in helper_names - assert "Greet" in helper_names - assert "AddFloat" in helper_names - assert "Sqrt" in helper_names - assert "Reset" in helper_names - assert "Add" not in helper_names - - def test_helper_sources_are_full_functions(self, tmp_path: Path) -> None: + def test_no_helpers_when_no_calls(self, tmp_path: Path) -> None: result = _build_context_for_function(CALCULATOR_SOURCE, "calc.go", "Add", tmp_path) - by_name = {h.only_function_name: h for h in result.helper_functions} - - assert by_name["subtract"].source_code == dedent("""\ - func subtract(a, b int) int { - \treturn a - b - }""") - - assert by_name["Greet"].source_code == dedent("""\ - // Greet builds a greeting message. - func Greet(name string) string { - \treturn fmt.Sprintf("Hello, %s", str.TrimSpace(name)) - } - """) - - def test_method_helpers_have_qualified_names(self, tmp_path: Path) -> None: - result = _build_context_for_function(CALCULATOR_SOURCE, "calc.go", "Add", tmp_path) - by_name = {h.only_function_name: h for h in result.helper_functions} - assert by_name["AddFloat"].qualified_name == "Calculator.AddFloat" - assert by_name["AddFloat"].fully_qualified_name == "Calculator.AddFloat" - assert by_name["subtract"].qualified_name == "subtract" + assert result.helper_functions == [] def test_no_read_only_context_for_plain_function(self, tmp_path: Path) -> None: result = _build_context_for_function(CALCULATOR_SOURCE, "calc.go", "Add", tmp_path) @@ -315,42 +247,6 @@ def test_full_assembled_code_string(self, tmp_path: Path) -> None: \tc.History = append(c.History, c.Result) \treturn c.Result } - - - // Add returns the sum of two integers. - func Add(a, b int) int { - \treturn a + b - } - - - func subtract(a, b int) int { - \treturn a - b - } - - func multiply(a, b int) int { - \treturn a * b - } - - // Greet builds a greeting message. - func Greet(name string) string { - \treturn fmt.Sprintf("Hello, %s", str.TrimSpace(name)) - } - - - // Sqrt computes the square root of the current result. - func (c *Calculator) Sqrt() float64 { - \tc.Result = math.Sqrt(c.Result) - \tc.History = append(c.History, c.Result) - \treturn c.Result - } - - - // Reset zeroes out the calculator. - func (c Calculator) Reset() Calculator { - \tc.Result = 0 - \tc.History = nil - \treturn c - } """) assert code == expected @@ -369,16 +265,9 @@ def test_read_only_context_is_struct_definition(self, tmp_path: Path) -> None: \tHistory []float64 }""") - def test_helpers_exclude_self_include_others(self, tmp_path: Path) -> None: + def test_no_helpers_when_no_calls(self, tmp_path: Path) -> None: result = self._build(tmp_path) - helper_names = sorted(h.only_function_name for h in result.helper_functions) - assert "AddFloat" not in helper_names - assert "Add" in helper_names - assert "subtract" in helper_names - assert "multiply" in helper_names - assert "Greet" in helper_names - assert "Sqrt" in helper_names - assert "Reset" in helper_names + assert result.helper_functions == [] def test_target_not_duplicated_in_code_string(self, tmp_path: Path) -> None: result = self._build(tmp_path) @@ -418,21 +307,15 @@ def test_target_in_code_string(self, tmp_path: Path) -> None: assert code.count("func (c Calculator) Reset()") == 1 assert expected_target in code - def test_helpers_include_other_methods_on_same_struct(self, tmp_path: Path) -> None: + def test_no_helpers_when_no_calls(self, tmp_path: Path) -> None: result = self._build(tmp_path) - helper_names = sorted(h.only_function_name for h in result.helper_functions) - assert "Reset" not in helper_names - assert "AddFloat" in helper_names - assert "Sqrt" in helper_names - assert "Add" in helper_names + assert result.helper_functions == [] - def test_helper_code_in_assembled_string(self, tmp_path: Path) -> None: + def test_no_helper_code_in_assembled_string(self, tmp_path: Path) -> None: result = self._build(tmp_path) code = result.read_writable_code.code_strings[0].code - assert "func (c *Calculator) AddFloat" in code - assert "func (c *Calculator) Sqrt()" in code - assert "func Add(a, b int) int" in code - assert "func subtract(a, b int) int" in code + assert "func (c *Calculator) AddFloat" not in code + assert "func Add(a, b int) int" not in code def test_struct_in_read_only_context(self, tmp_path: Path) -> None: result = self._build(tmp_path) diff --git a/tests/test_languages/test_golang/test_parse.py b/tests/test_languages/test_golang/test_parse.py new file mode 100644 index 000000000..92dd5c97e --- /dev/null +++ b/tests/test_languages/test_golang/test_parse.py @@ -0,0 +1,342 @@ +from __future__ import annotations + +import json +import subprocess +from pathlib import Path +from unittest.mock import MagicMock + +from codeflash.languages.golang.parse import BENCHMARK_RE, parse_go_test_output +from codeflash.models.models import TestFile, TestFiles +from codeflash.models.test_type import TestType + + +def _make_test_config(tmp_path: Path) -> MagicMock: + cfg = MagicMock() + cfg.tests_project_rootdir = tmp_path + cfg.test_framework = "go-test" + return cfg + + +def _make_test_files(tmp_path: Path, filenames: list[str] | None = None, test_type: TestType = TestType.GENERATED_REGRESSION) -> TestFiles: + if filenames is None: + filenames = ["calc_test.go"] + files = [] + for name in filenames: + path = (tmp_path / name).resolve() + path.write_text("package calc\n", encoding="utf-8") + files.append( + TestFile( + instrumented_behavior_file_path=path, + test_type=test_type, + ) + ) + return TestFiles(test_files=files) + + +def _write_jsonl(path: Path, events: list[dict]) -> None: + path.write_text("\n".join(json.dumps(e) for e in events) + "\n", encoding="utf-8") + + +class TestBenchmarkRegex: + def test_basic_benchmark_line(self) -> None: + line = "BenchmarkAdd-8 \t 1000000\t 1234 ns/op\t 56 B/op\t 2 allocs/op" + m = BENCHMARK_RE.search(line) + assert m is not None + assert m.group(1) == "BenchmarkAdd" + assert m.group(2) == "1000000" + assert m.group(3) == "1234" + assert m.group(4) == "56" + assert m.group(5) == "2" + + def test_benchmark_without_mem(self) -> None: + line = "BenchmarkSort 5000 300000 ns/op" + m = BENCHMARK_RE.search(line) + assert m is not None + assert m.group(1) == "BenchmarkSort" + assert m.group(4) is None + assert m.group(5) is None + + def test_benchmark_with_float_ns(self) -> None: + line = "BenchmarkFib-16 100000 12345.67 ns/op" + m = BENCHMARK_RE.search(line) + assert m is not None + assert m.group(3) == "12345.67" + + def test_non_benchmark_line(self) -> None: + line = "=== RUN TestAdd" + m = BENCHMARK_RE.search(line) + assert m is None + + +class TestParseGoTestOutputBehavioral: + def test_all_passing(self, tmp_path: Path) -> None: + events = [ + {"Time": "2024-01-01T00:00:00Z", "Action": "run", "Package": "example/calc", "Test": "TestAdd"}, + {"Time": "2024-01-01T00:00:00Z", "Action": "output", "Package": "example/calc", "Test": "TestAdd", "Output": "=== RUN TestAdd\n"}, + {"Time": "2024-01-01T00:00:00Z", "Action": "output", "Package": "example/calc", "Test": "TestAdd", "Output": "--- PASS: TestAdd (0.00s)\n"}, + {"Time": "2024-01-01T00:00:00Z", "Action": "pass", "Package": "example/calc", "Test": "TestAdd", "Elapsed": 0.001}, + {"Time": "2024-01-01T00:00:00Z", "Action": "run", "Package": "example/calc", "Test": "TestSub"}, + {"Time": "2024-01-01T00:00:00Z", "Action": "pass", "Package": "example/calc", "Test": "TestSub", "Elapsed": 0.002}, + {"Time": "2024-01-01T00:00:00Z", "Action": "pass", "Package": "example/calc"}, + ] + + json_path = (tmp_path / "results.jsonl").resolve() + _write_jsonl(json_path, events) + + test_files = _make_test_files(tmp_path) + cfg = _make_test_config(tmp_path) + + results = parse_go_test_output(json_path, test_files, cfg) + assert len(results.test_results) == 2 + + by_name = {r.id.test_function_name: r for r in results.test_results} + assert by_name["TestAdd"].did_pass is True + assert by_name["TestAdd"].runtime == 1_000_000 + assert by_name["TestSub"].did_pass is True + assert by_name["TestSub"].runtime == 2_000_000 + + def test_with_failure(self, tmp_path: Path) -> None: + events = [ + {"Action": "run", "Package": "example/calc", "Test": "TestAdd"}, + {"Action": "output", "Package": "example/calc", "Test": "TestAdd", "Output": "got 4, want 5\n"}, + {"Action": "fail", "Package": "example/calc", "Test": "TestAdd", "Elapsed": 0.01}, + ] + + json_path = (tmp_path / "results.jsonl").resolve() + _write_jsonl(json_path, events) + + test_files = _make_test_files(tmp_path) + cfg = _make_test_config(tmp_path) + + results = parse_go_test_output(json_path, test_files, cfg) + assert len(results.test_results) == 1 + assert results.test_results[0].did_pass is False + assert "got 4, want 5" in results.test_results[0].stdout + + def test_mixed_pass_fail(self, tmp_path: Path) -> None: + events = [ + {"Action": "pass", "Package": "p", "Test": "TestA", "Elapsed": 0.001}, + {"Action": "fail", "Package": "p", "Test": "TestB", "Elapsed": 0.002}, + {"Action": "pass", "Package": "p", "Test": "TestC", "Elapsed": 0.003}, + ] + + json_path = (tmp_path / "results.jsonl").resolve() + _write_jsonl(json_path, events) + + test_files = _make_test_files(tmp_path) + cfg = _make_test_config(tmp_path) + + results = parse_go_test_output(json_path, test_files, cfg) + by_name = {r.id.test_function_name: r for r in results.test_results} + assert by_name["TestA"].did_pass is True + assert by_name["TestB"].did_pass is False + assert by_name["TestC"].did_pass is True + + def test_empty_file(self, tmp_path: Path) -> None: + json_path = (tmp_path / "empty.jsonl").resolve() + json_path.write_text("", encoding="utf-8") + + test_files = _make_test_files(tmp_path) + cfg = _make_test_config(tmp_path) + + results = parse_go_test_output(json_path, test_files, cfg) + assert len(results.test_results) == 0 + + def test_missing_file_falls_back_to_run_result(self, tmp_path: Path) -> None: + json_path = (tmp_path / "nonexistent.jsonl").resolve() + events = [ + {"Action": "pass", "Package": "p", "Test": "TestX", "Elapsed": 0.005}, + ] + stdout = "\n".join(json.dumps(e) for e in events) + run_result = subprocess.CompletedProcess(args=[], returncode=0, stdout=stdout, stderr="") + + test_files = _make_test_files(tmp_path) + cfg = _make_test_config(tmp_path) + + results = parse_go_test_output(json_path, test_files, cfg, run_result) + assert len(results.test_results) == 1 + assert results.test_results[0].id.test_function_name == "TestX" + + def test_invalid_json_lines_skipped(self, tmp_path: Path) -> None: + content = 'not json\n{"Action":"pass","Package":"p","Test":"TestOK","Elapsed":0.001}\nalso bad\n' + json_path = (tmp_path / "results.jsonl").resolve() + json_path.write_text(content, encoding="utf-8") + + test_files = _make_test_files(tmp_path) + cfg = _make_test_config(tmp_path) + + results = parse_go_test_output(json_path, test_files, cfg) + assert len(results.test_results) == 1 + assert results.test_results[0].did_pass is True + + def test_test_type_from_test_files(self, tmp_path: Path) -> None: + test_files = _make_test_files(tmp_path, test_type=TestType.EXISTING_UNIT_TEST) + events = [ + {"Action": "pass", "Package": "p", "Test": "TestFoo", "Elapsed": 0.001}, + ] + json_path = (tmp_path / "results.jsonl").resolve() + _write_jsonl(json_path, events) + cfg = _make_test_config(tmp_path) + + results = parse_go_test_output(json_path, test_files, cfg) + assert results.test_results[0].test_type == TestType.EXISTING_UNIT_TEST + + def test_framework_is_go_test(self, tmp_path: Path) -> None: + events = [ + {"Action": "pass", "Package": "p", "Test": "TestBar", "Elapsed": 0.001}, + ] + json_path = (tmp_path / "results.jsonl").resolve() + _write_jsonl(json_path, events) + test_files = _make_test_files(tmp_path) + cfg = _make_test_config(tmp_path) + + results = parse_go_test_output(json_path, test_files, cfg) + assert results.test_results[0].test_framework == "go-test" + + def test_package_level_events_ignored(self, tmp_path: Path) -> None: + events = [ + {"Action": "pass", "Package": "p", "Test": "TestOK", "Elapsed": 0.001}, + {"Action": "pass", "Package": "p", "Elapsed": 0.5}, + ] + json_path = (tmp_path / "results.jsonl").resolve() + _write_jsonl(json_path, events) + test_files = _make_test_files(tmp_path) + cfg = _make_test_config(tmp_path) + + results = parse_go_test_output(json_path, test_files, cfg) + assert len(results.test_results) == 1 + + +class TestParseGoTestOutputBenchmark: + def test_benchmark_parsing(self, tmp_path: Path) -> None: + events = [ + {"Action": "run", "Package": "p", "Test": "BenchmarkAdd"}, + {"Action": "output", "Package": "p", "Test": "BenchmarkAdd", "Output": "BenchmarkAdd-8 \t 1000000\t 1234 ns/op\t 56 B/op\t 2 allocs/op\n"}, + {"Action": "pass", "Package": "p", "Test": "BenchmarkAdd", "Elapsed": 1.5}, + ] + json_path = (tmp_path / "bench.jsonl").resolve() + _write_jsonl(json_path, events) + test_files = _make_test_files(tmp_path) + cfg = _make_test_config(tmp_path) + + results = parse_go_test_output(json_path, test_files, cfg) + assert len(results.test_results) == 1 + result = results.test_results[0] + assert result.did_pass is True + assert result.runtime == 1234 + + def test_benchmark_overrides_elapsed(self, tmp_path: Path) -> None: + events = [ + {"Action": "output", "Package": "p", "Test": "BenchmarkSort", "Output": "BenchmarkSort 5000 300000 ns/op\n"}, + {"Action": "pass", "Package": "p", "Test": "BenchmarkSort", "Elapsed": 2.0}, + ] + json_path = (tmp_path / "bench.jsonl").resolve() + _write_jsonl(json_path, events) + test_files = _make_test_files(tmp_path) + cfg = _make_test_config(tmp_path) + + results = parse_go_test_output(json_path, test_files, cfg) + assert results.test_results[0].runtime == 300000 + + def test_mixed_tests_and_benchmarks(self, tmp_path: Path) -> None: + events = [ + {"Action": "pass", "Package": "p", "Test": "TestAdd", "Elapsed": 0.001}, + {"Action": "output", "Package": "p", "Test": "BenchmarkAdd", "Output": "BenchmarkAdd-8 1000000 500 ns/op\n"}, + {"Action": "pass", "Package": "p", "Test": "BenchmarkAdd", "Elapsed": 1.0}, + ] + json_path = (tmp_path / "mixed.jsonl").resolve() + _write_jsonl(json_path, events) + test_files = _make_test_files(tmp_path) + cfg = _make_test_config(tmp_path) + + results = parse_go_test_output(json_path, test_files, cfg) + by_name = {r.id.test_function_name: r for r in results.test_results} + + assert by_name["TestAdd"].runtime == 1_000_000 + assert by_name["BenchmarkAdd"].runtime == 500 + + +class TestParseGoTestOutputInvocationId: + def test_invocation_id_fields(self, tmp_path: Path) -> None: + events = [ + {"Action": "pass", "Package": "example/calc", "Test": "TestAdd", "Elapsed": 0.001}, + ] + json_path = (tmp_path / "results.jsonl").resolve() + _write_jsonl(json_path, events) + test_files = _make_test_files(tmp_path) + cfg = _make_test_config(tmp_path) + + results = parse_go_test_output(json_path, test_files, cfg) + inv = results.test_results[0] + assert inv.id.test_module_path == "example/calc" + assert inv.id.test_class_name is None + assert inv.id.test_function_name == "TestAdd" + assert inv.loop_index == 1 + + def test_unique_invocation_loop_id_stable(self, tmp_path: Path) -> None: + events = [ + {"Action": "pass", "Package": "p", "Test": "TestA", "Elapsed": 0.001}, + ] + json_path = (tmp_path / "results.jsonl").resolve() + _write_jsonl(json_path, events) + test_files = _make_test_files(tmp_path) + cfg = _make_test_config(tmp_path) + + results1 = parse_go_test_output(json_path, test_files, cfg) + results2 = parse_go_test_output(json_path, test_files, cfg) + + id1 = results1.test_results[0].unique_invocation_loop_id + id2 = results2.test_results[0].unique_invocation_loop_id + assert id1 == id2 + + +class TestParseGoTestOutputComparison: + def test_behavioral_comparison_same_results(self, tmp_path: Path) -> None: + from codeflash.verification.equivalence import compare_test_results + + events = [ + {"Action": "pass", "Package": "p", "Test": "TestAdd", "Elapsed": 0.001}, + {"Action": "pass", "Package": "p", "Test": "TestSub", "Elapsed": 0.002}, + ] + json_path = (tmp_path / "results.jsonl").resolve() + _write_jsonl(json_path, events) + test_files = _make_test_files(tmp_path) + cfg = _make_test_config(tmp_path) + + original = parse_go_test_output(json_path, test_files, cfg) + candidate = parse_go_test_output(json_path, test_files, cfg) + + are_equal, diffs = compare_test_results(original, candidate, pass_fail_only=True) + assert are_equal is True + assert diffs == [] + + def test_behavioral_comparison_different_results(self, tmp_path: Path) -> None: + from codeflash.verification.equivalence import compare_test_results + + original_events = [ + {"Action": "pass", "Package": "p", "Test": "TestAdd", "Elapsed": 0.001}, + ] + candidate_events = [ + {"Action": "fail", "Package": "p", "Test": "TestAdd", "Elapsed": 0.001}, + ] + orig_path = (tmp_path / "orig.jsonl").resolve() + cand_path = (tmp_path / "cand.jsonl").resolve() + _write_jsonl(orig_path, original_events) + _write_jsonl(cand_path, candidate_events) + + test_files = _make_test_files(tmp_path) + cfg = _make_test_config(tmp_path) + + original = parse_go_test_output(orig_path, test_files, cfg) + candidate = parse_go_test_output(cand_path, test_files, cfg) + + are_equal, diffs = compare_test_results(original, candidate, pass_fail_only=True) + assert are_equal is False + assert len(diffs) == 1 + + def test_empty_results_not_equal(self, tmp_path: Path) -> None: + from codeflash.models.models import TestResults + from codeflash.verification.equivalence import compare_test_results + + are_equal, _diffs = compare_test_results(TestResults(), TestResults(), pass_fail_only=True) + assert are_equal is False diff --git a/tests/test_languages/test_golang/test_support.py b/tests/test_languages/test_golang/test_support.py index df67d84eb..5c415c78e 100644 --- a/tests/test_languages/test_golang/test_support.py +++ b/tests/test_languages/test_golang/test_support.py @@ -4,7 +4,7 @@ from codeflash.languages.golang.support import GoSupport from codeflash.languages.language_enum import Language -from codeflash.languages.registry import clear_cache, clear_registry, get_language_support +from codeflash.languages.registry import get_language_support class TestGoSupportProperties: @@ -118,3 +118,113 @@ def test_get_test_dir_for_source(self) -> None: source_file = Path("/project/pkg/calc.go") result = support.get_test_dir_for_source(Path("/project"), source_file) assert result == Path("/project/pkg") + + def test_get_module_path(self) -> None: + support = GoSupport() + source_file = Path("/project/pkg/calc.go") + result = support.get_module_path(source_file, Path("/project")) + assert result == str(source_file) + + def test_setup_test_config_returns_true(self) -> None: + support = GoSupport() + + class FakeTestCfg: + project_root_path = Path("/nonexistent") + + assert support.setup_test_config(FakeTestCfg(), Path("/file.go")) is True + + def test_prepare_module_valid(self, tmp_path: Path) -> None: + from codeflash.models.models import ValidCode + + support = GoSupport() + code = "package main\n\nfunc main() {}\n" + module_path = (tmp_path / "main.go").resolve() + result = support.prepare_module(code, module_path, tmp_path) + assert result is not None + validated, ast_node = result + assert ast_node is None + assert module_path in validated + assert isinstance(validated[module_path], ValidCode) + assert validated[module_path].source_code == code + + def test_prepare_module_invalid(self, tmp_path: Path) -> None: + support = GoSupport() + result = support.prepare_module("func {{{ invalid", (tmp_path / "bad.go").resolve(), tmp_path) + assert result is None + + def test_instrument_existing_test_reads_file(self, tmp_path: Path) -> None: + support = GoSupport() + test_file = (tmp_path / "calc_test.go").resolve() + test_file.write_text("package calc\n\nfunc TestAdd(t *testing.T) {}\n", encoding="utf-8") + success, content = support.instrument_existing_test( + test_path=test_file, call_positions=[], function_to_optimize=None, tests_project_root=tmp_path, mode="behavior" + ) + assert success is True + assert content is not None + assert "TestAdd" in content + + def test_instrument_existing_test_missing_file(self, tmp_path: Path) -> None: + support = GoSupport() + success, content = support.instrument_existing_test( + test_path=(tmp_path / "missing.go").resolve(), + call_positions=[], + function_to_optimize=None, + tests_project_root=tmp_path, + mode="behavior", + ) + assert success is False + assert content is None + + def test_postprocess_generated_tests_passthrough(self) -> None: + support = GoSupport() + sentinel = object() + result = support.postprocess_generated_tests(sentinel, "go-test", Path("/project"), Path("/project/calc.go")) # type: ignore[arg-type] + assert result is sentinel + + def test_process_generated_test_strings_passthrough(self) -> None: + support = GoSupport() + gen, beh, perf = support.process_generated_test_strings( + "gen_code", "beh_code", "perf_code", None, Path("/test.go"), None, None + ) + assert gen == "gen_code" + assert beh == "beh_code" + assert perf == "perf_code" + + def test_add_runtime_comments_to_generated_tests_passthrough(self) -> None: + support = GoSupport() + sentinel = object() + result = support.add_runtime_comments_to_generated_tests(sentinel, {}, {}) # type: ignore[arg-type] + assert result is sentinel + + def test_remove_test_functions_from_generated_tests(self) -> None: + from codeflash.models.models import GeneratedTests, GeneratedTestsList + + support = GoSupport() + source = """\ +package calc + +import "testing" + +func TestAdd(t *testing.T) { +\tif Add(1, 2) != 3 { +\t\tt.Fatal("bad") +\t} +} + +func TestSub(t *testing.T) { +\tif Sub(3, 1) != 2 { +\t\tt.Fatal("bad") +\t} +} +""" + gt = GeneratedTests( + generated_original_test_source=source, + instrumented_behavior_test_source=source, + instrumented_perf_test_source=source, + behavior_file_path=Path("/test_beh.go"), + perf_file_path=Path("/test_perf.go"), + ) + tests_list = GeneratedTestsList(generated_tests=[gt]) + result = support.remove_test_functions_from_generated_tests(tests_list, ["TestSub"]) + assert "TestAdd" in result.generated_tests[0].generated_original_test_source + assert "TestSub" not in result.generated_tests[0].generated_original_test_source diff --git a/tests/test_languages/test_golang/test_test_runner.py b/tests/test_languages/test_golang/test_test_runner.py index c812f8c60..a218e6cb7 100644 --- a/tests/test_languages/test_golang/test_test_runner.py +++ b/tests/test_languages/test_golang/test_test_runner.py @@ -2,7 +2,12 @@ from pathlib import Path -from codeflash.languages.golang.test_runner import parse_go_test_json, parse_test_results +from codeflash.languages.golang.test_runner import ( + _collect_original_file_paths, + _hide_original_test_files, + parse_go_test_json, + parse_test_results, +) GO_TEST_JSON_ALL_PASS = """\ @@ -90,3 +95,86 @@ def test_falls_back_to_stdout(self, tmp_path: Path) -> None: assert len(results) == 1 assert results[0].test_name == "TestBad" assert results[0].passed is False + + +class _FakeTestFile: + def __init__(self, instrumented: Path | None = None, original: Path | None = None) -> None: + self.instrumented_behavior_file_path = instrumented + self.original_file_path = original + + +class _FakeTestFiles: + def __init__(self, test_files: list[_FakeTestFile]) -> None: + self.test_files = test_files + + +class TestCollectOriginalFilePaths: + def test_returns_originals_when_instrumented_differs(self, tmp_path: Path) -> None: + original = (tmp_path / "sorting_test.go").resolve() + original.write_text("package x", encoding="utf-8") + instrumented = (tmp_path / "sorting__perfinstrumented_test.go").resolve() + tf = _FakeTestFile(instrumented=instrumented, original=original) + result = _collect_original_file_paths(_FakeTestFiles([tf])) + assert result == [original] + + def test_skips_when_same_path(self, tmp_path: Path) -> None: + original = (tmp_path / "sorting_test.go").resolve() + original.write_text("package x", encoding="utf-8") + tf = _FakeTestFile(instrumented=original, original=original) + result = _collect_original_file_paths(_FakeTestFiles([tf])) + assert result == [] + + def test_skips_missing_original(self, tmp_path: Path) -> None: + original = (tmp_path / "missing_test.go").resolve() + instrumented = (tmp_path / "missing__perfinstrumented_test.go").resolve() + tf = _FakeTestFile(instrumented=instrumented, original=original) + result = _collect_original_file_paths(_FakeTestFiles([tf])) + assert result == [] + + def test_none_test_paths(self) -> None: + assert _collect_original_file_paths(None) == [] + + +class TestHideOriginalTestFiles: + def test_hides_and_restores(self, tmp_path: Path) -> None: + original = (tmp_path / "sorting_test.go").resolve() + original.write_text("package x\n\nfunc TestSort(t *testing.T) {}", encoding="utf-8") + + with _hide_original_test_files([original]): + assert not original.exists() + assert original.with_suffix(".go.codeflash_hidden").exists() + + assert original.exists() + assert not original.with_suffix(".go.codeflash_hidden").exists() + assert original.read_text(encoding="utf-8") == "package x\n\nfunc TestSort(t *testing.T) {}" + + def test_restores_even_on_exception(self, tmp_path: Path) -> None: + original = (tmp_path / "sorting_test.go").resolve() + original.write_text("content", encoding="utf-8") + + try: + with _hide_original_test_files([original]): + raise RuntimeError("boom") + except RuntimeError: + pass + + assert original.exists() + assert not original.with_suffix(".go.codeflash_hidden").exists() + + def test_empty_list_is_noop(self) -> None: + with _hide_original_test_files([]): + pass + + def test_multiple_files(self, tmp_path: Path) -> None: + files = [] + for name in ("a_test.go", "b_test.go"): + f = (tmp_path / name).resolve() + f.write_text(f"package {name}", encoding="utf-8") + files.append(f) + + with _hide_original_test_files(files): + for f in files: + assert not f.exists() + + for f in files: + assert f.exists() From 1961d7c9e588285a60c3af2db1f345bfc389e871 Mon Sep 17 00:00:00 2001 From: ali Date: Tue, 28 Apr 2026 12:42:43 +0300 Subject: [PATCH 05/10] =?UTF-8?q?feat:=20Go=20benchmarking=20pipeline=20wi?= =?UTF-8?q?th=20adaptive=20loops=20and=20Test=E2=86=92Benchmark=20conversi?= =?UTF-8?q?on?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds end-to-end Go benchmarking: convert Test* functions to Benchmark* with b.N loop wrapping, run go test with adaptive loop bounded by max_loops and target_duration_seconds, parse benchmark JSON output (ns/op from Output lines), and prefer goimports over gofmt for unused import cleanup after code replacement. --- codeflash/api/aiservice.py | 2 - codeflash/languages/function_optimizer.py | 85 ++-- codeflash/languages/golang/formatter.py | 50 ++- .../languages/golang/function_optimizer.py | 29 +- codeflash/languages/golang/instrumentation.py | 45 ++ codeflash/languages/golang/parse.py | 110 ++--- codeflash/languages/golang/support.py | 29 +- codeflash/languages/golang/test_runner.py | 233 +++++++--- codeflash/setup/detector.py | 12 +- .../test_golang/test_formatter.py | 6 +- .../test_golang/test_function_optimizer.py | 15 +- .../test_golang/test_instrumentation.py | 151 +++++++ .../test_golang/test_test_runner.py | 411 +++++++++++++++--- 13 files changed, 948 insertions(+), 230 deletions(-) create mode 100644 codeflash/languages/golang/instrumentation.py create mode 100644 tests/test_languages/test_golang/test_instrumentation.py diff --git a/codeflash/api/aiservice.py b/codeflash/api/aiservice.py index ec2960a97..cbf34d17f 100644 --- a/codeflash/api/aiservice.py +++ b/codeflash/api/aiservice.py @@ -113,8 +113,6 @@ def make_ai_service_request( url = f"{self.base_url}/ai{endpoint}" if method.upper() == "POST": json_payload = json.dumps(payload, indent=None, default=pydantic_encoder) - print(f"url: {url}") - print(f"payload: {json_payload}") headers = {**self.headers, "Content-Type": "application/json"} response = requests.post(url, data=json_payload, headers=headers, timeout=timeout) else: diff --git a/codeflash/languages/function_optimizer.py b/codeflash/languages/function_optimizer.py index b18e0c60e..388546a2d 100644 --- a/codeflash/languages/function_optimizer.py +++ b/codeflash/languages/function_optimizer.py @@ -21,7 +21,7 @@ from rich.tree import Tree import codeflash.code_utils._libcst_cache # noqa: F401 -from codeflash.api.aiservice import AiServiceClient, AIServiceRefinerRequest, LocalAiServiceClient +from codeflash.api.aiservice import AiServiceClient, LocalAiServiceClient from codeflash.api.cfapi import add_code_context_hash, create_staging, get_cfapi_base_urls, mark_optimization_success from codeflash.benchmarking.utils import process_benchmark_data from codeflash.cli_cmds.console import ( @@ -1277,28 +1277,30 @@ def process_single_candidate( self.future_adaptive_optimizations.append(future_adaptive_optimization) else: # Refinement for all languages (Python, JavaScript, TypeScript) - future_refinement = self.executor.submit( - aiservice_client.optimize_code_refinement, - request=[ - AIServiceRefinerRequest( - optimization_id=best_optimization.candidate.optimization_id, - original_source_code=code_context.read_writable_code.markdown, - read_only_dependency_code=code_context.read_only_context_code, - original_code_runtime=original_code_baseline.runtime, - optimized_source_code=best_optimization.candidate.source_code.markdown, - optimized_explanation=best_optimization.candidate.explanation, - optimized_code_runtime=best_optimization.runtime, - speedup=f"{int(performance_gain(original_runtime_ns=original_code_baseline.runtime, optimized_runtime_ns=best_optimization.runtime) * 100)}%", - trace_id=self.get_trace_id(exp_type), - original_line_profiler_results=original_code_baseline.line_profile_results["str_out"], - optimized_line_profiler_results=best_optimization.line_profiler_test_results["str_out"], - function_references=function_references, - language=self.function_to_optimize.language, - language_version=self.language_support.language_version, - ) - ], - rerun_trace_id=self.rerun_trace_id, - ) + # future_refinement = self.executor.submit( + # aiservice_client.optimize_code_refinement, + # request=[ + # AIServiceRefinerRequest( + # optimization_id=best_optimization.candidate.optimization_id, + # original_source_code=code_context.read_writable_code.markdown, + # read_only_dependency_code=code_context.read_only_context_code, + # original_code_runtime=original_code_baseline.runtime, + # optimized_source_code=best_optimization.candidate.source_code.markdown, + # optimized_explanation=best_optimization.candidate.explanation, + # optimized_code_runtime=best_optimization.runtime, + # speedup=f"{int(performance_gain(original_runtime_ns=original_code_baseline.runtime, optimized_runtime_ns=best_optimization.runtime) * 100)}%", + # trace_id=self.get_trace_id(exp_type), + # original_line_profiler_results=original_code_baseline.line_profile_results["str_out"], + # optimized_line_profiler_results=best_optimization.line_profiler_test_results["str_out"], + # function_references=function_references, + # language=self.function_to_optimize.language, + # language_version=self.language_support.language_version, + # ) + # ], + # rerun_trace_id=self.rerun_trace_id, + # ) + future_refinement = concurrent.futures.Future() + future_refinement.set_result([]) self.future_all_refinements.append(future_refinement) # Display runtime information @@ -1345,23 +1347,26 @@ def determine_best_candidate( ai_service_client = self.aiservice_client if exp_type == "EXP0" else self.local_aiservice_client assert ai_service_client is not None, "AI service client must be set for optimization" - future_line_profile_results = self.executor.submit( - ai_service_client.optimize_python_code_line_profiler, - source_code=code_context.read_writable_code.markdown, - dependency_code=code_context.read_only_context_code, - trace_id=self.get_trace_id(exp_type), - line_profiler_results=original_code_baseline.line_profile_results["str_out"], - n_candidates=get_effort_value(EffortKeys.N_OPTIMIZER_LP_CANDIDATES, self.effort), - experiment_metadata=ExperimentMetadata( - id=self.experiment_id, group="control" if exp_type == "EXP0" else "experiment" - ) - if self.experiment_id - else None, - is_numerical_code=self.is_numerical_code and not self.args.no_jit_opts, - language=self.function_to_optimize.language, - language_version=self.language_support.language_version, - rerun_trace_id=self.rerun_trace_id, - ) + # future_line_profile_results = self.executor.submit( + # ai_service_client.optimize_python_code_line_profiler, + # source_code=code_context.read_writable_code.markdown, + # dependency_code=code_context.read_only_context_code, + # trace_id=self.get_trace_id(exp_type), + # line_profiler_results=original_code_baseline.line_profile_results["str_out"], + # n_candidates=get_effort_value(EffortKeys.N_OPTIMIZER_LP_CANDIDATES, self.effort), + # experiment_metadata=ExperimentMetadata( + # id=self.experiment_id, group="control" if exp_type == "EXP0" else "experiment" + # ) + # if self.experiment_id + # else None, + # is_numerical_code=self.is_numerical_code and not self.args.no_jit_opts, + # language=self.function_to_optimize.language, + # language_version=self.language_support.language_version, + # rerun_trace_id=self.rerun_trace_id, + # ) + + future_line_profile_results = concurrent.futures.Future() + future_line_profile_results.set_result([]) normalized_original = self.language_support.normalize_code(code_context.read_writable_code.flat.strip()) processor = CandidateProcessor( diff --git a/codeflash/languages/golang/formatter.py b/codeflash/languages/golang/formatter.py index 26dbe8715..61a0bfe58 100644 --- a/codeflash/languages/golang/formatter.py +++ b/codeflash/languages/golang/formatter.py @@ -12,26 +12,48 @@ def format_go_code(source: str, file_path: Path | None = None) -> str: - gofmt = shutil.which("gofmt") - if gofmt is None: - goimports = shutil.which("goimports") - if goimports is not None: - gofmt = goimports - else: - logger.debug("No Go formatter found (gofmt/goimports), returning source unchanged") - return source + goimports = _find_go_tool("goimports") + if goimports is not None: + formatted = _run_formatter(goimports, source) + if formatted is not None: + return formatted + + gofmt = _find_go_tool("gofmt") + if gofmt is not None: + formatted = _run_formatter(gofmt, source) + if formatted is not None: + return formatted + + logger.debug("No Go formatter found (goimports/gofmt), returning source unchanged") + return source + + +def _find_go_tool(name: str) -> str | None: + import os + from pathlib import Path + found = shutil.which(name) + if found: + return found + gopath = os.environ.get("GOPATH") or str(Path.home() / "go") + for bin_dir in ("bin", str(Path("packages") / "bin")): + candidate = Path(gopath) / bin_dir / name + if candidate.is_file() and os.access(candidate, os.X_OK): + return str(candidate) + return None + + +def _run_formatter(tool: str, source: str) -> str | None: try: - result = subprocess.run([gofmt], input=source, capture_output=True, text=True, timeout=15, check=False) + result = subprocess.run([tool], input=source, capture_output=True, text=True, timeout=15, check=False) if result.returncode == 0: return result.stdout - logger.debug("gofmt failed: %s", result.stderr) + logger.debug("%s failed: %s", tool, result.stderr) except subprocess.TimeoutExpired: - logger.warning("gofmt timed out") + logger.warning("%s timed out", tool) except Exception: - logger.debug("gofmt error", exc_info=True) - - return source + logger.debug("%s error", tool, exc_info=True) + return None def normalize_go_code(source: str) -> str: diff --git a/codeflash/languages/golang/function_optimizer.py b/codeflash/languages/golang/function_optimizer.py index c3113276e..184fc0070 100644 --- a/codeflash/languages/golang/function_optimizer.py +++ b/codeflash/languages/golang/function_optimizer.py @@ -64,10 +64,12 @@ def replace_function_and_helpers_with_optimized_code( original_helper_code: dict[Path, str], ) -> bool: from codeflash.languages.code_replacer import replace_function_definitions_for_language + from codeflash.languages.golang.formatter import format_go_code did_update = False + modified_files: list[Path] = [] for module_abspath, qualified_names in self.group_functions_by_file(code_context).items(): - did_update |= replace_function_definitions_for_language( + updated = replace_function_definitions_for_language( function_names=list(qualified_names), optimized_code=optimized_code, module_abspath=module_abspath, @@ -75,9 +77,29 @@ def replace_function_and_helpers_with_optimized_code( lang_support=self.language_support, function_to_optimize=self.function_to_optimize, ) + if updated: + modified_files.append(module_abspath) + did_update |= updated + + for file_path in modified_files: + source = file_path.read_text(encoding="utf-8") + formatted = format_go_code(source, file_path) + if formatted != source: + file_path.write_text(formatted, encoding="utf-8") + return did_update +def _extract_package_name(file_path: Path) -> str | None: + from codeflash.languages.golang.parser import GoAnalyzer + + try: + source = file_path.read_text(encoding="utf-8") + except OSError: + return None + return GoAnalyzer().find_package_name(source) + + def _build_optimization_context( code_context: CodeContext, file_path: Path, @@ -86,6 +108,8 @@ def _build_optimization_context( optim_token_limit: int = OPTIMIZATION_CONTEXT_TOKEN_LIMIT, testgen_token_limit: int = TESTGEN_CONTEXT_TOKEN_LIMIT, ) -> CodeOptimizationContext: + package_name = _extract_package_name(file_path) + if code_context.imports: inner = "\n".join(f"\t{imp}" for imp in code_context.imports) imports_code = f"import (\n{inner}\n)" @@ -121,6 +145,9 @@ def _build_optimization_context( if imports_code: target_file_code = imports_code + "\n\n" + target_file_code + if package_name: + target_file_code = f"package {package_name}\n\n" + target_file_code + read_writable_code_strings = [CodeString(code=target_file_code, file_path=target_relative_path, language=language)] for helper_file_path, file_helpers in helpers_by_file.items(): diff --git a/codeflash/languages/golang/instrumentation.py b/codeflash/languages/golang/instrumentation.py new file mode 100644 index 000000000..e9c928c3f --- /dev/null +++ b/codeflash/languages/golang/instrumentation.py @@ -0,0 +1,45 @@ +from __future__ import annotations + +import re + +_FUNC_BODY_RE = re.compile(r"^(func\s+)(Test\w+)(\s*\(\s*)(\w+)(\s+\*testing\.T\s*\)\s*\{)", re.MULTILINE) +_PARALLEL_RE = re.compile(r"^\s*\w+\.Parallel\(\)\s*\n?", re.MULTILINE) +_HELPER_RE = re.compile(r"^\s*\w+\.Helper\(\)\s*\n?", re.MULTILINE) + + +def convert_tests_to_benchmarks(test_source: str, target_function_name: str = "") -> str: + if not test_source.strip(): + return test_source + + if not _FUNC_BODY_RE.search(test_source): + return test_source + + result = test_source + + for match in reversed(list(_FUNC_BODY_RE.finditer(result))): + func_prefix = match.group(1) + test_name = match.group(2) + paren_open = match.group(3) + param_name = match.group(4) + + bench_name = "Benchmark" + test_name[len("Test") :] + + body_start = match.end() + brace_depth = 1 + pos = body_start + while pos < len(result) and brace_depth > 0: + if result[pos] == "{": + brace_depth += 1 + elif result[pos] == "}": + brace_depth -= 1 + pos += 1 + + body = result[body_start : pos - 1] + + new_sig = f"{func_prefix}{bench_name}{paren_open}{param_name} *testing.B) {{\n\tfor i := 0; i < {param_name}.N; i++ {{" + new_func = f"{new_sig}{body}\t}}\n}}" + result = result[: match.start()] + new_func + result[pos:] + + result = result.replace("*testing.T", "*testing.B") + result = _PARALLEL_RE.sub("", result) + return _HELPER_RE.sub("", result) diff --git a/codeflash/languages/golang/parse.py b/codeflash/languages/golang/parse.py index 40ed18f91..42da7a59c 100644 --- a/codeflash/languages/golang/parse.py +++ b/codeflash/languages/golang/parse.py @@ -43,8 +43,8 @@ def parse_go_test_output( logger.warning("No valid JSON events found in %s", test_json_path) return test_results - test_states: dict[str, _TestState] = {} - benchmark_results: dict[str, _BenchmarkResult] = {} + iterations: list[_TestIteration] = [] + active: dict[str, _TestIteration] = {} for event in events: action = event.get("Action") @@ -52,70 +52,85 @@ def parse_go_test_output( package = event.get("Package", "") if test_name is None: + if action == "output": + output_text = event.get("Output", "") + bench_match = BENCHMARK_RE.search(output_text) + if bench_match: + bench_name = bench_match.group(1) + it = _TestIteration(test_name=bench_name, package=package) + it.passed = True + it.bench_ns_per_op = float(bench_match.group(3)) + it.bench_iterations = int(bench_match.group(2)) + it.stdout = output_text + iterations.append(it) continue - if test_name not in test_states: - test_states[test_name] = _TestState(package=package) + if action == "run": + if test_name in active: + iterations.append(active[test_name]) + active[test_name] = _TestIteration(test_name=test_name, package=package) + continue - state = test_states[test_name] + it = active.get(test_name) + if it is None: + it = _TestIteration(test_name=test_name, package=package) + active[test_name] = it if action == "output": output_text = event.get("Output", "") - state.stdout += output_text + it.stdout += output_text bench_match = BENCHMARK_RE.search(output_text) if bench_match: - bench_name = bench_match.group(1) - iterations = int(bench_match.group(2)) - ns_per_op = float(bench_match.group(3)) - b_per_op = int(bench_match.group(4)) if bench_match.group(4) else None - allocs_per_op = int(bench_match.group(5)) if bench_match.group(5) else None - benchmark_results[bench_name] = _BenchmarkResult( - ns_per_op=ns_per_op, iterations=iterations, b_per_op=b_per_op, allocs_per_op=allocs_per_op - ) - elif action == "pass": - state.passed = True - elapsed = event.get("Elapsed", 0) - state.runtime_ns = int(elapsed * 1_000_000_000) if elapsed else None - elif action == "fail": - state.passed = False + it.bench_ns_per_op = float(bench_match.group(3)) + it.bench_iterations = int(bench_match.group(2)) + elif action in ("pass", "fail"): + it.passed = action == "pass" elapsed = event.get("Elapsed", 0) - state.runtime_ns = int(elapsed * 1_000_000_000) if elapsed else None + it.elapsed_ns = int(elapsed * 1_000_000_000) if elapsed else None + iterations.append(active.pop(test_name)) + + for it in active.values(): + if it.passed is not None: + iterations.append(it) + loop_counters: dict[str, int] = {} base_dir = test_config.tests_project_rootdir - for test_name, state in test_states.items(): - if state.passed is None: + for it in iterations: + if it.passed is None: continue - test_file_path = _resolve_test_file(test_name, state.package, test_files, base_dir) + loop_index = loop_counters.get(it.test_name, 0) + 1 + loop_counters[it.test_name] = loop_index + + runtime_ns = it.bench_ns_per_op if it.bench_ns_per_op is not None else it.elapsed_ns + if runtime_ns is not None: + runtime_ns = int(runtime_ns) + + test_file_path = _resolve_test_file(it.test_name, it.package, test_files, base_dir) test_type = _resolve_test_type(test_file_path, test_files) if test_type is None: - logger.debug("Skipping test %s: could not resolve test type", test_name) + logger.debug("Skipping test %s: could not resolve test type", it.test_name) continue - runtime_ns = state.runtime_ns - bench = benchmark_results.get(test_name) - if bench is not None: - runtime_ns = int(bench.ns_per_op) - test_results.add( FunctionTestInvocation( - loop_index=1, + loop_index=loop_index, id=InvocationId( - test_module_path=state.package, + test_module_path=it.package, test_class_name=None, - test_function_name=test_name, + test_function_name=it.test_name, function_getting_tested="", iteration_id="", ), file_name=test_file_path, runtime=runtime_ns, test_framework="go-test", - did_pass=state.passed, + did_pass=it.passed, test_type=test_type, return_value=None, timed_out=False, - stdout=state.stdout, + stdout=it.stdout, ) ) @@ -124,31 +139,24 @@ def parse_go_test_output( if run_result is not None: logger.debug("stdout: %s\nstderr: %s", run_result.stdout, run_result.stderr) + logger.debug("[BENCHMARK-DONE] Got %d benchmark results", len(test_results)) + return test_results -class _TestState: - __slots__ = ("package", "passed", "runtime_ns", "stdout") +class _TestIteration: + __slots__ = ("bench_iterations", "bench_ns_per_op", "elapsed_ns", "package", "passed", "stdout", "test_name") - def __init__(self, package: str) -> None: + def __init__(self, test_name: str, package: str) -> None: + self.test_name = test_name self.package = package self.passed: bool | None = None - self.runtime_ns: int | None = None + self.elapsed_ns: int | None = None + self.bench_ns_per_op: float | None = None + self.bench_iterations: int | None = None self.stdout: str = "" -class _BenchmarkResult: - __slots__ = ("allocs_per_op", "b_per_op", "iterations", "ns_per_op") - - def __init__( - self, ns_per_op: float, iterations: int, b_per_op: int | None = None, allocs_per_op: int | None = None - ) -> None: - self.ns_per_op = ns_per_op - self.iterations = iterations - self.b_per_op = b_per_op - self.allocs_per_op = allocs_per_op - - def _read_json_output(path: Path, run_result: subprocess.CompletedProcess | None) -> str: try: content = path.read_text(encoding="utf-8") diff --git a/codeflash/languages/golang/support.py b/codeflash/languages/golang/support.py index e6c0e78dc..0552735c8 100644 --- a/codeflash/languages/golang/support.py +++ b/codeflash/languages/golang/support.py @@ -220,16 +220,25 @@ def instrument_for_behavior(self, source: str, functions: Sequence[FunctionToOpt return source def instrument_for_benchmarking(self, test_source: str, target_function: FunctionToOptimize) -> str: - return test_source + from codeflash.languages.golang.instrumentation import convert_tests_to_benchmarks + + func_name = target_function.function_name if target_function else "" + return convert_tests_to_benchmarks(test_source, func_name) def instrument_existing_test( self, test_path: Path, call_positions: Any, function_to_optimize: Any, tests_project_root: Path, mode: str ) -> tuple[bool, str | None]: - _ = call_positions, function_to_optimize, tests_project_root, mode + _ = call_positions, tests_project_root try: - return True, test_path.read_text(encoding="utf-8") + source = test_path.read_text(encoding="utf-8") except Exception: return False, None + if mode == "performance": + from codeflash.languages.golang.instrumentation import convert_tests_to_benchmarks + + func_name = function_to_optimize.function_name if function_to_optimize else "" + source = convert_tests_to_benchmarks(source, func_name) + return True, source def postprocess_generated_tests( self, generated_tests: GeneratedTestsList, test_framework: str, project_root: Path, source_file_path: Path @@ -247,7 +256,11 @@ def process_generated_test_strings( test_cfg: Any, project_module_system: str | None, ) -> tuple[str, str, str]: - _ = function_to_optimize, test_path, test_cfg, project_module_system + _ = test_path, test_cfg, project_module_system + from codeflash.languages.golang.instrumentation import convert_tests_to_benchmarks + + func_name = function_to_optimize.function_name if function_to_optimize else "" + instrumented_perf_test_source = convert_tests_to_benchmarks(instrumented_perf_test_source, func_name) return generated_test_source, instrumented_behavior_test_source, instrumented_perf_test_source def load_coverage(self, *args: Any, **kwargs: Any) -> Any: @@ -256,6 +269,14 @@ def load_coverage(self, *args: Any, **kwargs: Any) -> Any: def get_test_file_suffix(self) -> str: return "_test.go" + def resolve_test_file_from_class_path(self, test_class_path: str, base_dir: Path) -> Path | None: + return None + + def resolve_test_module_path_for_pr( + self, test_module_path: str, tests_project_rootdir: Path, non_generated_tests: set[Path] + ) -> Path | None: + return None + def find_test_root(self, project_root: Path) -> Path | None: return project_root diff --git a/codeflash/languages/golang/test_runner.py b/codeflash/languages/golang/test_runner.py index 0a3006d78..0accd4953 100644 --- a/codeflash/languages/golang/test_runner.py +++ b/codeflash/languages/golang/test_runner.py @@ -4,9 +4,11 @@ import json import logging import os +import re import signal import subprocess import sys +import time from typing import TYPE_CHECKING, Any from codeflash.languages.base import TestResult @@ -38,10 +40,13 @@ def run_behavioral_tests( env = {**os.environ, **test_env} - cmd = ["go", "test", "-json", "-v", "-count=1", *packages] - - originals = _collect_original_file_paths(test_paths) - with _hide_original_test_files(originals): + others = _collect_other_test_files(test_file_paths) + with _hide_other_test_files(others), _deduplicated_test_files(test_file_paths): + run_regex = _build_run_regex(test_file_paths) + cmd = ["go", "test", "-json", "-v", "-count=1"] + if run_regex: + cmd.extend(["-run", run_regex]) + cmd.extend(packages) proc_result = _run_cmd_kill_pg_on_timeout(cmd, cwd=cwd, env=env, timeout=timeout) json_output_file.write_text(proc_result.stdout or "", encoding="utf-8") @@ -64,30 +69,72 @@ def run_benchmarking_tests( result_dir.mkdir(parents=True, exist_ok=True) json_output_file = result_dir / "benchmark.jsonl" - test_file_paths = _collect_test_file_paths(test_paths) + test_file_paths = _collect_test_file_paths(test_paths, use_benchmarking=True) packages = _test_files_to_packages(test_file_paths, cwd) if not packages: packages = ["./..."] env = {**os.environ, **test_env} - benchtime = f"{target_duration_seconds:.0f}s" - cmd = [ - "go", - "test", - "-json", - "-v", - "-bench=.", - f"-benchtime={benchtime}", - "-benchmem", - f"-count={min_loops}", - "-run=^$", - *packages, - ] - - originals = _collect_original_file_paths(test_paths) - with _hide_original_test_files(originals): - proc_result = _run_cmd_kill_pg_on_timeout(cmd, cwd=cwd, env=env, timeout=timeout) + others = _collect_other_test_files(test_file_paths) + with _hide_other_test_files(others), _deduplicated_test_files(test_file_paths): + bench_regex = _build_bench_regex(test_file_paths) + if bench_regex: + benchtime_secs = min(target_duration_seconds, 1.0) + num_benchmarks = len(_extract_func_names(test_file_paths, _BENCH_FUNC_RE)) + per_loop_estimate = int(num_benchmarks * benchtime_secs * 2) + 10 + cmd = [ + "go", + "test", + "-json", + "-v", + f"-bench={bench_regex}", + f"-benchtime={benchtime_secs:.0f}s", + # "-benchmem", + "-count=1", # setting count to as we looping manually to track timeout and max_loop + "-run=^$", + f"-timeout={per_loop_estimate}s", + *packages, + ] + # logger.info("Benchmark command: %s", cmd) + all_stdout: list[str] = [] + all_stderr: list[str] = [] + last_returncode = 0 + start_time = time.monotonic() + for loop in range(1, max_loops + 1): + proc_result = _run_cmd_kill_pg_on_timeout(cmd, cwd=cwd, env=env, timeout=per_loop_estimate) + if proc_result.stdout: + all_stdout.append(proc_result.stdout) + if proc_result.stderr: + all_stderr.append(proc_result.stderr) + last_returncode = proc_result.returncode + if proc_result.returncode != 0: + logger.warning( + "Benchmark loop %d failed (rc=%d):\nstdout:%s\nstderr: %s", + loop, + proc_result.returncode, + proc_result.stdout, + proc_result.stderr, + ) + break + elapsed = time.monotonic() - start_time + if loop >= min_loops and elapsed >= target_duration_seconds: + logger.info( + "Benchmark stopping after %d loops (%.1fs elapsed, target %.1fs)", + loop, + elapsed, + target_duration_seconds, + ) + break + logger.info("Benchmark completed %d loop(s), returncode: %d", loop, last_returncode) + combined_stdout = "".join(all_stdout) + combined_stderr = "".join(all_stderr) + proc_result = subprocess.CompletedProcess( + args=cmd, returncode=last_returncode, stdout=combined_stdout, stderr=combined_stderr + ) + else: + logger.warning("No Benchmark* functions found in perf test files: %s", [str(p) for p in test_file_paths]) + proc_result = subprocess.CompletedProcess(args=[], returncode=0, stdout="", stderr="") json_output_file.write_text(proc_result.stdout or "", encoding="utf-8") @@ -169,7 +216,7 @@ def _package_to_path(package: str) -> Path: return _Path() -def _collect_test_file_paths(test_paths: Any) -> list[Path]: +def _collect_test_file_paths(test_paths: Any, *, use_benchmarking: bool = False) -> list[Path]: from pathlib import Path as _Path if test_paths is None: @@ -178,7 +225,12 @@ def _collect_test_file_paths(test_paths: Any) -> list[Path]: if hasattr(test_paths, "test_files"): paths = [] for tf in test_paths.test_files: - p = getattr(tf, "instrumented_behavior_file_path", None) or getattr(tf, "original_file_path", None) + if use_benchmarking: + p = getattr(tf, "benchmarking_file_path", None) or getattr(tf, "perf_file_path", None) + else: + p = getattr(tf, "instrumented_behavior_file_path", None) + if p is None: + p = getattr(tf, "original_file_path", None) if p is not None: paths.append(_Path(p)) return paths @@ -189,40 +241,39 @@ def _collect_test_file_paths(test_paths: Any) -> list[Path]: return [] -def _collect_original_file_paths(test_paths: Any) -> list[Path]: - from pathlib import Path as _Path +def _collect_other_test_files(test_file_paths: list[Path]) -> list[Path]: - if test_paths is None or not hasattr(test_paths, "test_files"): + if not test_file_paths: return [] - originals: list[Path] = [] - for tf in test_paths.test_files: - instrumented = getattr(tf, "instrumented_behavior_file_path", None) - original = getattr(tf, "original_file_path", None) - if instrumented is not None and original is not None: - instrumented_p = _Path(instrumented) - original_p = _Path(original) - if instrumented_p != original_p and original_p.exists(): - originals.append(original_p) - return originals + keep = {f.resolve() for f in test_file_paths} + dirs = {f.resolve().parent for f in test_file_paths} + + others: list[Path] = [] + for d in dirs: + for f in d.glob("*_test.go"): + if f.resolve() not in keep and f.exists(): + others.append(f) + return others @contextlib.contextmanager -def _hide_original_test_files(originals: list[Path]) -> Generator[None, None, None]: - """Temporarily rename original test files so `go test` only sees the instrumented copies. +def _hide_other_test_files(others: list[Path]) -> Generator[None, None, None]: + """Temporarily rename test files we don't want compiled. - Go compiles all *_test.go files in a package together, so having both the original - and its instrumented copy causes duplicate symbol errors. + Go compiles ALL *_test.go files in a package together, so any duplicate + symbols across test files cause build errors. We hide every test file in + the target directories except the ones we intend to run. """ renamed: list[tuple[Path, Path]] = [] - for original in originals: - hidden = original.with_suffix(".go.codeflash_hidden") + for f in others: + hidden = f.with_suffix(".go.codeflash_hidden") try: - original.rename(hidden) - renamed.append((hidden, original)) - logger.debug("Temporarily hid %s during go test", original) + f.rename(hidden) + renamed.append((hidden, f)) + logger.debug("Temporarily hid %s during go test", f) except OSError: - logger.debug("Could not hide %s, skipping", original) + logger.debug("Could not hide %s, skipping", f) try: yield finally: @@ -234,12 +285,96 @@ def _hide_original_test_files(originals: list[Path]) -> Generator[None, None, No logger.warning("Failed to restore %s from %s", original, hidden) +_TEST_FUNC_RE = re.compile(r"^func\s+(Test\w+)\s*\(", re.MULTILINE) +_BENCH_FUNC_RE = re.compile(r"^func\s+(Benchmark\w+)\s*\(", re.MULTILINE) +_FUNC_DECL_RE = re.compile(r"^(func\s+)(Test\w+|Benchmark\w+)(\s*\()", re.MULTILINE) + + +def _extract_func_names(test_files: list[Path], pattern: re.Pattern[str]) -> list[str]: + names: list[str] = [] + for f in test_files: + try: + content = f.read_text(encoding="utf-8") + except OSError: + continue + names.extend(pattern.findall(content)) + return names + + +def _build_run_regex(test_files: list[Path]) -> str | None: + names = _extract_func_names(test_files, _TEST_FUNC_RE) + if not names: + return None + return f"^({'|'.join(re.escape(n) for n in names)})$" + + +def _build_bench_regex(test_files: list[Path]) -> str | None: + names = _extract_func_names(test_files, _BENCH_FUNC_RE) + if not names: + return None + return f"^({'|'.join(re.escape(n) for n in names)})$" + + +def _deduplicate_test_func_names(test_files: list[Path]) -> dict[Path, str]: + seen: dict[str, int] = {} + originals: dict[Path, str] = {} + + for f in test_files: + try: + content = f.read_text(encoding="utf-8") + except OSError: + continue + + names_in_file = [name for _, name, _ in _FUNC_DECL_RE.findall(content)] + if not names_in_file: + continue + + needs_rewrite = any(name in seen for name in names_in_file) + + if not needs_rewrite: + for name in names_in_file: + seen[name] = 1 + continue + + originals[f] = content + + def _renamer(m: re.Match[str]) -> str: + prefix, name, suffix = m.group(1), m.group(2), m.group(3) + if name not in seen: + seen[name] = 1 + return m.group(0) + idx = seen[name] + seen[name] = idx + 1 + return f"{prefix}{name}_{idx}{suffix}" + + new_content = _FUNC_DECL_RE.sub(_renamer, content) + f.write_text(new_content, encoding="utf-8") + logger.debug("Deduplicated test function names in %s", f) + + return originals + + +@contextlib.contextmanager +def _deduplicated_test_files(test_files: list[Path]) -> Generator[None, None, None]: + originals = _deduplicate_test_func_names(test_files) + try: + yield + finally: + for f, content in originals.items(): + try: + f.write_text(content, encoding="utf-8") + except OSError: + logger.warning("Failed to restore original content for %s", f) + + def _test_files_to_packages(test_files: list[Path], cwd: Path) -> list[str]: dirs: set[str] = set() + resolved_cwd = cwd.resolve() for f in test_files: try: - rel = f.resolve().parent.relative_to(cwd.resolve()) - dirs.add(f"./{rel.as_posix()}") + rel = f.resolve().parent.relative_to(resolved_cwd) + pkg = f"./{rel.as_posix()}" if rel.parts else "." + dirs.add(pkg) except ValueError: continue return sorted(dirs) if dirs else [] diff --git a/codeflash/setup/detector.py b/codeflash/setup/detector.py index 3d7a0ea45..a84f26b72 100644 --- a/codeflash/setup/detector.py +++ b/codeflash/setup/detector.py @@ -848,10 +848,14 @@ def _detect_go_formatter(project_root: Path) -> tuple[list[str], str]: Go has a universal formatter (gofmt). goimports is preferred if available because it also manages imports. """ - if shutil.which("goimports"): - return ["goimports -w $file"], "goimports (auto-detected)" - if shutil.which("gofmt"): - return ["gofmt -w $file"], "gofmt (auto-detected)" + from codeflash.languages.golang.formatter import _find_go_tool + + goimports = _find_go_tool("goimports") + if goimports: + return [f"{goimports} -w $file"], "goimports (auto-detected)" + gofmt = _find_go_tool("gofmt") + if gofmt: + return [f"{gofmt} -w $file"], "gofmt (auto-detected)" return ["gofmt -w $file"], "gofmt (default)" diff --git a/tests/test_languages/test_golang/test_formatter.py b/tests/test_languages/test_golang/test_formatter.py index 2c36f44b6..4665aaa6d 100644 --- a/tests/test_languages/test_golang/test_formatter.py +++ b/tests/test_languages/test_golang/test_formatter.py @@ -55,7 +55,9 @@ def test_mixed_comments(self) -> None: "}\n" ) result = normalize_go_code(source) - expected = "package calc\nfunc Add(a, b int) int {\nreturn a + b\n}\nfunc Subtract(a, b int) int {\nreturn a - b\n}" + expected = ( + "package calc\nfunc Add(a, b int) int {\nreturn a + b\n}\nfunc Subtract(a, b int) int {\nreturn a - b\n}" + ) assert result == expected def test_inline_block_comment(self) -> None: @@ -76,7 +78,7 @@ def test_only_comments(self) -> None: class TestFormatGoCode: def test_no_formatter_returns_source(self) -> None: source = "package calc\n\nfunc Add(a, b int) int {\nreturn a+b\n}\n" - with patch("codeflash.languages.golang.formatter.shutil.which", return_value=None): + with patch("codeflash.languages.golang.formatter._find_go_tool", return_value=None): result = format_go_code(source) assert result == source diff --git a/tests/test_languages/test_golang/test_function_optimizer.py b/tests/test_languages/test_golang/test_function_optimizer.py index 6dff40585..b76c82ba6 100644 --- a/tests/test_languages/test_golang/test_function_optimizer.py +++ b/tests/test_languages/test_golang/test_function_optimizer.py @@ -161,6 +161,8 @@ def test_full_assembled_code_string(self, tmp_path: Path) -> None: code = result.read_writable_code.code_strings[0].code expected = dedent("""\ + package calc + import ( \t"fmt" \t"math" @@ -174,10 +176,10 @@ def test_full_assembled_code_string(self, tmp_path: Path) -> None: """) assert code == expected - def test_code_excludes_package_clause(self, tmp_path: Path) -> None: + def test_code_includes_package_clause(self, tmp_path: Path) -> None: result = _build_context_for_function(CALCULATOR_SOURCE, "calc.go", "Add", tmp_path) code = result.read_writable_code.code_strings[0].code - assert "package calc" not in code + assert code.startswith("package calc\n") def test_code_excludes_struct_definition(self, tmp_path: Path) -> None: result = _build_context_for_function(CALCULATOR_SOURCE, "calc.go", "Add", tmp_path) @@ -235,6 +237,8 @@ def test_full_assembled_code_string(self, tmp_path: Path) -> None: code = result.read_writable_code.code_strings[0].code expected = dedent("""\ + package calc + import ( \t"fmt" \t"math" @@ -250,10 +254,9 @@ def test_full_assembled_code_string(self, tmp_path: Path) -> None: """) assert code == expected - def test_code_excludes_package_and_type_defs(self, tmp_path: Path) -> None: + def test_code_excludes_type_defs(self, tmp_path: Path) -> None: result = self._build(tmp_path) code = result.read_writable_code.code_strings[0].code - assert "package calc" not in code assert "type Calculator struct" not in code assert "type Formatter interface" not in code @@ -334,10 +337,12 @@ def test_struct_in_read_only_context(self, tmp_path: Path) -> None: class TestBuildContextMinimalSource: """Target: Double(x int) — minimal file with no imports or structs.""" - def test_no_imports_no_prefix(self, tmp_path: Path) -> None: + def test_no_imports_package_only_prefix(self, tmp_path: Path) -> None: result = _build_context_for_function(SIMPLE_SOURCE, "simple.go", "Double", tmp_path) code = result.read_writable_code.code_strings[0].code assert code == dedent("""\ + package simple + func Double(x int) int { \treturn x * 2 }""") diff --git a/tests/test_languages/test_golang/test_instrumentation.py b/tests/test_languages/test_golang/test_instrumentation.py new file mode 100644 index 000000000..07f608d91 --- /dev/null +++ b/tests/test_languages/test_golang/test_instrumentation.py @@ -0,0 +1,151 @@ +from __future__ import annotations + +from codeflash.languages.golang.instrumentation import convert_tests_to_benchmarks + +SIMPLE_TEST = """\ +package sample + +import "testing" + +func TestAdd(t *testing.T) { +\tgot := Add(1, 2) +\tif got != 3 { +\t\tt.Errorf("Add(1, 2) = %d, want 3", got) +\t} +} +""" + +TEST_WITH_SUBTESTS = """\ +package sample + +import "testing" + +func TestBubbleSort_BasicCases(t *testing.T) { +\ttests := []struct { +\t\tname string +\t\tinput []int +\t\twant []int +\t}{ +\t\t{"sorted", []int{1, 2, 3}, []int{1, 2, 3}}, +\t} +\tfor _, tt := range tests { +\t\tt.Run(tt.name, func(t *testing.T) { +\t\t\tgot := BubbleSort(tt.input) +\t\t\tif len(got) != len(tt.want) { +\t\t\t\tt.Errorf("wrong length") +\t\t\t} +\t\t}) +\t} +} +""" + +MULTIPLE_TESTS = """\ +package sample + +import "testing" + +func TestFoo(t *testing.T) { +\tFoo() +} + +func TestBar(t *testing.T) { +\tBar() +} +""" + +BENCHMARK_ONLY = """\ +package sample + +import "testing" + +func BenchmarkFoo(b *testing.B) { +\tfor i := 0; i < b.N; i++ { +\t\tFoo() +\t} +} +""" + +TEST_WITH_HELPER = """\ +package sample + +import "testing" + +func equalSlices(t *testing.T, got, want []int) { +\tif len(got) != len(want) { +\t\tt.Fatalf("length mismatch") +\t} +} + +func TestBFS(t *testing.T) { +\tgot := BFS(graph, 0) +\tequalSlices(t, got, []int{0, 1, 2}) +} +""" + +TEST_WITH_PARALLEL = """\ +package sample + +import "testing" + +func TestFoo(t *testing.T) { +\tt.Parallel() +\tFoo() +} + +func TestBar(t *testing.T) { +\tt.Helper() +\tt.Parallel() +\tBar() +} +""" + + +class TestConvertTestsToBenchmarks: + def test_simple_test(self) -> None: + result = convert_tests_to_benchmarks(SIMPLE_TEST, "Add") + assert "func BenchmarkAdd(" in result + assert "*testing.B)" in result + assert "for i := 0; i < " in result + assert ".N; i++ {" in result + assert "func TestAdd(" not in result + + def test_subtests_converted(self) -> None: + result = convert_tests_to_benchmarks(TEST_WITH_SUBTESTS, "BubbleSort") + assert "func BenchmarkBubbleSort_BasicCases(" in result + assert "*testing.T" not in result + + def test_multiple_functions(self) -> None: + result = convert_tests_to_benchmarks(MULTIPLE_TESTS, "Foo") + assert "func BenchmarkFoo(" in result + assert "func BenchmarkBar(" in result + assert "func TestFoo(" not in result + assert "func TestBar(" not in result + + def test_empty_source(self) -> None: + assert convert_tests_to_benchmarks("", "Foo") == "" + + def test_no_test_functions(self) -> None: + result = convert_tests_to_benchmarks(BENCHMARK_ONLY, "Foo") + assert result == BENCHMARK_ONLY + + def test_package_preserved(self) -> None: + result = convert_tests_to_benchmarks(SIMPLE_TEST, "Add") + assert result.startswith("package sample") + + def test_import_preserved(self) -> None: + result = convert_tests_to_benchmarks(SIMPLE_TEST, "Add") + assert 'import "testing"' in result + + def test_helper_functions_converted(self) -> None: + result = convert_tests_to_benchmarks(TEST_WITH_HELPER, "BFS") + assert "func BenchmarkBFS(" in result + assert "*testing.T" not in result + assert "equalSlices" in result + assert "*testing.B" in result + + def test_parallel_removed(self) -> None: + result = convert_tests_to_benchmarks(TEST_WITH_PARALLEL, "Foo") + assert ".Parallel()" not in result + assert ".Helper()" not in result + assert "func BenchmarkFoo(" in result + assert "func BenchmarkBar(" in result diff --git a/tests/test_languages/test_golang/test_test_runner.py b/tests/test_languages/test_golang/test_test_runner.py index a218e6cb7..dea932f3c 100644 --- a/tests/test_languages/test_golang/test_test_runner.py +++ b/tests/test_languages/test_golang/test_test_runner.py @@ -3,8 +3,16 @@ from pathlib import Path from codeflash.languages.golang.test_runner import ( - _collect_original_file_paths, - _hide_original_test_files, + _build_bench_regex, + _build_run_regex, + _collect_other_test_files, + _deduplicate_test_func_names, + _deduplicated_test_files, + _extract_func_names, + _hide_other_test_files, + _test_files_to_packages, + _BENCH_FUNC_RE, + _TEST_FUNC_RE, parse_go_test_json, parse_test_results, ) @@ -79,10 +87,7 @@ def test_zero_elapsed(self) -> None: class TestParseTestResults: def test_reads_from_file(self, tmp_path: Path) -> None: json_file = (tmp_path / "results.jsonl").resolve() - json_file.write_text( - '{"Action":"pass","Package":"calc","Test":"TestAdd","Elapsed":0.001}\n', - encoding="utf-8", - ) + json_file.write_text('{"Action":"pass","Package":"calc","Test":"TestAdd","Elapsed":0.001}\n', encoding="utf-8") results = parse_test_results(json_file, "") assert len(results) == 1 assert results[0].test_name == "TestAdd" @@ -97,72 +102,90 @@ def test_falls_back_to_stdout(self, tmp_path: Path) -> None: assert results[0].passed is False -class _FakeTestFile: - def __init__(self, instrumented: Path | None = None, original: Path | None = None) -> None: - self.instrumented_behavior_file_path = instrumented - self.original_file_path = original - - -class _FakeTestFiles: - def __init__(self, test_files: list[_FakeTestFile]) -> None: - self.test_files = test_files - - -class TestCollectOriginalFilePaths: - def test_returns_originals_when_instrumented_differs(self, tmp_path: Path) -> None: - original = (tmp_path / "sorting_test.go").resolve() - original.write_text("package x", encoding="utf-8") - instrumented = (tmp_path / "sorting__perfinstrumented_test.go").resolve() - tf = _FakeTestFile(instrumented=instrumented, original=original) - result = _collect_original_file_paths(_FakeTestFiles([tf])) - assert result == [original] - - def test_skips_when_same_path(self, tmp_path: Path) -> None: - original = (tmp_path / "sorting_test.go").resolve() - original.write_text("package x", encoding="utf-8") - tf = _FakeTestFile(instrumented=original, original=original) - result = _collect_original_file_paths(_FakeTestFiles([tf])) +class TestCollectOtherTestFiles: + def test_finds_other_test_files_in_same_dir(self, tmp_path: Path) -> None: + keep = (tmp_path / "instrumented_test.go").resolve() + keep.write_text("package x", encoding="utf-8") + other1 = (tmp_path / "sorting_test.go").resolve() + other1.write_text("package x", encoding="utf-8") + other2 = (tmp_path / "perf_test.go").resolve() + other2.write_text("package x", encoding="utf-8") + + result = _collect_other_test_files([keep]) + resolved = {f.resolve() for f in result} + assert other1.resolve() in resolved + assert other2.resolve() in resolved + assert keep.resolve() not in resolved + + def test_keeps_only_specified_files(self, tmp_path: Path) -> None: + f1 = (tmp_path / "a_test.go").resolve() + f1.write_text("package x", encoding="utf-8") + f2 = (tmp_path / "b_test.go").resolve() + f2.write_text("package x", encoding="utf-8") + + result = _collect_other_test_files([f1, f2]) assert result == [] - def test_skips_missing_original(self, tmp_path: Path) -> None: - original = (tmp_path / "missing_test.go").resolve() - instrumented = (tmp_path / "missing__perfinstrumented_test.go").resolve() - tf = _FakeTestFile(instrumented=instrumented, original=original) - result = _collect_original_file_paths(_FakeTestFiles([tf])) - assert result == [] - - def test_none_test_paths(self) -> None: - assert _collect_original_file_paths(None) == [] - - -class TestHideOriginalTestFiles: + def test_ignores_non_test_files(self, tmp_path: Path) -> None: + keep = (tmp_path / "target_test.go").resolve() + keep.write_text("package x", encoding="utf-8") + non_test = (tmp_path / "helper.go").resolve() + non_test.write_text("package x", encoding="utf-8") + + result = _collect_other_test_files([keep]) + assert all(f.name.endswith("_test.go") for f in result) + assert non_test not in result + + def test_empty_list(self) -> None: + assert _collect_other_test_files([]) == [] + + def test_multiple_dirs(self, tmp_path: Path) -> None: + d1 = (tmp_path / "pkg1").resolve() + d1.mkdir() + d2 = (tmp_path / "pkg2").resolve() + d2.mkdir() + keep1 = (d1 / "target_test.go").resolve() + keep1.write_text("package pkg1", encoding="utf-8") + other1 = (d1 / "old_test.go").resolve() + other1.write_text("package pkg1", encoding="utf-8") + keep2 = (d2 / "target_test.go").resolve() + keep2.write_text("package pkg2", encoding="utf-8") + + result = _collect_other_test_files([keep1, keep2]) + resolved = {f.resolve() for f in result} + assert other1.resolve() in resolved + assert keep1.resolve() not in resolved + assert keep2.resolve() not in resolved + + +class TestHideOtherTestFiles: def test_hides_and_restores(self, tmp_path: Path) -> None: - original = (tmp_path / "sorting_test.go").resolve() - original.write_text("package x\n\nfunc TestSort(t *testing.T) {}", encoding="utf-8") + other = (tmp_path / "sorting_test.go").resolve() + other.write_text("package x\n\nfunc TestSort(t *testing.T) {}", encoding="utf-8") - with _hide_original_test_files([original]): - assert not original.exists() - assert original.with_suffix(".go.codeflash_hidden").exists() + with _hide_other_test_files([other]): + assert not other.exists() + assert other.with_suffix(".go.codeflash_hidden").exists() - assert original.exists() - assert not original.with_suffix(".go.codeflash_hidden").exists() - assert original.read_text(encoding="utf-8") == "package x\n\nfunc TestSort(t *testing.T) {}" + assert other.exists() + assert not other.with_suffix(".go.codeflash_hidden").exists() + assert other.read_text(encoding="utf-8") == "package x\n\nfunc TestSort(t *testing.T) {}" def test_restores_even_on_exception(self, tmp_path: Path) -> None: - original = (tmp_path / "sorting_test.go").resolve() - original.write_text("content", encoding="utf-8") + other = (tmp_path / "sorting_test.go").resolve() + other.write_text("content", encoding="utf-8") try: - with _hide_original_test_files([original]): + with _hide_other_test_files([other]): raise RuntimeError("boom") except RuntimeError: pass - assert original.exists() - assert not original.with_suffix(".go.codeflash_hidden").exists() + assert other.exists() + assert not other.with_suffix(".go.codeflash_hidden").exists() def test_empty_list_is_noop(self) -> None: - with _hide_original_test_files([]): + with _hide_other_test_files([]): pass def test_multiple_files(self, tmp_path: Path) -> None: @@ -172,9 +195,281 @@ def test_multiple_files(self, tmp_path: Path) -> None: f.write_text(f"package {name}", encoding="utf-8") files.append(f) - with _hide_original_test_files(files): + with _hide_other_test_files(files): for f in files: assert not f.exists() for f in files: assert f.exists() + + +GO_TEST_SOURCE = """\ +package sorting + +import "testing" + +func TestBubbleSort_Basic(t *testing.T) {} +func TestBubbleSort_EdgeCases(t *testing.T) {} +""" + +GO_BENCH_SOURCE = """\ +package sorting + +import "testing" + +func BenchmarkBubbleSort(b *testing.B) {} +func BenchmarkBubbleSort_Large(b *testing.B) {} +""" + +GO_MIXED_SOURCE = """\ +package sorting + +import "testing" + +func TestBubbleSort(t *testing.T) {} +func BenchmarkBubbleSort(b *testing.B) {} +""" + + +class TestExtractFuncNames: + def test_extracts_test_funcs(self, tmp_path: Path) -> None: + f = (tmp_path / "sorting_test.go").resolve() + f.write_text(GO_TEST_SOURCE, encoding="utf-8") + names = _extract_func_names([f], _TEST_FUNC_RE) + assert names == ["TestBubbleSort_Basic", "TestBubbleSort_EdgeCases"] + + def test_extracts_bench_funcs(self, tmp_path: Path) -> None: + f = (tmp_path / "sorting_test.go").resolve() + f.write_text(GO_BENCH_SOURCE, encoding="utf-8") + names = _extract_func_names([f], _BENCH_FUNC_RE) + assert names == ["BenchmarkBubbleSort", "BenchmarkBubbleSort_Large"] + + def test_test_regex_does_not_match_benchmarks(self, tmp_path: Path) -> None: + f = (tmp_path / "sorting_test.go").resolve() + f.write_text(GO_BENCH_SOURCE, encoding="utf-8") + names = _extract_func_names([f], _TEST_FUNC_RE) + assert names == [] + + def test_multiple_files(self, tmp_path: Path) -> None: + f1 = (tmp_path / "a_test.go").resolve() + f1.write_text("package x\nfunc TestA(t *testing.T) {}", encoding="utf-8") + f2 = (tmp_path / "b_test.go").resolve() + f2.write_text("package x\nfunc TestB(t *testing.T) {}", encoding="utf-8") + names = _extract_func_names([f1, f2], _TEST_FUNC_RE) + assert names == ["TestA", "TestB"] + + def test_missing_file_skipped(self, tmp_path: Path) -> None: + missing = (tmp_path / "missing_test.go").resolve() + names = _extract_func_names([missing], _TEST_FUNC_RE) + assert names == [] + + def test_empty_list(self) -> None: + assert _extract_func_names([], _TEST_FUNC_RE) == [] + + +class TestBuildRunRegex: + def test_single_test_func(self, tmp_path: Path) -> None: + f = (tmp_path / "a_test.go").resolve() + f.write_text("package x\nfunc TestFoo(t *testing.T) {}", encoding="utf-8") + regex = _build_run_regex([f]) + assert regex == "^(TestFoo)$" + + def test_multiple_test_funcs(self, tmp_path: Path) -> None: + f = (tmp_path / "a_test.go").resolve() + f.write_text(GO_TEST_SOURCE, encoding="utf-8") + regex = _build_run_regex([f]) + assert regex == "^(TestBubbleSort_Basic|TestBubbleSort_EdgeCases)$" + + def test_no_test_funcs_returns_none(self, tmp_path: Path) -> None: + f = (tmp_path / "a_test.go").resolve() + f.write_text("package x\nfunc helper() {}", encoding="utf-8") + assert _build_run_regex([f]) is None + + def test_empty_files_returns_none(self) -> None: + assert _build_run_regex([]) is None + + +class TestBuildBenchRegex: + def test_single_bench_func(self, tmp_path: Path) -> None: + f = (tmp_path / "a_test.go").resolve() + f.write_text('package x\nimport "testing"\nfunc BenchmarkFoo(b *testing.B) {}', encoding="utf-8") + regex = _build_bench_regex([f]) + assert regex == "^(BenchmarkFoo)$" + + def test_no_bench_funcs_returns_none(self, tmp_path: Path) -> None: + f = (tmp_path / "a_test.go").resolve() + f.write_text(GO_TEST_SOURCE, encoding="utf-8") + assert _build_bench_regex([f]) is None + + +class TestTestFilesToPackages: + def test_subdirectory(self, tmp_path: Path) -> None: + subdir = (tmp_path / "sorting").resolve() + subdir.mkdir() + f = subdir / "sorting_test.go" + f.write_text("package sorting", encoding="utf-8") + packages = _test_files_to_packages([f.resolve()], tmp_path.resolve()) + assert packages == ["./sorting"] + + def test_root_directory(self, tmp_path: Path) -> None: + f = (tmp_path / "main_test.go").resolve() + f.write_text("package main", encoding="utf-8") + packages = _test_files_to_packages([f], tmp_path.resolve()) + assert packages == ["."] + + def test_deduplicates_same_package(self, tmp_path: Path) -> None: + f1 = (tmp_path / "a_test.go").resolve() + f1.write_text("package x", encoding="utf-8") + f2 = (tmp_path / "b_test.go").resolve() + f2.write_text("package x", encoding="utf-8") + packages = _test_files_to_packages([f1, f2], tmp_path.resolve()) + assert packages == ["."] + + def test_multiple_packages(self, tmp_path: Path) -> None: + for name in ("pkg1", "pkg2"): + d = (tmp_path / name).resolve() + d.mkdir() + (d / "x_test.go").write_text(f"package {name}", encoding="utf-8") + f1 = (tmp_path / "pkg1" / "x_test.go").resolve() + f2 = (tmp_path / "pkg2" / "x_test.go").resolve() + packages = _test_files_to_packages([f1, f2], tmp_path.resolve()) + assert packages == ["./pkg1", "./pkg2"] + + def test_empty_list(self, tmp_path: Path) -> None: + assert _test_files_to_packages([], tmp_path.resolve()) == [] + + def test_file_outside_cwd_skipped(self, tmp_path: Path) -> None: + other = (tmp_path / "other").resolve() + other.mkdir() + f = (other / "x_test.go").resolve() + f.write_text("package x", encoding="utf-8") + cwd = (tmp_path / "project").resolve() + cwd.mkdir() + assert _test_files_to_packages([f], cwd) == [] + + +GO_FILE_A = """\ +package x + +import "testing" + +func TestFoo(t *testing.T) {} +func TestBar(t *testing.T) {} +""" + +GO_FILE_B_DUPLICATES = """\ +package x + +import "testing" + +func TestFoo(t *testing.T) {} +func TestBaz(t *testing.T) {} +""" + +GO_FILE_C_MORE_DUPLICATES = """\ +package x + +import "testing" + +func TestFoo(t *testing.T) {} +func TestBar(t *testing.T) {} +func TestNew(t *testing.T) {} +""" + + +class TestDeduplicateTestFuncNames: + def test_no_duplicates_no_changes(self, tmp_path: Path) -> None: + f1 = (tmp_path / "a_test.go").resolve() + f1.write_text(GO_FILE_A, encoding="utf-8") + originals = _deduplicate_test_func_names([f1]) + assert originals == {} + assert f1.read_text(encoding="utf-8") == GO_FILE_A + + def test_renames_duplicates_in_second_file(self, tmp_path: Path) -> None: + f1 = (tmp_path / "a_test.go").resolve() + f1.write_text(GO_FILE_A, encoding="utf-8") + f2 = (tmp_path / "b_test.go").resolve() + f2.write_text(GO_FILE_B_DUPLICATES, encoding="utf-8") + + originals = _deduplicate_test_func_names([f1, f2]) + + assert f1.read_text(encoding="utf-8") == GO_FILE_A + assert f2 in originals + assert originals[f2] == GO_FILE_B_DUPLICATES + + rewritten = f2.read_text(encoding="utf-8") + assert "func TestFoo_1(" in rewritten + assert "func TestBaz(" in rewritten + assert "func TestFoo(" not in rewritten + + def test_renames_across_three_files(self, tmp_path: Path) -> None: + f1 = (tmp_path / "a_test.go").resolve() + f1.write_text(GO_FILE_A, encoding="utf-8") + f2 = (tmp_path / "b_test.go").resolve() + f2.write_text(GO_FILE_B_DUPLICATES, encoding="utf-8") + f3 = (tmp_path / "c_test.go").resolve() + f3.write_text(GO_FILE_C_MORE_DUPLICATES, encoding="utf-8") + + _deduplicate_test_func_names([f1, f2, f3]) + + rewritten_b = f2.read_text(encoding="utf-8") + rewritten_c = f3.read_text(encoding="utf-8") + + assert "func TestFoo_1(" in rewritten_b + assert "func TestFoo_2(" in rewritten_c + assert "func TestBar_1(" in rewritten_c + + def test_empty_list(self) -> None: + assert _deduplicate_test_func_names([]) == {} + + def test_single_file_no_changes(self, tmp_path: Path) -> None: + f = (tmp_path / "a_test.go").resolve() + f.write_text(GO_FILE_B_DUPLICATES, encoding="utf-8") + originals = _deduplicate_test_func_names([f]) + assert originals == {} + + def test_benchmarks_also_deduplicated(self, tmp_path: Path) -> None: + f1 = (tmp_path / "a_test.go").resolve() + f1.write_text('package x\n\nimport "testing"\n\nfunc BenchmarkFoo(b *testing.B) {}\n', encoding="utf-8") + f2 = (tmp_path / "b_test.go").resolve() + f2.write_text('package x\n\nimport "testing"\n\nfunc BenchmarkFoo(b *testing.B) {}\n', encoding="utf-8") + + _deduplicate_test_func_names([f1, f2]) + + assert "func BenchmarkFoo(" in f1.read_text(encoding="utf-8") + rewritten = f2.read_text(encoding="utf-8") + assert "func BenchmarkFoo_1(" in rewritten + + +class TestDeduplicatedTestFiles: + def test_restores_after_context(self, tmp_path: Path) -> None: + f1 = (tmp_path / "a_test.go").resolve() + f1.write_text(GO_FILE_A, encoding="utf-8") + f2 = (tmp_path / "b_test.go").resolve() + f2.write_text(GO_FILE_B_DUPLICATES, encoding="utf-8") + + with _deduplicated_test_files([f1, f2]): + assert "func TestFoo_1(" in f2.read_text(encoding="utf-8") + + assert f2.read_text(encoding="utf-8") == GO_FILE_B_DUPLICATES + + def test_restores_on_exception(self, tmp_path: Path) -> None: + f1 = (tmp_path / "a_test.go").resolve() + f1.write_text(GO_FILE_A, encoding="utf-8") + f2 = (tmp_path / "b_test.go").resolve() + f2.write_text(GO_FILE_B_DUPLICATES, encoding="utf-8") + + try: + with _deduplicated_test_files([f1, f2]): + raise RuntimeError("boom") + except RuntimeError: + pass + + assert f2.read_text(encoding="utf-8") == GO_FILE_B_DUPLICATES + + def test_no_duplicates_is_noop(self, tmp_path: Path) -> None: + f = (tmp_path / "a_test.go").resolve() + f.write_text(GO_FILE_A, encoding="utf-8") + + with _deduplicated_test_files([f]): + assert f.read_text(encoding="utf-8") == GO_FILE_A From 4c4485115a21e3ff820dfa659d098c17be55cddd Mon Sep 17 00:00:00 2001 From: ali Date: Tue, 28 Apr 2026 13:52:53 +0300 Subject: [PATCH 06/10] only benchmark relevant tests to get rid of noise --- code_to_optimize/go/go.mod | 2 +- code_to_optimize/go/sorter/sorting.go | 103 ++++++++++++++++++ codeflash/languages/golang/instrumentation.py | 13 ++- .../test_golang/test_instrumentation.py | 33 +++++- 4 files changed, 145 insertions(+), 6 deletions(-) create mode 100644 code_to_optimize/go/sorter/sorting.go diff --git a/code_to_optimize/go/go.mod b/code_to_optimize/go/go.mod index f037eded9..d45a82bbd 100644 --- a/code_to_optimize/go/go.mod +++ b/code_to_optimize/go/go.mod @@ -1,3 +1,3 @@ module example/codeflash-go-sample -go 1.21 +go 1.26 diff --git a/code_to_optimize/go/sorter/sorting.go b/code_to_optimize/go/sorter/sorting.go new file mode 100644 index 000000000..22e821561 --- /dev/null +++ b/code_to_optimize/go/sorter/sorting.go @@ -0,0 +1,103 @@ +package sorter + +func BubbleSort(arr []int) []int { + if len(arr) == 0 { + return arr + } + + result := make([]int, len(arr)) + copy(result, arr) + n := len(result) + + // Standard optimized bubble sort: + // - reduce inner loop bound each pass (last elements are already sorted) + // - stop early if no swaps occurred in a pass + for i := 0; i < n-1; i++ { + swapped := false + // after i passes, the last i elements are in place + limit := n - 1 - i + for j := 0; j < limit; j++ { + if result[j] > result[j+1] { + // swap + result[j], result[j+1] = result[j+1], result[j] + swapped = true + } + } + if !swapped { + break + } + } + return result +} + +func BubbleSortDescending(arr []int) []int { + if len(arr) == 0 { + return arr + } + + result := make([]int, len(arr)) + copy(result, arr) + n := len(result) + + for i := 0; i < n-1; i++ { + for j := 0; j < n-i-1; j++ { + if result[j] < result[j+1] { + temp := result[j] + result[j] = result[j+1] + result[j+1] = temp + } + } + } + return result +} + +func InsertionSort(arr []int) []int { + if len(arr) == 0 { + return arr + } + + result := make([]int, len(arr)) + copy(result, arr) + n := len(result) + + for i := 1; i < n; i++ { + key := result[i] + j := i - 1 + for j >= 0 && result[j] > key { + result[j+1] = result[j] + j-- + } + result[j+1] = key + } + return result +} + +func SelectionSort(arr []int) []int { + if len(arr) == 0 { + return arr + } + + result := make([]int, len(arr)) + copy(result, arr) + n := len(result) + + for i := 0; i < n-1; i++ { + minIdx := i + for j := i + 1; j < n; j++ { + if result[j] < result[minIdx] { + minIdx = j + } + } + result[minIdx], result[i] = result[i], result[minIdx] + } + return result +} + +func IsSorted(arr []int) bool { + for i := 0; i < len(arr)-1; i++ { + if arr[i] > arr[i+1] { + return false + } + } + return true +} diff --git a/codeflash/languages/golang/instrumentation.py b/codeflash/languages/golang/instrumentation.py index e9c928c3f..122d9e6c8 100644 --- a/codeflash/languages/golang/instrumentation.py +++ b/codeflash/languages/golang/instrumentation.py @@ -7,6 +7,12 @@ _HELPER_RE = re.compile(r"^\s*\w+\.Helper\(\)\s*\n?", re.MULTILINE) +def _test_matches_target(test_name: str, target_function_name: str) -> bool: + remainder = test_name[len("Test") :] + segments = remainder.split("_") + return target_function_name in segments + + def convert_tests_to_benchmarks(test_source: str, target_function_name: str = "") -> str: if not test_source.strip(): return test_source @@ -22,8 +28,6 @@ def convert_tests_to_benchmarks(test_source: str, target_function_name: str = "" paren_open = match.group(3) param_name = match.group(4) - bench_name = "Benchmark" + test_name[len("Test") :] - body_start = match.end() brace_depth = 1 pos = body_start @@ -34,7 +38,12 @@ def convert_tests_to_benchmarks(test_source: str, target_function_name: str = "" brace_depth -= 1 pos += 1 + if target_function_name and not _test_matches_target(test_name, target_function_name): + result = result[: match.start()] + result[pos:] + continue + body = result[body_start : pos - 1] + bench_name = "Benchmark" + test_name[len("Test") :] new_sig = f"{func_prefix}{bench_name}{paren_open}{param_name} *testing.B) {{\n\tfor i := 0; i < {param_name}.N; i++ {{" new_func = f"{new_sig}{body}\t}}\n}}" diff --git a/tests/test_languages/test_golang/test_instrumentation.py b/tests/test_languages/test_golang/test_instrumentation.py index 07f608d91..182733ab8 100644 --- a/tests/test_languages/test_golang/test_instrumentation.py +++ b/tests/test_languages/test_golang/test_instrumentation.py @@ -1,6 +1,6 @@ from __future__ import annotations -from codeflash.languages.golang.instrumentation import convert_tests_to_benchmarks +from codeflash.languages.golang.instrumentation import _test_matches_target, convert_tests_to_benchmarks SIMPLE_TEST = """\ package sample @@ -100,6 +100,26 @@ """ +class TestMatchesTarget: + def test_exact_match(self) -> None: + assert _test_matches_target("TestBFS", "BFS") is True + + def test_prefix_segment_match(self) -> None: + assert _test_matches_target("TestBFS_BasicCases", "BFS") is True + + def test_suffix_segment_match(self) -> None: + assert _test_matches_target("TestGraph_BFS", "BFS") is True + + def test_no_match_substring(self) -> None: + assert _test_matches_target("TestBFSHelper", "BFS") is False + + def test_no_match_different_function(self) -> None: + assert _test_matches_target("TestDFS", "BFS") is False + + def test_multi_underscore(self) -> None: + assert _test_matches_target("TestBFS_Large_Graph", "BFS") is True + + class TestConvertTestsToBenchmarks: def test_simple_test(self) -> None: result = convert_tests_to_benchmarks(SIMPLE_TEST, "Add") @@ -114,9 +134,16 @@ def test_subtests_converted(self) -> None: assert "func BenchmarkBubbleSort_BasicCases(" in result assert "*testing.T" not in result - def test_multiple_functions(self) -> None: + def test_multiple_functions_filtered(self) -> None: result = convert_tests_to_benchmarks(MULTIPLE_TESTS, "Foo") assert "func BenchmarkFoo(" in result + assert "func BenchmarkBar(" not in result + assert "func TestFoo(" not in result + assert "func TestBar(" not in result + + def test_multiple_functions_no_filter(self) -> None: + result = convert_tests_to_benchmarks(MULTIPLE_TESTS, "") + assert "func BenchmarkFoo(" in result assert "func BenchmarkBar(" in result assert "func TestFoo(" not in result assert "func TestBar(" not in result @@ -148,4 +175,4 @@ def test_parallel_removed(self) -> None: assert ".Parallel()" not in result assert ".Helper()" not in result assert "func BenchmarkFoo(" in result - assert "func BenchmarkBar(" in result + assert "func BenchmarkBar(" not in result From c4a59a5355ca5f6641c9056f10d4c269352c2524 Mon Sep 17 00:00:00 2001 From: ali Date: Tue, 28 Apr 2026 13:53:11 +0300 Subject: [PATCH 07/10] only benchmark relevant tests to get rid of noise --- codeflash/api/aiservice.py | 2 +- codeflash/cli_cmds/cli.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/codeflash/api/aiservice.py b/codeflash/api/aiservice.py index cbf34d17f..3127649f2 100644 --- a/codeflash/api/aiservice.py +++ b/codeflash/api/aiservice.py @@ -49,7 +49,7 @@ def __init__(self) -> None: self.is_local = self.base_url == "http://localhost:8000" # (connect_timeout, read_timeout) — connect should be fast; read # can be slow because the server runs LLM inference. - self.timeout: float | tuple[float, float] | None = (10, 600) + self.timeout: float | tuple[float, float] | None = (10, 300) def get_next_sequence(self) -> int: """Get the next LLM call sequence number.""" diff --git a/codeflash/cli_cmds/cli.py b/codeflash/cli_cmds/cli.py index b39d9567c..cd4d94787 100644 --- a/codeflash/cli_cmds/cli.py +++ b/codeflash/cli_cmds/cli.py @@ -124,6 +124,7 @@ def process_pyproject_config(args: Namespace) -> Namespace: if args.tests_root is None: if is_go_project: + # this is just a placeholder, in go we put generated test files in the same package as the source args.tests_root = args.module_root elif is_java_project: # Try standard Maven/Gradle test directories From 9654991e1288c9a8cc8f1242d62101d8cf165346 Mon Sep 17 00:00:00 2001 From: ali Date: Tue, 28 Apr 2026 14:13:12 +0300 Subject: [PATCH 08/10] cleaning up --- codeflash/discovery/functions_to_optimize.py | 2 - codeflash/languages/function_optimizer.py | 84 ++++++++++---------- 2 files changed, 40 insertions(+), 46 deletions(-) diff --git a/codeflash/discovery/functions_to_optimize.py b/codeflash/discovery/functions_to_optimize.py index e311e2190..fdac43c25 100644 --- a/codeflash/discovery/functions_to_optimize.py +++ b/codeflash/discovery/functions_to_optimize.py @@ -883,8 +883,6 @@ def is_test_file(file_path_normalized: str) -> bool: site_packages_removed_count += len(_functions) continue if not file_path_normalized.startswith(module_root_str + os.sep): - print(f"module_root_str: {module_root_str}") - print(f"file_path_normalized: {file_path_normalized}") non_modules_removed_count += len(_functions) continue diff --git a/codeflash/languages/function_optimizer.py b/codeflash/languages/function_optimizer.py index 388546a2d..859e6ba16 100644 --- a/codeflash/languages/function_optimizer.py +++ b/codeflash/languages/function_optimizer.py @@ -78,6 +78,7 @@ AdaptiveOptimizedCandidate, AIServiceAdaptiveOptimizeRequest, AIServiceCodeRepairRequest, + AIServiceRefinerRequest, BestOptimization, CandidateEvaluationContext, GeneratedTests, @@ -1277,30 +1278,28 @@ def process_single_candidate( self.future_adaptive_optimizations.append(future_adaptive_optimization) else: # Refinement for all languages (Python, JavaScript, TypeScript) - # future_refinement = self.executor.submit( - # aiservice_client.optimize_code_refinement, - # request=[ - # AIServiceRefinerRequest( - # optimization_id=best_optimization.candidate.optimization_id, - # original_source_code=code_context.read_writable_code.markdown, - # read_only_dependency_code=code_context.read_only_context_code, - # original_code_runtime=original_code_baseline.runtime, - # optimized_source_code=best_optimization.candidate.source_code.markdown, - # optimized_explanation=best_optimization.candidate.explanation, - # optimized_code_runtime=best_optimization.runtime, - # speedup=f"{int(performance_gain(original_runtime_ns=original_code_baseline.runtime, optimized_runtime_ns=best_optimization.runtime) * 100)}%", - # trace_id=self.get_trace_id(exp_type), - # original_line_profiler_results=original_code_baseline.line_profile_results["str_out"], - # optimized_line_profiler_results=best_optimization.line_profiler_test_results["str_out"], - # function_references=function_references, - # language=self.function_to_optimize.language, - # language_version=self.language_support.language_version, - # ) - # ], - # rerun_trace_id=self.rerun_trace_id, - # ) - future_refinement = concurrent.futures.Future() - future_refinement.set_result([]) + future_refinement = self.executor.submit( + aiservice_client.optimize_code_refinement, + request=[ + AIServiceRefinerRequest( + optimization_id=best_optimization.candidate.optimization_id, + original_source_code=code_context.read_writable_code.markdown, + read_only_dependency_code=code_context.read_only_context_code, + original_code_runtime=original_code_baseline.runtime, + optimized_source_code=best_optimization.candidate.source_code.markdown, + optimized_explanation=best_optimization.candidate.explanation, + optimized_code_runtime=best_optimization.runtime, + speedup=f"{int(performance_gain(original_runtime_ns=original_code_baseline.runtime, optimized_runtime_ns=best_optimization.runtime) * 100)}%", + trace_id=self.get_trace_id(exp_type), + original_line_profiler_results=original_code_baseline.line_profile_results["str_out"], + optimized_line_profiler_results=best_optimization.line_profiler_test_results["str_out"], + function_references=function_references, + language=self.function_to_optimize.language, + language_version=self.language_support.language_version, + ) + ], + rerun_trace_id=self.rerun_trace_id, + ) self.future_all_refinements.append(future_refinement) # Display runtime information @@ -1347,26 +1346,23 @@ def determine_best_candidate( ai_service_client = self.aiservice_client if exp_type == "EXP0" else self.local_aiservice_client assert ai_service_client is not None, "AI service client must be set for optimization" - # future_line_profile_results = self.executor.submit( - # ai_service_client.optimize_python_code_line_profiler, - # source_code=code_context.read_writable_code.markdown, - # dependency_code=code_context.read_only_context_code, - # trace_id=self.get_trace_id(exp_type), - # line_profiler_results=original_code_baseline.line_profile_results["str_out"], - # n_candidates=get_effort_value(EffortKeys.N_OPTIMIZER_LP_CANDIDATES, self.effort), - # experiment_metadata=ExperimentMetadata( - # id=self.experiment_id, group="control" if exp_type == "EXP0" else "experiment" - # ) - # if self.experiment_id - # else None, - # is_numerical_code=self.is_numerical_code and not self.args.no_jit_opts, - # language=self.function_to_optimize.language, - # language_version=self.language_support.language_version, - # rerun_trace_id=self.rerun_trace_id, - # ) - - future_line_profile_results = concurrent.futures.Future() - future_line_profile_results.set_result([]) + future_line_profile_results = self.executor.submit( + ai_service_client.optimize_python_code_line_profiler, + source_code=code_context.read_writable_code.markdown, + dependency_code=code_context.read_only_context_code, + trace_id=self.get_trace_id(exp_type), + line_profiler_results=original_code_baseline.line_profile_results["str_out"], + n_candidates=get_effort_value(EffortKeys.N_OPTIMIZER_LP_CANDIDATES, self.effort), + experiment_metadata=ExperimentMetadata( + id=self.experiment_id, group="control" if exp_type == "EXP0" else "experiment" + ) + if self.experiment_id + else None, + is_numerical_code=self.is_numerical_code and not self.args.no_jit_opts, + language=self.function_to_optimize.language, + language_version=self.language_support.language_version, + rerun_trace_id=self.rerun_trace_id, + ) normalized_original = self.language_support.normalize_code(code_context.read_writable_code.flat.strip()) processor = CandidateProcessor( From ea51f780a3e07d0c1e291bc453b8ce4bc9841736 Mon Sep 17 00:00:00 2001 From: ali Date: Tue, 28 Apr 2026 14:34:15 +0300 Subject: [PATCH 09/10] version --- codeflash/version.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/codeflash/version.py b/codeflash/version.py index b354a5b56..226fdf7ad 100644 --- a/codeflash/version.py +++ b/codeflash/version.py @@ -1,2 +1,2 @@ # These version placeholders will be replaced by uv-dynamic-versioning during build. -__version__ = "0.20.5.post243.dev0+67cf12392" +__version__ = "0.20.5" From a5aa75d717ab6ffacdd2bce693e36233fb563063 Mon Sep 17 00:00:00 2001 From: ali Date: Tue, 28 Apr 2026 14:46:49 +0300 Subject: [PATCH 10/10] fix typing issues --- codeflash/languages/golang/context.py | 3 +- .../languages/golang/function_optimizer.py | 2 +- codeflash/languages/golang/parse.py | 30 ++++++++----------- codeflash/languages/golang/replacement.py | 19 +++++++----- codeflash/languages/golang/support.py | 6 ++-- 5 files changed, 30 insertions(+), 30 deletions(-) diff --git a/codeflash/languages/golang/context.py b/codeflash/languages/golang/context.py index a2a608e2b..eec372e5d 100644 --- a/codeflash/languages/golang/context.py +++ b/codeflash/languages/golang/context.py @@ -3,8 +3,9 @@ import logging from typing import TYPE_CHECKING -from codeflash.languages.base import CodeContext, HelperFunction, Language +from codeflash.languages.base import CodeContext, HelperFunction from codeflash.languages.golang.parser import GoAnalyzer +from codeflash.languages.language_enum import Language if TYPE_CHECKING: from pathlib import Path diff --git a/codeflash/languages/golang/function_optimizer.py b/codeflash/languages/golang/function_optimizer.py index 184fc0070..0b679da3b 100644 --- a/codeflash/languages/golang/function_optimizer.py +++ b/codeflash/languages/golang/function_optimizer.py @@ -27,7 +27,7 @@ class GoFunctionOptimizer(FunctionOptimizer): def get_code_optimization_context(self) -> Result[CodeOptimizationContext, str]: from codeflash.languages import get_language_support - from codeflash.languages.base import Language + from codeflash.languages.language_enum import Language language = Language(self.function_to_optimize.language) lang_support = get_language_support(language) diff --git a/codeflash/languages/golang/parse.py b/codeflash/languages/golang/parse.py index 42da7a59c..e6c0f09e7 100644 --- a/codeflash/languages/golang/parse.py +++ b/codeflash/languages/golang/parse.py @@ -3,7 +3,7 @@ import json import logging import re -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any from codeflash.models.models import FunctionTestInvocation, InvocationId, TestResults @@ -12,6 +12,7 @@ from pathlib import Path from codeflash.models.models import TestFiles + from codeflash.models.test_type import TestType from codeflash.verification.verification_utils import TestConfig logger = logging.getLogger(__name__) @@ -29,7 +30,7 @@ def parse_go_test_output( test_json_path: Path, test_files: TestFiles, test_config: TestConfig, - run_result: subprocess.CompletedProcess | None = None, + run_result: subprocess.CompletedProcess[str] | None = None, ) -> TestResults: test_results = TestResults() @@ -71,10 +72,11 @@ def parse_go_test_output( active[test_name] = _TestIteration(test_name=test_name, package=package) continue - it = active.get(test_name) - if it is None: - it = _TestIteration(test_name=test_name, package=package) - active[test_name] = it + maybe_it = active.get(test_name) + if maybe_it is None: + maybe_it = _TestIteration(test_name=test_name, package=package) + active[test_name] = maybe_it + it = maybe_it if action == "output": output_text = event.get("Output", "") @@ -109,9 +111,6 @@ def parse_go_test_output( test_file_path = _resolve_test_file(it.test_name, it.package, test_files, base_dir) test_type = _resolve_test_type(test_file_path, test_files) - if test_type is None: - logger.debug("Skipping test %s: could not resolve test type", it.test_name) - continue test_results.add( FunctionTestInvocation( @@ -157,7 +156,7 @@ def __init__(self, test_name: str, package: str) -> None: self.stdout: str = "" -def _read_json_output(path: Path, run_result: subprocess.CompletedProcess | None) -> str: +def _read_json_output(path: Path, run_result: subprocess.CompletedProcess[str] | None) -> str: try: content = path.read_text(encoding="utf-8") if content.strip(): @@ -165,15 +164,12 @@ def _read_json_output(path: Path, run_result: subprocess.CompletedProcess | None except Exception: pass if run_result is not None: - stdout = run_result.stdout - if isinstance(stdout, bytes): - stdout = stdout.decode("utf-8", errors="replace") - return stdout or "" + return run_result.stdout or "" return "" -def _parse_json_lines(content: str) -> list[dict]: - events: list[dict] = [] +def _parse_json_lines(content: str) -> list[dict[str, Any]]: + events: list[dict[str, Any]] = [] for line in content.splitlines(): line = line.strip() if not line: @@ -199,7 +195,7 @@ def _resolve_test_file(test_name: str, package: str, test_files: TestFiles, base return base_dir / f"{test_name}.go" -def _resolve_test_type(test_file_path: Path, test_files: TestFiles): +def _resolve_test_type(test_file_path: Path, test_files: TestFiles) -> TestType: from codeflash.models.test_type import TestType test_type = test_files.get_test_type_by_instrumented_file_path(test_file_path) diff --git a/codeflash/languages/golang/replacement.py b/codeflash/languages/golang/replacement.py index 68b1c811e..c6301312f 100644 --- a/codeflash/languages/golang/replacement.py +++ b/codeflash/languages/golang/replacement.py @@ -6,7 +6,10 @@ from codeflash.languages.golang.parser import GoAnalyzer if TYPE_CHECKING: + import tree_sitter + from codeflash.discovery.functions_to_optimize import FunctionToOptimize + from codeflash.languages.golang.parser import GoGlobalDeclaration logger = logging.getLogger(__name__) @@ -102,7 +105,7 @@ def _merge_global_var_const(optimized_code: str, original_source: str, analyzer: return original_source orig_decls = analyzer.find_global_declarations(original_source) - orig_names_to_decl: dict[str, object] = {} + orig_names_to_decl: dict[str, GoGlobalDeclaration] = {} for decl in orig_decls: for name in decl.names: orig_names_to_decl[name] = decl @@ -131,7 +134,7 @@ def _merge_global_var_const(optimized_code: str, original_source: str, analyzer: return original_source -def _replace_declaration_block(source: str, orig_decl: object, new_source_code: str) -> str: +def _replace_declaration_block(source: str, orig_decl: GoGlobalDeclaration, new_source_code: str) -> str: lines = source.splitlines(keepends=True) start = orig_decl.starting_line - 1 end = orig_decl.ending_line @@ -186,19 +189,19 @@ def remove_test_functions(test_source: str, functions_to_remove: list[str], anal return "".join(lines) -def _find_doc_comment_start(node: object) -> int | None: - prev = getattr(node, "prev_named_sibling", None) +def _find_doc_comment_start(node: tree_sitter.Node) -> int | None: + prev = node.prev_named_sibling if prev is None: return None - if getattr(prev, "type", None) != "comment": + if prev.type != "comment": return None if prev.end_point.row + 1 != node.start_point.row: return None - comment_start = prev.start_point.row + 1 + comment_start: int = prev.start_point.row + 1 current = prev while True: - earlier = getattr(current, "prev_named_sibling", None) - if earlier is None or getattr(earlier, "type", None) != "comment": + earlier = current.prev_named_sibling + if earlier is None or earlier.type != "comment": break if earlier.end_point.row + 1 != current.start_point.row: break diff --git a/codeflash/languages/golang/support.py b/codeflash/languages/golang/support.py index 0552735c8..aaa7c8e3f 100644 --- a/codeflash/languages/golang/support.py +++ b/codeflash/languages/golang/support.py @@ -4,6 +4,7 @@ from pathlib import Path from typing import TYPE_CHECKING, Any +from codeflash.languages.base import LanguageSupport from codeflash.languages.golang.comparator import compare_test_results as _compare_results from codeflash.languages.golang.config import detect_go_project, detect_go_version from codeflash.languages.golang.context import extract_code_context as _extract_context @@ -29,18 +30,17 @@ DependencyResolver, FunctionFilterCriteria, HelperFunction, - InvocationId, ReferenceInfo, TestInfo, ) from codeflash.models.function_types import FunctionToOptimize - from codeflash.models.models import GeneratedTestsList + from codeflash.models.models import GeneratedTestsList, InvocationId logger = logging.getLogger(__name__) @register_language -class GoSupport: +class GoSupport(LanguageSupport): def __init__(self) -> None: self._analyzer = GoAnalyzer() self._go_version: str | None = None