Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 33 additions & 24 deletions scripts/microgenerator/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@
import argparse
import glob
import logging
import re
from collections import defaultdict
from typing import List, Dict, Any, Iterator
from pathlib import Path
from typing import List, Dict, Any

from . import name_utils
from . import utils
Expand All @@ -51,6 +51,7 @@ def __init__(self):
self.types: set[str] = set()
self._current_class_info: Dict[str, Any] | None = None
self._is_in_method: bool = False
self._depth = 0

def _get_type_str(self, node: ast.AST | None) -> str | None:
"""Recursively reconstructs a type annotation string from an AST node."""
Expand Down Expand Up @@ -112,30 +113,32 @@ def _collect_types_from_node(self, node: ast.AST | None) -> None:

def visit_Import(self, node: ast.Import) -> None:
"""Catches 'import X' and 'import X as Y' statements."""
for alias in node.names:
if alias.asname:
self.imports.add(f"import {alias.name} as {alias.asname}")
else:
self.imports.add(f"import {alias.name}")
if self._depth == 0: # Only top-level imports
for alias in node.names:
if alias.asname:
self.imports.add(f"import {alias.name} as {alias.asname}")
else:
self.imports.add(f"import {alias.name}")
self.generic_visit(node)

def visit_ImportFrom(self, node: ast.ImportFrom) -> None:
"""Catches 'from X import Y' statements."""
module = node.module or ""
if not module:
module = "." * node.level
else:
module = "." * node.level + module

names = []
for alias in node.names:
if alias.asname:
names.append(f"{alias.name} as {alias.asname}")
if self._depth == 0: # Only top-level imports
module = node.module or ""
if not module:
module = "." * node.level
else:
names.append(alias.name)
module = "." * node.level + module

names = []
for alias in node.names:
if alias.asname:
names.append(f"{alias.name} as {alias.asname}")
else:
names.append(alias.name)

if names:
self.imports.add(f"from {module} import {', '.join(names)}")
if names:
self.imports.add(f"from {module} import {', '.join(names)}")
self.generic_visit(node)

def visit_ClassDef(self, node: ast.ClassDef) -> None:
Expand All @@ -155,12 +158,15 @@ def visit_ClassDef(self, node: ast.ClassDef) -> None:

self.structure.append(class_info)
self._current_class_info = class_info
self._depth += 1
self.generic_visit(node)
self._depth -= 1
self._current_class_info = None

def visit_FunctionDef(self, node: ast.FunctionDef) -> None:
"""Visits a function/method definition node."""
if self._current_class_info: # This is a method
is_method = self._current_class_info is not None
if is_method:
args_info = []

# Get default values
Expand Down Expand Up @@ -189,10 +195,13 @@ def visit_FunctionDef(self, node: ast.FunctionDef) -> None:
"return_type": return_type,
}
self._current_class_info["methods"].append(method_info)

# Visit nodes inside the method to find instance attributes.
self._is_in_method = True
self.generic_visit(node)

self._depth += 1
self.generic_visit(node)
self._depth -= 1

if is_method:
self._is_in_method = False

def _add_attribute(self, attr_name: str, attr_type: str | None = None):
Expand Down
89 changes: 89 additions & 0 deletions scripts/microgenerator/tests/unit/test_generate_analyzer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
# -*- coding: utf-8 -*-
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import ast
import pytest
from scripts.microgenerator.generate import CodeAnalyzer

# --- Tests CodeAnalyzer handling of Imports ---


class TestCodeAnalyzerImports:
@pytest.mark.parametrize(
"code_snippet, expected_imports",
[
pytest.param(
"import os\nimport sys",
["import os", "import sys"],
id="simple_imports",
),
pytest.param(
"import numpy as np",
["import numpy as np"],
id="aliased_import",
),
pytest.param(
"from collections import defaultdict, OrderedDict",
["from collections import defaultdict, OrderedDict"],
id="from_import_multiple",
),
pytest.param(
"from typing import List as L",
["from typing import List as L"],
id="from_import_aliased",
),
pytest.param(
"from math import *",
["from math import *"],
id="from_import_wildcard",
),
pytest.param(
"import os.path",
["import os.path"],
id="dotted_import",
),
pytest.param(
"from google.cloud import bigquery",
["from google.cloud import bigquery"],
id="from_dotted_module",
),
pytest.param(
"",
[],
id="no_imports",
),
pytest.param(
"class MyClass:\n import json # Should not be picked up",
[],
id="import_inside_class",
),
pytest.param(
"def my_func():\n from time import sleep # Should not be picked up",
[],
id="import_inside_function",
),
],
)
def test_import_extraction(self, code_snippet, expected_imports):
analyzer = CodeAnalyzer()
tree = ast.parse(code_snippet)
analyzer.visit(tree)

# Normalize for comparison
extracted = sorted(list(analyzer.imports))
expected = sorted(expected_imports)

assert extracted == expected
Loading