diff --git a/scripts/microgenerator/generate.py b/scripts/microgenerator/generate.py index ccaff07a4..2fc356464 100644 --- a/scripts/microgenerator/generate.py +++ b/scripts/microgenerator/generate.py @@ -27,9 +27,9 @@ import argparse import glob import logging -import re from collections import defaultdict -from typing import List, Dict, Any, Iterator +from pathlib import Path +from typing import List, Dict, Any from . import name_utils from . import utils @@ -51,6 +51,7 @@ def __init__(self): self.types: set[str] = set() self._current_class_info: Dict[str, Any] | None = None self._is_in_method: bool = False + self._depth = 0 def _get_type_str(self, node: ast.AST | None) -> str | None: """Recursively reconstructs a type annotation string from an AST node.""" @@ -112,30 +113,32 @@ def _collect_types_from_node(self, node: ast.AST | None) -> None: def visit_Import(self, node: ast.Import) -> None: """Catches 'import X' and 'import X as Y' statements.""" - for alias in node.names: - if alias.asname: - self.imports.add(f"import {alias.name} as {alias.asname}") - else: - self.imports.add(f"import {alias.name}") + if self._depth == 0: # Only top-level imports + for alias in node.names: + if alias.asname: + self.imports.add(f"import {alias.name} as {alias.asname}") + else: + self.imports.add(f"import {alias.name}") self.generic_visit(node) def visit_ImportFrom(self, node: ast.ImportFrom) -> None: """Catches 'from X import Y' statements.""" - module = node.module or "" - if not module: - module = "." * node.level - else: - module = "." * node.level + module - - names = [] - for alias in node.names: - if alias.asname: - names.append(f"{alias.name} as {alias.asname}") + if self._depth == 0: # Only top-level imports + module = node.module or "" + if not module: + module = "." * node.level else: - names.append(alias.name) + module = "." * node.level + module + + names = [] + for alias in node.names: + if alias.asname: + names.append(f"{alias.name} as {alias.asname}") + else: + names.append(alias.name) - if names: - self.imports.add(f"from {module} import {', '.join(names)}") + if names: + self.imports.add(f"from {module} import {', '.join(names)}") self.generic_visit(node) def visit_ClassDef(self, node: ast.ClassDef) -> None: @@ -155,12 +158,15 @@ def visit_ClassDef(self, node: ast.ClassDef) -> None: self.structure.append(class_info) self._current_class_info = class_info + self._depth += 1 self.generic_visit(node) + self._depth -= 1 self._current_class_info = None def visit_FunctionDef(self, node: ast.FunctionDef) -> None: """Visits a function/method definition node.""" - if self._current_class_info: # This is a method + is_method = self._current_class_info is not None + if is_method: args_info = [] # Get default values @@ -189,10 +195,13 @@ def visit_FunctionDef(self, node: ast.FunctionDef) -> None: "return_type": return_type, } self._current_class_info["methods"].append(method_info) - - # Visit nodes inside the method to find instance attributes. self._is_in_method = True - self.generic_visit(node) + + self._depth += 1 + self.generic_visit(node) + self._depth -= 1 + + if is_method: self._is_in_method = False def _add_attribute(self, attr_name: str, attr_type: str | None = None): diff --git a/scripts/microgenerator/tests/unit/test_generate_analyzer.py b/scripts/microgenerator/tests/unit/test_generate_analyzer.py new file mode 100644 index 000000000..67e9f563d --- /dev/null +++ b/scripts/microgenerator/tests/unit/test_generate_analyzer.py @@ -0,0 +1,89 @@ +# -*- coding: utf-8 -*- +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import ast +import pytest +from scripts.microgenerator.generate import CodeAnalyzer + +# --- Tests CodeAnalyzer handling of Imports --- + + +class TestCodeAnalyzerImports: + @pytest.mark.parametrize( + "code_snippet, expected_imports", + [ + pytest.param( + "import os\nimport sys", + ["import os", "import sys"], + id="simple_imports", + ), + pytest.param( + "import numpy as np", + ["import numpy as np"], + id="aliased_import", + ), + pytest.param( + "from collections import defaultdict, OrderedDict", + ["from collections import defaultdict, OrderedDict"], + id="from_import_multiple", + ), + pytest.param( + "from typing import List as L", + ["from typing import List as L"], + id="from_import_aliased", + ), + pytest.param( + "from math import *", + ["from math import *"], + id="from_import_wildcard", + ), + pytest.param( + "import os.path", + ["import os.path"], + id="dotted_import", + ), + pytest.param( + "from google.cloud import bigquery", + ["from google.cloud import bigquery"], + id="from_dotted_module", + ), + pytest.param( + "", + [], + id="no_imports", + ), + pytest.param( + "class MyClass:\n import json # Should not be picked up", + [], + id="import_inside_class", + ), + pytest.param( + "def my_func():\n from time import sleep # Should not be picked up", + [], + id="import_inside_function", + ), + ], + ) + def test_import_extraction(self, code_snippet, expected_imports): + analyzer = CodeAnalyzer() + tree = ast.parse(code_snippet) + analyzer.visit(tree) + + # Normalize for comparison + extracted = sorted(list(analyzer.imports)) + expected = sorted(expected_imports) + + assert extracted == expected