diff --git a/scripts/microgenerator/generate.py b/scripts/microgenerator/generate.py index e3b346dd4..61f47d6c1 100644 --- a/scripts/microgenerator/generate.py +++ b/scripts/microgenerator/generate.py @@ -47,7 +47,7 @@ class CodeAnalyzer(ast.NodeVisitor): """ def __init__(self): - self.structure: List[Dict[str, Any]] = [] + self.analyzed_classes: List[Dict[str, Any]] = [] self.imports: set[str] = set() self.types: set[str] = set() self._current_class_info: Dict[str, Any] | None = None @@ -106,13 +106,19 @@ def _collect_types_from_node(self, node: ast.AST | None) -> None: if type_str: self.types.add(type_str) elif isinstance(node, ast.Subscript): - self._collect_types_from_node(node.value) + # Add the base type of the subscript (e.g., "List", "Dict") + if isinstance(node.value, ast.Name): + self.types.add(node.value.id) + self._collect_types_from_node(node.value) # Recurse on value just in case self._collect_types_from_node(node.slice) elif isinstance(node, (ast.Tuple, ast.List)): for elt in node.elts: self._collect_types_from_node(elt) - elif isinstance(node, ast.Constant) and isinstance(node.value, str): - self.types.add(node.value) + elif isinstance(node, ast.Constant): + if isinstance(node.value, str): # Forward references + self.types.add(node.value) + elif node.value is None: # None type + self.types.add("None") elif isinstance(node, ast.BinOp) and isinstance( node.op, ast.BitOr ): # For | union type @@ -164,7 +170,7 @@ def visit_ClassDef(self, node: ast.ClassDef) -> None: type_str = self._get_type_str(item.annotation) class_info["attributes"].append({"name": attr_name, "type": type_str}) - self.structure.append(class_info) + self.analyzed_classes.append(class_info) self._current_class_info = class_info self._depth += 1 self.generic_visit(node) @@ -260,6 +266,7 @@ def visit_AnnAssign(self, node: ast.AnnAssign) -> None: # directly within the class body, not inside a method. elif isinstance(target, ast.Name) and not self._is_in_method: self._add_attribute(target.id, self._get_type_str(node.annotation)) + self._collect_types_from_node(node.annotation) self.generic_visit(node) @@ -280,7 +287,7 @@ def parse_code(code: str) -> tuple[List[Dict[str, Any]], set[str], set[str]]: tree = ast.parse(code) analyzer = CodeAnalyzer() analyzer.visit(tree) - return analyzer.structure, analyzer.imports, analyzer.types + return analyzer.analyzed_classes, analyzer.imports, analyzer.types def parse_file(file_path: str) -> tuple[List[Dict[str, Any]], set[str], set[str]]: @@ -332,10 +339,10 @@ def list_code_objects( all_class_keys = [] def process_structure( - structure: List[Dict[str, Any]], file_name: str | None = None + analyzed_classes: List[Dict[str, Any]], file_name: str | None = None ): """Populates the results dictionary from the parsed AST structure.""" - for class_info in structure: + for class_info in analyzed_classes: key = class_info["class_name"] if file_name: key = f"{key} (in {file_name})" @@ -361,13 +368,13 @@ def process_structure( # Determine if the path is a file or directory and process accordingly if os.path.isfile(path) and path.endswith(".py"): - structure, _, _ = parse_file(path) - process_structure(structure) + analyzed_classes, _, _ = parse_file(path) + process_structure(analyzed_classes) elif os.path.isdir(path): # This assumes `utils.walk_codebase` is defined elsewhere. for file_path in utils.walk_codebase(path): - structure, _, _ = parse_file(file_path) - process_structure(structure, file_name=os.path.basename(file_path)) + analyzed_classes, _, _ = parse_file(file_path) + process_structure(analyzed_classes, file_name=os.path.basename(file_path)) # Return the data in the desired format based on the flags if not show_methods and not show_attributes: @@ -419,11 +426,11 @@ def _build_request_arg_schema( module_name = os.path.splitext(relative_path)[0].replace(os.path.sep, ".") try: - structure, _, _ = parse_file(file_path) - if not structure: + analyzed_classes, _, _ = parse_file(file_path) + if not analyzed_classes: continue - for class_info in structure: + for class_info in analyzed_classes: class_name = class_info.get("class_name", "Unknown") if class_name.endswith("Request"): full_class_name = f"{module_name}.{class_name}" @@ -451,11 +458,11 @@ def _process_service_clients( if "/services/" not in file_path: continue - structure, imports, types = parse_file(file_path) + analyzed_classes, imports, types = parse_file(file_path) all_imports.update(imports) all_types.update(types) - for class_info in structure: + for class_info in analyzed_classes: class_name = class_info["class_name"] if not _should_include_class(class_name, class_filters): continue diff --git a/scripts/microgenerator/noxfile.py b/scripts/microgenerator/noxfile.py index 9f4cd7e2c..065ff192e 100644 --- a/scripts/microgenerator/noxfile.py +++ b/scripts/microgenerator/noxfile.py @@ -16,7 +16,6 @@ from functools import wraps import pathlib -import os import nox import time @@ -26,7 +25,7 @@ BLACK_VERSION = "black==23.7.0" BLACK_PATHS = (".",) -DEFAULT_PYTHON_VERSION = "3.9" +DEFAULT_PYTHON_VERSION = "3.13" UNIT_TEST_PYTHON_VERSIONS = ["3.9", "3.11", "3.12", "3.13"] CURRENT_DIRECTORY = pathlib.Path(__file__).parent.absolute() @@ -190,9 +189,8 @@ def lint(session): session.install("flake8", BLACK_VERSION) session.install("-e", ".") session.run("python", "-m", "pip", "freeze") - session.run("flake8", os.path.join("scripts")) + session.run("flake8", ".") session.run("flake8", "tests") - session.run("flake8", "benchmark") session.run("black", "--check", *BLACK_PATHS) diff --git a/scripts/microgenerator/tests/unit/test_generate_analyzer.py b/scripts/microgenerator/tests/unit/test_generate_analyzer.py index 71ee37bf4..77aca9bf1 100644 --- a/scripts/microgenerator/tests/unit/test_generate_analyzer.py +++ b/scripts/microgenerator/tests/unit/test_generate_analyzer.py @@ -95,7 +95,7 @@ def test_import_extraction(self, code_snippet, expected_imports): class TestCodeAnalyzerAttributes: @pytest.mark.parametrize( - "code_snippet, expected_structure", + "code_snippet, expected_analyzed_classes", [ pytest.param( """ @@ -243,22 +243,24 @@ def __init__(self): ), ], ) - def test_attribute_extraction(self, code_snippet: str, expected_structure: list): + def test_attribute_extraction( + self, code_snippet: str, expected_analyzed_classes: list + ): """Tests the extraction of class and instance attributes.""" analyzer = CodeAnalyzer() tree = ast.parse(code_snippet) analyzer.visit(tree) - extracted = analyzer.structure + extracted = analyzer.analyzed_classes # Normalize attributes for order-independent comparison for item in extracted: if "attributes" in item: item["attributes"].sort(key=lambda x: x["name"]) - for item in expected_structure: + for item in expected_analyzed_classes: if "attributes" in item: item["attributes"].sort(key=lambda x: x["name"]) - assert extracted == expected_structure + assert extracted == expected_analyzed_classes # --- Mock Types --- @@ -284,8 +286,8 @@ class MyClass: analyzer = CodeAnalyzer() tree = ast.parse(code) analyzer.visit(tree) - assert len(analyzer.structure) == 1 - assert analyzer.structure[0]["class_name"] == "MyClass" + assert len(analyzer.analyzed_classes) == 1 + assert analyzer.analyzed_classes[0]["class_name"] == "MyClass" def test_codeanalyzer_finds_multiple_classes(): @@ -302,8 +304,8 @@ class ClassB: analyzer = CodeAnalyzer() tree = ast.parse(code) analyzer.visit(tree) - assert len(analyzer.structure) == 2 - class_names = sorted([c["class_name"] for c in analyzer.structure]) + assert len(analyzer.analyzed_classes) == 2 + class_names = sorted([c["class_name"] for c in analyzer.analyzed_classes]) assert class_names == ["ClassA", "ClassB"] @@ -318,9 +320,9 @@ def my_method(self): analyzer = CodeAnalyzer() tree = ast.parse(code) analyzer.visit(tree) - assert len(analyzer.structure) == 1 - assert len(analyzer.structure[0]["methods"]) == 1 - assert analyzer.structure[0]["methods"][0]["method_name"] == "my_method" + assert len(analyzer.analyzed_classes) == 1 + assert len(analyzer.analyzed_classes[0]["methods"]) == 1 + assert analyzer.analyzed_classes[0]["methods"][0]["method_name"] == "my_method" def test_codeanalyzer_finds_multiple_methods(): @@ -337,8 +339,8 @@ def method_b(self): analyzer = CodeAnalyzer() tree = ast.parse(code) analyzer.visit(tree) - assert len(analyzer.structure) == 1 - method_names = sorted([m["method_name"] for m in analyzer.structure[0]["methods"]]) + assert len(analyzer.analyzed_classes) == 1 + method_names = sorted([m["method_name"] for m in analyzer.analyzed_classes[0]["methods"]]) assert method_names == ["method_a", "method_b"] @@ -352,7 +354,7 @@ def top_level_function(): analyzer = CodeAnalyzer() tree = ast.parse(code) analyzer.visit(tree) - assert len(analyzer.structure) == 0 + assert len(analyzer.analyzed_classes) == 0 def test_codeanalyzer_class_with_no_methods(): @@ -365,9 +367,9 @@ class MyClass: analyzer = CodeAnalyzer() tree = ast.parse(code) analyzer.visit(tree) - assert len(analyzer.structure) == 1 - assert analyzer.structure[0]["class_name"] == "MyClass" - assert len(analyzer.structure[0]["methods"]) == 0 + assert len(analyzer.analyzed_classes) == 1 + assert analyzer.analyzed_classes[0]["class_name"] == "MyClass" + assert len(analyzer.analyzed_classes[0]["methods"]) == 0 # --- Test Data for Parameterization --- @@ -487,10 +489,10 @@ class TestCodeAnalyzerArgsReturns: "code_snippet, expected_args, expected_return", TYPE_TEST_CASES ) def test_type_extraction(self, code_snippet, expected_args, expected_return): - structure, imports, types = parse_code(code_snippet) + analyzed_classes, imports, types = parse_code(code_snippet) - assert len(structure) == 1, "Should parse one class" - class_info = structure[0] + assert len(analyzed_classes) == 1, "Should parse one class" + class_info = analyzed_classes[0] assert class_info["class_name"] == "TestClass" assert len(class_info["methods"]) == 1, "Should find one method" @@ -506,3 +508,4 @@ def test_type_extraction(self, code_snippet, expected_args, expected_return): assert extracted_args == expected_args assert method_info.get("return_type") == expected_return + diff --git a/scripts/microgenerator/tests/unit/test_generate_parser.py b/scripts/microgenerator/tests/unit/test_generate_parser.py new file mode 100644 index 000000000..610d2dfc1 --- /dev/null +++ b/scripts/microgenerator/tests/unit/test_generate_parser.py @@ -0,0 +1,158 @@ +# -*- coding: utf-8 -*- +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import pytest +from unittest import mock +import textwrap as tw + +from scripts.microgenerator import generate + + +# --- Tests for parse_code() --- +def test_parse_code_empty(): + analyzed_classes, imports, types = generate.parse_code("") + assert analyzed_classes == [] + assert imports == set() + assert types == set() + + +def test_parse_code_simple_class(): + code = tw.dedent( + """ + class MyClass: + pass + """ + ) + analyzed_classes, _, _ = generate.parse_code(code) + assert len(analyzed_classes) == 1 + assert analyzed_classes[0]["class_name"] == "MyClass" + + +def test_parse_code_simple_function(): + code = tw.dedent( + """ + def my_function(): + pass + """ + ) + + # In the microgenerator, the focus is parsing major classes (and their + # associated methods in the GAPIC generated code, not parsing top-level + # functions. Thus we do not expect it to capture this top-level function. + analyzed_classes, _, _ = generate.parse_code(code) + assert len(analyzed_classes) == 0 + + +def test_parse_code_invalid_syntax(): + with pytest.raises(SyntaxError): + # incorrect indentation and missing trailing colon on func definition. + code = tw.dedent( + """ + class MyClass: + pass + def func() + pass + """ + ) + generate.parse_code(code) + + +def test_parse_code_with_imports_and_types(): + code = tw.dedent( + """ + import os + import sys as system + from typing import List, Optional, Dict + from . import my_module + + class MyClass: + attr: Dict[str, int] + def method(self, x: List[str]) -> Optional[int]: + return None + def method2(self, y: 'MyClass') -> None: + pass + """ + ) + analyzed_classes, imports, types = generate.parse_code(code) + + expected_imports = { + "import os", + "import sys as system", + "from typing import List, Optional, Dict", + "from . import my_module", + } + assert imports == expected_imports + + expected_types = { + "Dict", + "str", + "int", + "List", + "Optional", + "MyClass", + "None", + } + assert types == expected_types + assert len(analyzed_classes) == 1 + + +# --- Tests for parse_file() --- +# parse_file() wraps parse_code() and simply reads in content from a file +# as a string using the built in open() function and passes the string intact +# to parse_code(). +@mock.patch("builtins.open", new_callable=mock.mock_open) +def test_parse_file_reads_and_parses(mock_file): + read_data = tw.dedent( + """ + class TestClass: + pass + """ + ) + mock_file.return_value.read.return_value = read_data + analyzed_classes, _, _ = generate.parse_file("dummy/path/file.py") + mock_file.assert_called_once_with("dummy/path/file.py", "r", encoding="utf-8") + assert len(analyzed_classes) == 1 + assert analyzed_classes[0]["class_name"] == "TestClass" + + +@mock.patch("builtins.open", side_effect=FileNotFoundError) +def test_parse_file_not_found(mock_file): + with pytest.raises(FileNotFoundError): + generate.parse_file("nonexistent.py") + mock_file.assert_called_once_with("nonexistent.py", "r", encoding="utf-8") + + +@mock.patch("builtins.open", new_callable=mock.mock_open) +def test_parse_file_syntax_error(mock_file): + mock_file.return_value.read.return_value = "a = (" + with pytest.raises(SyntaxError): + generate.parse_file("syntax_error.py") + mock_file.assert_called_once_with("syntax_error.py", "r", encoding="utf-8") + + +@mock.patch( + "scripts.microgenerator.generate.parse_code", return_value=([], set(), set()) +) +@mock.patch("builtins.open", new_callable=mock.mock_open) +def test_parse_file_calls_parse_code(mock_file, mock_parse_code): + """This test simply confirms that parse_code() gets called internally. + + Other parse_code tests ensure that it works as expected. + """ + read_data = "some code" + mock_file.return_value.read.return_value = read_data + generate.parse_file("some_file.py") + mock_parse_code.assert_called_once_with(read_data)