diff --git a/.chronus/changes/python-sample-test-generation-optimization-2026-0-29-6-46-35.md b/.chronus/changes/python-sample-test-generation-optimization-2026-0-29-6-46-35.md new file mode 100644 index 00000000000..490c7cbe940 --- /dev/null +++ b/.chronus/changes/python-sample-test-generation-optimization-2026-0-29-6-46-35.md @@ -0,0 +1,7 @@ +--- +changeKind: internal +packages: + - "@typespec/http-client-python" +--- + +Optimize sdk generation performance \ No newline at end of file diff --git a/packages/http-client-python/emitter/src/emitter.ts b/packages/http-client-python/emitter/src/emitter.ts index 9534c1ec99f..2916a55f6d4 100644 --- a/packages/http-client-python/emitter/src/emitter.ts +++ b/packages/http-client-python/emitter/src/emitter.ts @@ -239,7 +239,7 @@ async function onEmitMain(context: EmitContext) { ".mypy_cache", ".pytest_cache", ".vscode", - "_build", + ".*_build/", "/build/", "dist", ".nox", diff --git a/packages/http-client-python/generator/pygen/codegen/serializers/__init__.py b/packages/http-client-python/generator/pygen/codegen/serializers/__init__.py index e7654299885..6a16ed93e48 100644 --- a/packages/http-client-python/generator/pygen/codegen/serializers/__init__.py +++ b/packages/http-client-python/generator/pygen/codegen/serializers/__init__.py @@ -34,11 +34,7 @@ from .test_serializer import TestSerializer, TestGeneralSerializer from .types_serializer import TypesSerializer from ...utils import to_snake_case, VALID_PACKAGE_MODE -from .utils import ( - extract_sample_name, - get_namespace_from_package_name, - get_namespace_config, -) +from .utils import extract_sample_name, get_namespace_from_package_name, get_namespace_config, hash_file_import _LOGGER = logging.getLogger(__name__) @@ -536,39 +532,72 @@ def sample_additional_folder(self) -> Path: def _generated_tests_samples_folder(self, folder_name: str) -> Path: return self._root_of_sdk / folder_name + def _process_operation_samples( + self, + samples: dict, + env: Environment, + op_group, + operation, + import_sample_cache: dict[tuple[str, str], str], + out_path: Path, + sample_additional_folder: Path, + ) -> None: + """Process samples for a single operation.""" + for sample_value in samples.values(): + file = sample_value.get("x-ms-original-file", "sample.json") + file_name = to_snake_case(extract_sample_name(file)) + ".py" + try: + sample_ser = SampleSerializer( + code_model=self.code_model, + env=env, + operation_group=op_group, + operation=operation, + sample=sample_value, + file_name=file_name, + ) + file_import = sample_ser.get_file_import() + imports_hash_string = hash_file_import(file_import) + cache_key = (op_group.client.client_namespace, imports_hash_string) + if cache_key not in import_sample_cache: + import_sample_cache[cache_key] = sample_ser.get_imports_from_file_import(file_import) + sample_ser.imports = import_sample_cache[cache_key] + + content = sample_ser.serialize() + output_path = out_path / sample_additional_folder / _sample_output_path(file) / file_name + self.write_file(output_path, content) + except Exception as e: # pylint: disable=broad-except + _LOGGER.error("error happens in sample %s: %s", file, e) + def _serialize_and_write_sample(self, env: Environment): out_path = self._generated_tests_samples_folder("generated_samples") + sample_additional_folder = self.sample_additional_folder + + # Cache import_test per (client_namespace, imports_hash_string) since it's expensive to compute + import_sample_cache: dict[tuple[str, str], str] = {} + for client in self.code_model.clients: for op_group in client.operation_groups: for operation in op_group.operations: samples = operation.yaml_data.get("samples") if not samples or operation.name.startswith("_"): continue - for value in samples.values(): - file = value.get("x-ms-original-file", "sample.json") - file_name = to_snake_case(extract_sample_name(file)) + ".py" - try: - self.write_file( - out_path / self.sample_additional_folder / _sample_output_path(file) / file_name, - SampleSerializer( - code_model=self.code_model, - env=env, - operation_group=op_group, - operation=operation, - sample=value, - file_name=file_name, - ).serialize(), - ) - except Exception as e: # pylint: disable=broad-except - # sample generation shall not block code generation, so just log error - log_error = f"error happens in sample {file}: {e}" - _LOGGER.error(log_error) + self._process_operation_samples( + samples, + env, + op_group, + operation, + import_sample_cache, + out_path, + sample_additional_folder, + ) def _serialize_and_write_test(self, env: Environment): self.code_model.for_test = True out_path = self._generated_tests_samples_folder("generated_tests") + general_serializer = TestGeneralSerializer(code_model=self.code_model, env=env) self.write_file(out_path / "conftest.py", general_serializer.serialize_conftest()) + if not self.code_model.options["azure-arm"]: for async_mode in (True, False): async_suffix = "_async" if async_mode else "" @@ -578,18 +607,24 @@ def _serialize_and_write_test(self, env: Environment): general_serializer.serialize_testpreparer(), ) + # Generate test files - reuse serializer per operation group, toggle async_mode + # Cache import_test per (client.name, async_mode) since it's expensive to compute + import_test_cache: dict[tuple[str, bool], str] = {} for client in self.code_model.clients: for og in client.operation_groups: + # Create serializer once per operation group test_serializer = TestSerializer(self.code_model, env, client=client, operation_group=og) - for async_mode in (True, False): - try: + try: + for async_mode in (True, False): test_serializer.async_mode = async_mode - self.write_file( - out_path / f"{to_snake_case(test_serializer.test_class_name)}.py", - test_serializer.serialize_test(), - ) - except Exception as e: # pylint: disable=broad-except - # test generation shall not block code generation, so just log error - log_error = f"error happens in test generation for operation group {og.class_name}: {e}" - _LOGGER.error(log_error) + cache_key = (client.name, async_mode) + if cache_key not in import_test_cache: + import_test_cache[cache_key] = test_serializer.get_import_test() + test_serializer.import_test = import_test_cache[cache_key] + content = test_serializer.serialize_test() + output_path = out_path / f"{to_snake_case(test_serializer.test_class_name)}.py" + self.write_file(output_path, content) + except Exception as e: # pylint: disable=broad-except + _LOGGER.error("error happens in test generation for operation group %s: %s", og.class_name, e) + self.code_model.for_test = False diff --git a/packages/http-client-python/generator/pygen/codegen/serializers/sample_serializer.py b/packages/http-client-python/generator/pygen/codegen/serializers/sample_serializer.py index 3fd57800168..733209b8396 100644 --- a/packages/http-client-python/generator/pygen/codegen/serializers/sample_serializer.py +++ b/packages/http-client-python/generator/pygen/codegen/serializers/sample_serializer.py @@ -20,6 +20,7 @@ BodyParameter, FileImport, ) +from .utils import create_fake_value _LOGGER = logging.getLogger(__name__) @@ -40,8 +41,17 @@ def __init__( self.sample = sample self.file_name = file_name self.sample_params = sample.get("parameters", {}) + self._imports: str = "" - def _imports(self) -> FileImportSerializer: + @property + def imports(self) -> str: + return self._imports + + @imports.setter + def imports(self, value: str) -> None: + self._imports = value + + def get_file_import(self) -> FileImport: imports = FileImport(self.code_model) client = self.operation_group.client namespace = client.client_namespace @@ -59,7 +69,12 @@ def _imports(self) -> FileImportSerializer: for param in self.operation.parameters.positional + self.operation.parameters.keyword_only: if param.client_default_value is None and not param.optional and param.wire_name in self.sample_params: imports.merge(param.type.imports_for_sample()) - return FileImportSerializer(imports, True) + + return imports + + @staticmethod + def get_imports_from_file_import(file_import: FileImport) -> str: + return str(FileImportSerializer(file_import, True)) def _client_params(self) -> dict[str, Any]: # client params @@ -97,19 +112,14 @@ def handle_param(param: Union[Parameter, BodyParameter], param_value: Any) -> st # prepare operation parameters def _operation_params(self) -> dict[str, Any]: - params = [ - p - for p in (self.operation.parameters.positional + self.operation.parameters.keyword_only) - if not p.client_default_value - ] - failure_info = "fail to find required param named {}" operation_params = {} - for param in params: - if not param.optional: + for param in self.operation.parameters.positional + self.operation.parameters.keyword_only: + if not param.optional and not param.client_default_value: param_value = self.sample_params.get(param.wire_name) if not param_value: - raise Exception(failure_info.format(param.client_name)) # pylint: disable=broad-exception-raised - operation_params[param.client_name] = self.handle_param(param, param_value) + operation_params[param.client_name] = create_fake_value(param.type) + else: + operation_params[param.client_name] = self.handle_param(param, param_value) return operation_params def _operation_group_name(self) -> str: @@ -154,7 +164,7 @@ def serialize(self) -> str: operation_params=self._operation_params(), operation_group_name=self._operation_group_name(), operation_name=self._operation_name(), - imports=self._imports(), + imports=self.imports, client_params=self._client_params(), origin_file=self._origin_file(), return_var=return_var, diff --git a/packages/http-client-python/generator/pygen/codegen/serializers/test_serializer.py b/packages/http-client-python/generator/pygen/codegen/serializers/test_serializer.py index 73baf5dd8d2..410b45cfecb 100644 --- a/packages/http-client-python/generator/pygen/codegen/serializers/test_serializer.py +++ b/packages/http-client-python/generator/pygen/codegen/serializers/test_serializer.py @@ -3,7 +3,7 @@ # Licensed under the MIT License. See License.txt in the project root for # license information. # -------------------------------------------------------------------------- -from typing import Any, Optional +from typing import Any from jinja2 import Environment from .import_serializer import FileImportSerializer @@ -14,12 +14,9 @@ OperationGroup, Client, OperationType, - ModelType, - BaseType, - CombinedType, FileImport, ) -from .utils import json_dumps_template +from .utils import create_fake_value def is_lro(operation_type: str) -> bool: @@ -39,30 +36,15 @@ def __init__(self, code_model: CodeModel, client_name: str, *, async_mode: bool self.code_model = code_model self.client_name = client_name self.async_mode = async_mode - - @property - def async_suffix_capt(self) -> str: - return "Async" if self.async_mode else "" - - @property - def create_client_name(self) -> str: - return "create_async_client" if self.async_mode else "create_client" - - @property - def prefix(self) -> str: - return self.client_name.replace("Client", "") - - @property - def preparer_name(self) -> str: - if self.code_model.options["azure-arm"]: - return "RandomNameResourceGroupPreparer" - return self.prefix + "Preparer" - - @property - def base_test_class_name(self) -> str: - if self.code_model.options["azure-arm"]: - return "AzureMgmtRecordedTestCase" - return f"{self.client_name}TestBase{self.async_suffix_capt}" + # Pre-compute values for render speed optimization + self.async_suffix_capt = "Async" if async_mode else "" + self.create_client_name = "create_async_client" if async_mode else "create_client" + self.prefix = client_name.replace("Client", "") + is_azure_arm = code_model.options["azure-arm"] + self.preparer_name = "RandomNameResourceGroupPreparer" if is_azure_arm else self.prefix + "Preparer" + self.base_test_class_name = ( + "AzureMgmtRecordedTestCase" if is_azure_arm else f"{client_name}TestBase{self.async_suffix_capt}" + ) class TestCase: @@ -73,50 +55,52 @@ def __init__( operation: OperationType, *, async_mode: bool = False, + is_azure_arm: bool = False, ) -> None: self.operation_groups = operation_groups - self.params = params self.operation = operation self.async_mode = async_mode - - @property - def name(self) -> str: - if self.operation_groups[-1].is_mixin: - return self.operation.name - return "_".join([og.property_name for og in self.operation_groups] + [self.operation.name]) - - @property - def operation_group_prefix(self) -> str: - if self.operation_groups[-1].is_mixin: - return "" - return "." + ".".join([og.property_name for og in self.operation_groups]) - - @property - def response(self) -> str: - if self.async_mode: - if is_lro(self.operation.operation_type): - return "response = await (await " - if is_common_operation(self.operation.operation_type): - return "response = await " - return "response = " - - @property - def lro_comment(self) -> str: - return " # call '.result()' to poll until service return final result" - - @property - def operation_suffix(self) -> str: - if is_lro(self.operation.operation_type): - extra = ")" if self.async_mode else "" - return f"{extra}.result(){self.lro_comment}" - return "" - - @property - def extra_operation(self) -> str: - if is_paging(self.operation.operation_type): - async_str = "async " if self.async_mode else "" - return f"result = [r {async_str}for r in response]" - return "" + self.is_azure_arm = is_azure_arm + # Pre-compute params + if is_azure_arm: + self.params = {k: ("resource_group.name" if k == "resource_group_name" else v) for k, v in params.items()} + else: + self.params = params + # Pre-compute name + if operation_groups[-1].is_mixin: + self.name = operation.name + else: + self.name = "_".join([og.property_name for og in operation_groups] + [operation.name]) + # Pre-compute operation_group_prefix + if operation_groups[-1].is_mixin: + self.operation_group_prefix = "" + else: + self.operation_group_prefix = "." + ".".join([og.property_name for og in operation_groups]) + # Pre-compute response + operation_type = operation.operation_type + if async_mode: + if is_lro(operation_type): + self.response = "response = await (await " + elif is_common_operation(operation_type): + self.response = "response = await " + else: + self.response = "response = " + else: + self.response = "response = " + # Pre-compute lro_comment + self.lro_comment = " # call '.result()' to poll until service return final result" + # Pre-compute operation_suffix + if is_lro(operation_type): + extra = ")" if async_mode else "" + self.operation_suffix = f"{extra}.result(){self.lro_comment}" + else: + self.operation_suffix = "" + # Pre-compute extra_operation + if is_paging(operation_type): + async_str = "async " if async_mode else "" + self.extra_operation = f"result = [r {async_str}for r in response]" + else: + self.extra_operation = "" class Test(TestName): @@ -189,9 +173,17 @@ def __init__( super().__init__(code_model, env, async_mode=async_mode) self.client = client self.operation_group = operation_group + self._import_test: str = "" @property - def import_test(self) -> FileImportSerializer: + def import_test(self) -> str: + return self._import_test + + @import_test.setter + def import_test(self, value: str) -> None: + self._import_test = value + + def get_import_test(self) -> str: imports = self.init_file_import() test_name = TestName(self.code_model, self.client.name, async_mode=self.async_mode) async_suffix = "_async" if self.async_mode else "" @@ -212,7 +204,7 @@ def import_test(self) -> FileImportSerializer: ) if self.code_model.options["azure-arm"]: self.add_import_client(imports) - return FileImportSerializer(imports, self.async_mode) + return str(FileImportSerializer(imports, self.async_mode)) @property def breadth_search_operation_group(self) -> list[list[OperationGroup]]: @@ -226,26 +218,11 @@ def breadth_search_operation_group(self) -> list[list[OperationGroup]]: queue.extend([current + [og] for og in current[-1].operation_groups]) return result - def get_sub_type(self, param_type: ModelType) -> ModelType: - if param_type.discriminated_subtypes: - for item in param_type.discriminated_subtypes.values(): - return self.get_sub_type(item) - return param_type - - def get_model_type(self, param_type: BaseType) -> Optional[ModelType]: - if isinstance(param_type, ModelType): - return param_type - if isinstance(param_type, CombinedType): - return param_type.target_model_subtype((ModelType,)) - return None - def get_operation_params(self, operation: OperationType) -> dict[str, Any]: operation_params = {} required_params = [p for p in operation.parameters.method if not p.optional] for param in required_params: - model_type = self.get_model_type(param.type) - param_type = self.get_sub_type(model_type) if model_type else param.type - operation_params[param.client_name] = json_dumps_template(param_type.get_json_template_representation()) + operation_params[param.client_name] = create_fake_value(param.type) return operation_params def get_test(self) -> Test: @@ -260,6 +237,7 @@ def get_test(self) -> Test: params=operation_params, operation=operation, async_mode=self.async_mode, + is_azure_arm=self.code_model.options["azure-arm"], ) testcases.append(testcase) if not testcases: @@ -283,6 +261,7 @@ def test_class_name(self) -> str: def serialize_test(self) -> str: return self.env.get_template("test.py.jinja2").render( imports=self.import_test, - code_model=self.code_model, + is_azure_arm=self.code_model.options["azure-arm"], + license_header=self.code_model.license_header, test=self.get_test(), ) diff --git a/packages/http-client-python/generator/pygen/codegen/serializers/utils.py b/packages/http-client-python/generator/pygen/codegen/serializers/utils.py index 9ea6c85c77b..52ee4d62e57 100644 --- a/packages/http-client-python/generator/pygen/codegen/serializers/utils.py +++ b/packages/http-client-python/generator/pygen/codegen/serializers/utils.py @@ -7,6 +7,15 @@ from typing import Optional, Any from pathlib import Path +from ..models import ModelType, BaseType, CombinedType, FileImport + + +def get_sub_type(param_type: ModelType) -> ModelType: + if param_type.discriminated_subtypes: + for item in param_type.discriminated_subtypes.values(): + return get_sub_type(item) + return param_type + def method_signature_and_response_type_annotation_template( *, @@ -52,3 +61,31 @@ def _improve_json_string(template_representation: str) -> Any: def json_dumps_template(template_representation: Any) -> Any: # only for template use, since it wraps everything in strings return _improve_json_string(json.dumps(template_representation, indent=4)) + + +def create_fake_value(param_type: BaseType) -> Any: + """Create a fake value for a parameter type by getting its JSON template representation. + + This function generates a fake value suitable for samples and tests. + + :param param_type: The parameter type to create a fake value for. + :return: A string representation of the fake value. + """ + + model_type: Optional[ModelType] = None + if isinstance(param_type, ModelType): + model_type = param_type + elif isinstance(param_type, CombinedType): + model_type = param_type.target_model_subtype((ModelType,)) + resolved_type = get_sub_type(model_type) if model_type else param_type + return json_dumps_template(resolved_type.get_json_template_representation()) + + +def hash_file_import(file_import: FileImport) -> str: + """Generate a hash for a FileImport object based on its imports. + + :param file_import: The FileImport object to generate a hash for. + :return: A string representing the hash of the FileImport object. + """ + + return "".join(sorted({str(hash(i)) for i in file_import.imports})) diff --git a/packages/http-client-python/generator/pygen/codegen/templates/test.py.jinja2 b/packages/http-client-python/generator/pygen/codegen/templates/test.py.jinja2 index 40b9e06a600..e3f4fc4df37 100644 --- a/packages/http-client-python/generator/pygen/codegen/templates/test.py.jinja2 +++ b/packages/http-client-python/generator/pygen/codegen/templates/test.py.jinja2 @@ -1,48 +1,51 @@ {% set prefix_lower = test.prefix|lower %} -{% set client_var = "self.client" if code_model.options["azure-arm"] else "client" %} +{% set client_var = "self.client" if is_azure_arm else "client" %} {% set async = "async " if test.async_mode else "" %} {% set async_suffix = "_async" if test.async_mode else "" %} # coding=utf-8 -{% if code_model.license_header %} -{{ code_model.license_header }} +{% if license_header %} +{{ license_header }} {% endif %} import pytest {{ imports }} -{% if code_model.options["azure-arm"] %} +{% if is_azure_arm %} AZURE_LOCATION = "eastus" {% endif %} @pytest.mark.skip("you may need to update the auto-generated test case before run it") class {{ test.test_class_name }}({{ test.base_test_class_name }}): -{% if code_model.options["azure-arm"] %} +{% if is_azure_arm %} def setup_method(self, method): {% if test.async_mode %} self.client = self.create_mgmt_client({{ test.client_name }}, is_async=True) {% else %} self.client = self.create_mgmt_client({{ test.client_name }}) {% endif %} -{% endif %} + {% for testcase in test.testcases %} - {% if code_model.options["azure-arm"] %} @{{ test.preparer_name }}(location=AZURE_LOCATION) - {% else %} - @{{ test.preparer_name }}() - {% endif %} @recorded_by_proxy{{ async_suffix }} - {% if code_model.options["azure-arm"] %} {{ async }}def test_{{ testcase.name }}(self, resource_group): - {% else %} + {{testcase.response }}{{ client_var }}{{ testcase.operation_group_prefix }}.{{ testcase.operation.name }}( + {% for key, value in testcase.params.items() %} + {{ key }}={{ value }}, + {% endfor %} + ){{ testcase.operation_suffix }} + {{ testcase.extra_operation }} + # please add some check logic here by yourself + # ... + +{% endfor %} +{% else %} +{% for testcase in test.testcases %} + @{{ test.preparer_name }}() + @recorded_by_proxy{{ async_suffix }} {{ async }}def test_{{ testcase.name }}(self, {{ prefix_lower }}_endpoint): {{ client_var }} = self.{{ test.create_client_name }}(endpoint={{ prefix_lower }}_endpoint) - {% endif %} {{testcase.response }}{{ client_var }}{{ testcase.operation_group_prefix }}.{{ testcase.operation.name }}( {% for key, value in testcase.params.items() %} - {% if code_model.options["azure-arm"] and key == "resource_group_name" %} - {{ key }}=resource_group.name, - {% else %} - {{ key }}={{ value|indent(12) }}, - {% endif %} + {{ key }}={{ value }}, {% endfor %} ){{ testcase.operation_suffix }} {{ testcase.extra_operation }} @@ -50,3 +53,4 @@ class {{ test.test_class_name }}({{ test.base_test_class_name }}): # ... {% endfor %} +{% endif %}