diff --git a/scripts/microgenerator/generate.py b/scripts/microgenerator/generate.py index 2fc356464..d9f27c85b 100644 --- a/scripts/microgenerator/generate.py +++ b/scripts/microgenerator/generate.py @@ -84,6 +84,13 @@ def _get_type_str(self, node: ast.AST | None) -> str | None: # Handles forward references as strings, e.g., '"Dataset"' if isinstance(node, ast.Constant): return repr(node.value) + + # Handles | union types, e.g., int | float + if isinstance(node, ast.BinOp) and isinstance(node.op, ast.BitOr): + left_str = self._get_type_str(node.left) + right_str = self._get_type_str(node.right) + return f"{left_str} | {right_str}" + return None # Fallback for unhandled types def _collect_types_from_node(self, node: ast.AST | None) -> None: diff --git a/scripts/microgenerator/tests/unit/test_generate_analyzer.py b/scripts/microgenerator/tests/unit/test_generate_analyzer.py index e17078a54..71ee37bf4 100644 --- a/scripts/microgenerator/tests/unit/test_generate_analyzer.py +++ b/scripts/microgenerator/tests/unit/test_generate_analyzer.py @@ -16,7 +16,8 @@ import ast import pytest -from scripts.microgenerator.generate import CodeAnalyzer +import textwrap as tw +from scripts.microgenerator.generate import parse_code, CodeAnalyzer # --- Tests CodeAnalyzer handling of Imports --- @@ -93,7 +94,6 @@ def test_import_extraction(self, code_snippet, expected_imports): class TestCodeAnalyzerAttributes: - @pytest.mark.parametrize( "code_snippet, expected_structure", [ @@ -259,3 +259,250 @@ def test_attribute_extraction(self, code_snippet: str, expected_structure: list) item["attributes"].sort(key=lambda x: x["name"]) assert extracted == expected_structure + + +# --- Mock Types --- +class MyClass: + pass + + +class AnotherClass: + pass + + +class YetAnotherClass: + pass + + +def test_codeanalyzer_finds_class(): + code = tw.dedent( + """ + class MyClass: + pass + """ + ) + analyzer = CodeAnalyzer() + tree = ast.parse(code) + analyzer.visit(tree) + assert len(analyzer.structure) == 1 + assert analyzer.structure[0]["class_name"] == "MyClass" + + +def test_codeanalyzer_finds_multiple_classes(): + code = tw.dedent( + """ + class ClassA: + pass + + + class ClassB: + pass + """ + ) + analyzer = CodeAnalyzer() + tree = ast.parse(code) + analyzer.visit(tree) + assert len(analyzer.structure) == 2 + class_names = sorted([c["class_name"] for c in analyzer.structure]) + assert class_names == ["ClassA", "ClassB"] + + +def test_codeanalyzer_finds_method(): + code = tw.dedent( + """ + class MyClass: + def my_method(self): + pass + """ + ) + analyzer = CodeAnalyzer() + tree = ast.parse(code) + analyzer.visit(tree) + assert len(analyzer.structure) == 1 + assert len(analyzer.structure[0]["methods"]) == 1 + assert analyzer.structure[0]["methods"][0]["method_name"] == "my_method" + + +def test_codeanalyzer_finds_multiple_methods(): + code = tw.dedent( + """ + class MyClass: + def method_a(self): + pass + + def method_b(self): + pass + """ + ) + analyzer = CodeAnalyzer() + tree = ast.parse(code) + analyzer.visit(tree) + assert len(analyzer.structure) == 1 + method_names = sorted([m["method_name"] for m in analyzer.structure[0]["methods"]]) + assert method_names == ["method_a", "method_b"] + + +def test_codeanalyzer_no_classes(): + code = tw.dedent( + """ + def top_level_function(): + pass + """ + ) + analyzer = CodeAnalyzer() + tree = ast.parse(code) + analyzer.visit(tree) + assert len(analyzer.structure) == 0 + + +def test_codeanalyzer_class_with_no_methods(): + code = tw.dedent( + """ + class MyClass: + attribute = 123 + """ + ) + analyzer = CodeAnalyzer() + tree = ast.parse(code) + analyzer.visit(tree) + assert len(analyzer.structure) == 1 + assert analyzer.structure[0]["class_name"] == "MyClass" + assert len(analyzer.structure[0]["methods"]) == 0 + + +# --- Test Data for Parameterization --- +TYPE_TEST_CASES = [ + pytest.param( + tw.dedent( + """ + class TestClass: + def func(self, a: int, b: str) -> bool: return True + """ + ), + [("a", "int"), ("b", "str")], + "bool", + id="simple_types", + ), + pytest.param( + tw.dedent( + """ + from typing import Optional + class TestClass: + def func(self, a: Optional[int]) -> str | None: return 'hello' + """ + ), + [("a", "Optional[int]")], + "str | None", + id="optional_union_none", + ), + pytest.param( + tw.dedent( + """ + from typing import Union + class TestClass: + def func(self, a: int | float, b: Union[str, bytes]) -> None: pass + """ + ), + [("a", "int | float"), ("b", "Union[str, bytes]")], + "None", + id="union_types", + ), + pytest.param( + tw.dedent( + """ + from typing import List, Dict, Tuple + class TestClass: + def func(self, a: List[int], b: Dict[str, float]) -> Tuple[int, str]: return (1, 'a') + """ + ), + [("a", "List[int]"), ("b", "Dict[str, float]")], + "Tuple[int, str]", + id="generic_types", + ), + pytest.param( + tw.dedent( + """ + import datetime + from scripts.microgenerator.tests.unit.test_generate_analyzer import MyClass + class TestClass: + def func(self, a: datetime.date, b: MyClass) -> MyClass: return b + """ + ), + [("a", "datetime.date"), ("b", "MyClass")], + "MyClass", + id="imported_types", + ), + pytest.param( + tw.dedent( + """ + from scripts.microgenerator.tests.unit.test_generate_analyzer import AnotherClass, YetAnotherClass + class TestClass: + def func(self, a: 'AnotherClass') -> 'YetAnotherClass': return AnotherClass() + """ + ), + [("a", "'AnotherClass'")], + "'YetAnotherClass'", + id="forward_refs", + ), + pytest.param( + tw.dedent( + """ + class TestClass: + def func(self, a, b): return a + b + """ + ), + [("a", None), ("b", None)], # No annotations means type is None + None, + id="no_annotations", + ), + pytest.param( + tw.dedent( + """ + from typing import List, Optional, Dict, Union, Any + class TestClass: + def func(self, a: List[Optional[Dict[str, Union[int, str]]]]) -> Dict[str, Any]: return {} + """ + ), + [("a", "List[Optional[Dict[str, Union[int, str]]]]")], + "Dict[str, Any]", + id="complex_nested", + ), + pytest.param( + tw.dedent( + """ + from typing import Literal + class TestClass: + def func(self, a: Literal['one', 'two']) -> Literal[True]: return True + """ + ), + [("a", "Literal['one', 'two']")], + "Literal[True]", + id="literal_type", + ), +] + + +class TestCodeAnalyzerArgsReturns: + @pytest.mark.parametrize( + "code_snippet, expected_args, expected_return", TYPE_TEST_CASES + ) + def test_type_extraction(self, code_snippet, expected_args, expected_return): + structure, imports, types = parse_code(code_snippet) + + assert len(structure) == 1, "Should parse one class" + class_info = structure[0] + assert class_info["class_name"] == "TestClass" + + assert len(class_info["methods"]) == 1, "Should find one method" + method_info = class_info["methods"][0] + assert method_info["method_name"] == "func" + + # Extract args, skipping 'self' + extracted_args = [] + for arg in method_info.get("args", []): + if arg["name"] == "self": + continue + extracted_args.append((arg["name"], arg["type"])) + + assert extracted_args == expected_args + assert method_info.get("return_type") == expected_return