Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
---
changeKind: internal
packages:
- "@typespec/http-client-python"
---

Optimize sdk generation performance
2 changes: 1 addition & 1 deletion packages/http-client-python/emitter/src/emitter.ts
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ async function onEmitMain(context: EmitContext<PythonEmitterOptions>) {
".mypy_cache",
".pytest_cache",
".vscode",
"_build",
".*_build/",
"/build/",
"dist",
".nox",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,7 @@
from .test_serializer import TestSerializer, TestGeneralSerializer
from .types_serializer import TypesSerializer
from ...utils import to_snake_case, VALID_PACKAGE_MODE
from .utils import (
extract_sample_name,
get_namespace_from_package_name,
get_namespace_config,
)
from .utils import extract_sample_name, get_namespace_from_package_name, get_namespace_config, hash_file_import

_LOGGER = logging.getLogger(__name__)

Expand Down Expand Up @@ -536,39 +532,72 @@ def sample_additional_folder(self) -> Path:
def _generated_tests_samples_folder(self, folder_name: str) -> Path:
return self._root_of_sdk / folder_name

def _process_operation_samples(
self,
samples: dict,
env: Environment,
op_group,
operation,
import_sample_cache: dict[tuple[str, str], str],
out_path: Path,
sample_additional_folder: Path,
) -> None:
"""Process samples for a single operation."""
for sample_value in samples.values():
file = sample_value.get("x-ms-original-file", "sample.json")
file_name = to_snake_case(extract_sample_name(file)) + ".py"
try:
sample_ser = SampleSerializer(
code_model=self.code_model,
env=env,
operation_group=op_group,
operation=operation,
sample=sample_value,
file_name=file_name,
)
file_import = sample_ser.get_file_import()
imports_hash_string = hash_file_import(file_import)
cache_key = (op_group.client.client_namespace, imports_hash_string)
if cache_key not in import_sample_cache:
import_sample_cache[cache_key] = sample_ser.get_imports_from_file_import(file_import)
sample_ser.imports = import_sample_cache[cache_key]

content = sample_ser.serialize()
output_path = out_path / sample_additional_folder / _sample_output_path(file) / file_name
self.write_file(output_path, content)
except Exception as e: # pylint: disable=broad-except
_LOGGER.error("error happens in sample %s: %s", file, e)

def _serialize_and_write_sample(self, env: Environment):
out_path = self._generated_tests_samples_folder("generated_samples")
sample_additional_folder = self.sample_additional_folder

# Cache import_test per (client_namespace, imports_hash_string) since it's expensive to compute
import_sample_cache: dict[tuple[str, str], str] = {}

for client in self.code_model.clients:
for op_group in client.operation_groups:
for operation in op_group.operations:
samples = operation.yaml_data.get("samples")
if not samples or operation.name.startswith("_"):
continue
for value in samples.values():
file = value.get("x-ms-original-file", "sample.json")
file_name = to_snake_case(extract_sample_name(file)) + ".py"
try:
self.write_file(
out_path / self.sample_additional_folder / _sample_output_path(file) / file_name,
SampleSerializer(
code_model=self.code_model,
env=env,
operation_group=op_group,
operation=operation,
sample=value,
file_name=file_name,
).serialize(),
)
except Exception as e: # pylint: disable=broad-except
# sample generation shall not block code generation, so just log error
log_error = f"error happens in sample {file}: {e}"
_LOGGER.error(log_error)
self._process_operation_samples(
samples,
env,
op_group,
operation,
import_sample_cache,
out_path,
sample_additional_folder,
)

def _serialize_and_write_test(self, env: Environment):
self.code_model.for_test = True
out_path = self._generated_tests_samples_folder("generated_tests")

general_serializer = TestGeneralSerializer(code_model=self.code_model, env=env)
self.write_file(out_path / "conftest.py", general_serializer.serialize_conftest())

if not self.code_model.options["azure-arm"]:
for async_mode in (True, False):
async_suffix = "_async" if async_mode else ""
Expand All @@ -578,18 +607,24 @@ def _serialize_and_write_test(self, env: Environment):
general_serializer.serialize_testpreparer(),
)

# Generate test files - reuse serializer per operation group, toggle async_mode
# Cache import_test per (client.name, async_mode) since it's expensive to compute
import_test_cache: dict[tuple[str, bool], str] = {}
for client in self.code_model.clients:
for og in client.operation_groups:
# Create serializer once per operation group
test_serializer = TestSerializer(self.code_model, env, client=client, operation_group=og)
for async_mode in (True, False):
try:
try:
for async_mode in (True, False):
test_serializer.async_mode = async_mode
self.write_file(
out_path / f"{to_snake_case(test_serializer.test_class_name)}.py",
test_serializer.serialize_test(),
)
except Exception as e: # pylint: disable=broad-except
# test generation shall not block code generation, so just log error
log_error = f"error happens in test generation for operation group {og.class_name}: {e}"
_LOGGER.error(log_error)
cache_key = (client.name, async_mode)
if cache_key not in import_test_cache:
import_test_cache[cache_key] = test_serializer.get_import_test()
test_serializer.import_test = import_test_cache[cache_key]
content = test_serializer.serialize_test()
output_path = out_path / f"{to_snake_case(test_serializer.test_class_name)}.py"
self.write_file(output_path, content)
except Exception as e: # pylint: disable=broad-except
_LOGGER.error("error happens in test generation for operation group %s: %s", og.class_name, e)

self.code_model.for_test = False
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
BodyParameter,
FileImport,
)
from .utils import create_fake_value

_LOGGER = logging.getLogger(__name__)

Expand All @@ -40,8 +41,17 @@ def __init__(
self.sample = sample
self.file_name = file_name
self.sample_params = sample.get("parameters", {})
self._imports: str = ""

def _imports(self) -> FileImportSerializer:
@property
def imports(self) -> str:
return self._imports

@imports.setter
def imports(self, value: str) -> None:
self._imports = value

def get_file_import(self) -> FileImport:
imports = FileImport(self.code_model)
client = self.operation_group.client
namespace = client.client_namespace
Expand All @@ -59,7 +69,12 @@ def _imports(self) -> FileImportSerializer:
for param in self.operation.parameters.positional + self.operation.parameters.keyword_only:
if param.client_default_value is None and not param.optional and param.wire_name in self.sample_params:
imports.merge(param.type.imports_for_sample())
return FileImportSerializer(imports, True)

return imports

@staticmethod
def get_imports_from_file_import(file_import: FileImport) -> str:
return str(FileImportSerializer(file_import, True))

def _client_params(self) -> dict[str, Any]:
# client params
Expand Down Expand Up @@ -97,19 +112,14 @@ def handle_param(param: Union[Parameter, BodyParameter], param_value: Any) -> st

# prepare operation parameters
def _operation_params(self) -> dict[str, Any]:
params = [
p
for p in (self.operation.parameters.positional + self.operation.parameters.keyword_only)
if not p.client_default_value
]
failure_info = "fail to find required param named {}"
operation_params = {}
for param in params:
if not param.optional:
for param in self.operation.parameters.positional + self.operation.parameters.keyword_only:
if not param.optional and not param.client_default_value:
param_value = self.sample_params.get(param.wire_name)
if not param_value:
raise Exception(failure_info.format(param.client_name)) # pylint: disable=broad-exception-raised
operation_params[param.client_name] = self.handle_param(param, param_value)
operation_params[param.client_name] = create_fake_value(param.type)
else:
operation_params[param.client_name] = self.handle_param(param, param_value)
return operation_params

def _operation_group_name(self) -> str:
Expand Down Expand Up @@ -154,7 +164,7 @@ def serialize(self) -> str:
operation_params=self._operation_params(),
operation_group_name=self._operation_group_name(),
operation_name=self._operation_name(),
imports=self._imports(),
imports=self.imports,
client_params=self._client_params(),
origin_file=self._origin_file(),
return_var=return_var,
Expand Down
Loading
Loading