diff --git a/.chronus/changes/python-addTypedDict-2026-3-21-17-47-3.md b/.chronus/changes/python-addTypedDict-2026-3-21-17-47-3.md new file mode 100644 index 00000000000..fae8e302e0b --- /dev/null +++ b/.chronus/changes/python-addTypedDict-2026-3-21-17-47-3.md @@ -0,0 +1,8 @@ +--- +# Change versionKind to one of: internal, fix, dependencies, feature, deprecation, breaking +changeKind: feature +packages: + - "@typespec/http-client-python" +--- + +[python] Always generate `TypedDict` typing hints for input models in the `types.py` file, and named union aliases in the `_unions.py` file diff --git a/cspell.yaml b/cspell.yaml index d2e086c1875..04bf892ca38 100644 --- a/cspell.yaml +++ b/cspell.yaml @@ -43,6 +43,7 @@ words: - canonicalizer - Cblack - Cbrown + - clid - clsx - cobertura - codehaus diff --git a/packages/http-client-python/emitter/src/types.ts b/packages/http-client-python/emitter/src/types.ts index f3d53dcfdc8..3f1a629f257 100644 --- a/packages/http-client-python/emitter/src/types.ts +++ b/packages/http-client-python/emitter/src/types.ts @@ -233,6 +233,7 @@ function emitProperty( // Python convert all the type of file part to FileType so clear these models' usage so that they won't be generated addDisableGenerationMap(context, property.type); } + const isNullable = !isMultipartFileInput && sourceType.kind === "nullable"; return { clientName: getClientName(property), isExactName: property.isExactName, @@ -242,6 +243,7 @@ function emitProperty( : property.serializationOptions?.json?.name) ?? property.name, type: getType(context, sourceType), optional: property.optional, + nullable: isNullable, description: property.summary ? property.summary : property.doc, addedOn: getAddedOn(context, property), apiVersions: property.apiVersions, diff --git a/packages/http-client-python/eng/scripts/ci/regenerate.ts b/packages/http-client-python/eng/scripts/ci/regenerate.ts index e6c8cee1110..748e957e328 100644 --- a/packages/http-client-python/eng/scripts/ci/regenerate.ts +++ b/packages/http-client-python/eng/scripts/ci/regenerate.ts @@ -6,29 +6,717 @@ * 1. TypeSpec compile (in-process, parallel) -> emits per-spec YAML only. * 2. Single batched Python subprocess reads all YAMLs and writes the * final `.py` files. Amortizes Python-startup cost across many specs. - * - * Shared helpers/data live in `regenerate-common.ts` (kept identical with the - * `@azure-tools/typespec-python` wrapper copy). */ +import { compile, NodeHost } from "@typespec/compiler"; import { execSync } from "child_process"; -import { existsSync } from "fs"; -import { access, readdir } from "fs/promises"; -import { platform } from "os"; -import { dirname, join, resolve } from "path"; +import { existsSync, rmSync } from "fs"; +import { access, cp, mkdir, mkdtemp, readdir, writeFile } from "fs/promises"; +import { platform, tmpdir } from "os"; +import { dirname, join, relative, resolve } from "path"; import pc from "picocolors"; import { fileURLToPath } from "url"; import { parseArgs } from "util"; -import { - buildTaskGroups, - getSubdirectories, - prepareBaselineOfGeneratedCode, - preprocess, - RegenerateContext, - RegenerateFlags, - runParallel, -} from "./regenerate-common.js"; +// ---- Types ---- + +interface RegenerateFlags { + flavor: string; + debug: boolean; + name?: string; +} + +interface CompileTask { + spec: string; + outputDir: string; + options: Record; +} + +interface TaskGroup { + spec: string; + tasks: CompileTask[]; +} + +interface RegenerateContext { + pluginDir: string; + azureHttpSpecs: string; + httpSpecs: string; + generatedFolder: string; + emitterName: string; +} + +interface BuildTaskGroupsOptions { + emitYamlOnly?: boolean; +} + +// ---- Constants ---- + +const SKIP_SPECS: string[] = [ + "type/file", + "service/multiple-services", + "azure/client-generator-core/response-as-bool", +]; + +const SpecialFlags: Record> = { + azure: { + "generate-test": true, + "generate-sample": true, + }, +}; + +// ---- Spec-specific emitter option overrides ---- + +const AZURE_EMITTER_OPTIONS: Record | Record[]> = { + "azure/client-generator-core/access": { + namespace: "specs.azure.clientgenerator.core.access", + }, + "azure/client-generator-core/alternate-type": { + namespace: "specs.azure.clientgenerator.core.alternatetype", + }, + "azure/client-generator-core/api-version": { + namespace: "specs.azure.clientgenerator.core.apiversion", + }, + "azure/client-generator-core/client-initialization/default": { + namespace: "specs.azure.clientgenerator.core.clientinitialization.default", + }, + "azure/client-generator-core/client-initialization/individually": { + namespace: "specs.azure.clientgenerator.core.clientinitialization.individually", + }, + "azure/client-generator-core/client-initialization/individuallyParent": { + namespace: "specs.azure.clientgenerator.core.clientinitialization.individuallyparent", + }, + "azure/client-generator-core/client-location": { + namespace: "specs.azure.clientgenerator.core.clientlocation", + }, + "azure/client-generator-core/deserialize-empty-string-as-null": { + namespace: "specs.azure.clientgenerator.core.emptystring", + }, + "azure/client-generator-core/flatten-property": { + namespace: "specs.azure.clientgenerator.core.flattenproperty", + }, + "azure/client-generator-core/usage": { + namespace: "specs.azure.clientgenerator.core.usage", + }, + "azure/client-generator-core/client-doc": { + namespace: "specs.azure.clientgenerator.core.clientdoc", + }, + "azure/client-generator-core/override": { + namespace: "specs.azure.clientgenerator.core.override", + }, + "azure/client-generator-core/hierarchy-building": { + namespace: "specs.azure.clientgenerator.core.hierarchybuilding", + }, + "azure/core/basic": { + namespace: "specs.azure.core.basic", + }, + "azure/core/lro/rpc": { + namespace: "specs.azure.core.lro.rpc", + }, + "azure/core/lro/standard": { + namespace: "specs.azure.core.lro.standard", + }, + "azure/core/model": { + namespace: "specs.azure.core.model", + }, + "azure/core/page": { + namespace: "specs.azure.core.page", + }, + "azure/core/scalar": { + namespace: "specs.azure.core.scalar", + }, + "azure/core/traits": { + namespace: "specs.azure.core.traits", + }, + "azure/encode/duration": { + namespace: "specs.azure.encode.duration", + }, + "azure/example/basic": { + namespace: "specs.azure.example.basic", + }, + "azure/payload/pageable": { + namespace: "specs.azure.payload.pageable", + }, + "azure/versioning/previewVersion": { + namespace: "specs.azure.versioning.previewversion", + }, + "client/structure/default": { + namespace: "client.structure.service", + }, + "client/structure/multi-client": { + "package-name": "client-structure-multiclient", + namespace: "client.structure.multiclient", + }, + "client/structure/renamed-operation": { + "package-name": "client-structure-renamedoperation", + namespace: "client.structure.renamedoperation", + }, + "client/structure/two-operation-group": { + "package-name": "client-structure-twooperationgroup", + namespace: "client.structure.twooperationgroup", + }, + "client/naming": [ + { + namespace: "client.naming.main", + }, + { + "package-name": "client-naming-typeddict", + namespace: "client.naming.typeddict", + }, + ], + "client/overload": { + namespace: "client.overload", + }, + "encode/duration": { + namespace: "encode.duration", + }, + "encode/numeric": { + namespace: "encode.numeric", + }, + "parameters/basic": { + namespace: "parameters.basic", + }, + "parameters/spread": { + namespace: "parameters.spread", + }, + "payload/content-negotiation": { + namespace: "payload.contentnegotiation", + }, + "payload/multipart": { + namespace: "payload.multipart", + }, + "serialization/encoded-name/json": { + namespace: "serialization.encodedname.json", + }, + "special-words": { + namespace: "specialwords", + }, + "service/multi-service": { + namespace: "service.multiservice", + }, + "client/structure/client-operation-group": { + "package-name": "client-structure-clientoperationgroup", + namespace: "client.structure.clientoperationgroup", + }, +}; + +const EMITTER_OPTIONS: Record | Record[]> = { + "resiliency/srv-driven/old.tsp": { + "package-name": "resiliency-srv-driven1", + namespace: "resiliency.srv.driven1", + "package-mode": "azure-dataplane", + "package-pprint-name": "ResiliencySrvDriven1", + }, + "resiliency/srv-driven": { + "package-name": "resiliency-srv-driven2", + namespace: "resiliency.srv.driven2", + "package-mode": "azure-dataplane", + "package-pprint-name": "ResiliencySrvDriven2", + }, + "authentication/api-key": { + "clear-output-folder": "true", + }, + "authentication/http/custom": { + "package-name": "authentication-http-custom", + namespace: "authentication.http.custom", + "package-pprint-name": "Authentication Http Custom", + }, + "authentication/union": [ + { + "package-name": "authentication-union", + namespace: "authentication.union", + }, + { + "package-name": "setuppy-authentication-union", + namespace: "setuppy.authentication.union", + "keep-setup-py": "true", + }, + ], + "type/array": { + "package-name": "typetest-array", + namespace: "typetest.array", + }, + "type/dictionary": { + "package-name": "typetest-dictionary", + namespace: "typetest.dictionary", + }, + "type/enum/extensible": { + "package-name": "typetest-enum-extensible", + namespace: "typetest.enum.extensible", + }, + "type/enum/fixed": { + "package-name": "typetest-enum-fixed", + namespace: "typetest.enum.fixed", + }, + "type/model/empty": { + "package-name": "typetest-model-empty", + namespace: "typetest.model.empty", + }, + "type/model/inheritance/enum-discriminator": { + "package-name": "typetest-model-enumdiscriminator", + namespace: "typetest.model.enumdiscriminator", + }, + "type/model/inheritance/nested-discriminator": { + "package-name": "typetest-model-nesteddiscriminator", + namespace: "typetest.model.nesteddiscriminator", + }, + "type/model/inheritance/not-discriminated": [ + { + "package-name": "typetest-model-notdiscriminated", + namespace: "typetest.model.notdiscriminated", + }, + { + "package-name": "typetest-model-notdiscriminated-typeddict", + namespace: "typetest.model.notdiscriminated.typeddict", + }, + ], + "type/model/inheritance/single-discriminator": [ + { + "package-name": "typetest-model-singlediscriminator", + namespace: "typetest.model.singlediscriminator", + }, + { + "package-name": "typetest-model-singlediscriminator-typeddict", + namespace: "typetest.model.singlediscriminator.typeddict", + }, + ], + "type/model/inheritance/recursive": [ + { + "package-name": "typetest-model-recursive", + namespace: "typetest.model.recursive", + }, + { + "package-name": "generation-subdir", + namespace: "generation.subdir", + "generation-subdir": "_generated", + "generate-test": "false", + "clear-output-folder": "true", + }, + ], + "type/model/usage": [ + { + "package-name": "typetest-model-usage", + namespace: "typetest.model.usage", + }, + { + "package-name": "typetest-model-usage-typeddictonly", + namespace: "typetest.model.usage.typeddictonly", + "models-mode": "typeddict", + }, + ], + "type/model/visibility": [ + { + "package-name": "typetest-model-visibility", + namespace: "typetest.model.visibility", + }, + { + "package-name": "headasbooleantrue", + namespace: "headasbooleantrue", + "head-as-boolean": "true", + }, + { + "package-name": "headasbooleanfalse", + namespace: "headasbooleanfalse", + "head-as-boolean": "false", + }, + ], + "type/property/nullable": { + "package-name": "typetest-property-nullable", + namespace: "typetest.property.nullable", + }, + "type/property/optionality": { + "package-name": "typetest-property-optional", + namespace: "typetest.property.optional", + }, + "type/property/additional-properties": { + "package-name": "typetest-property-additionalproperties", + namespace: "typetest.property.additionalproperties", + }, + "type/scalar": { + "package-name": "typetest-scalar", + namespace: "typetest.scalar", + }, + "type/property/value-types": { + "package-name": "typetest-property-valuetypes", + namespace: "typetest.property.valuetypes", + }, + "type/union": { + "package-name": "typetest-union", + namespace: "typetest.union", + }, + "type/union/discriminated": { + "package-name": "typetest-discriminatedunion", + namespace: "typetest.discriminatedunion", + }, + "type/file": { + "package-name": "typetest-file", + namespace: "typetest.file", + }, + documentation: { + "package-name": "specs-documentation", + namespace: "specs.documentation", + }, + "versioning/added": [ + { + "package-name": "versioning-added", + namespace: "versioning.added", + }, + { + "package-name": "generation-subdir2", + namespace: "generation.subdir2", + "generate-test": "false", + "generation-subdir": "_generated", + }, + ], +}; + +// ---- Helpers ---- + +function toPosix(p: string): string { + return p.replace(/\\/g, "/"); +} + +function isAzureSpec(spec: string): boolean { + return spec.includes("azure-http-specs"); +} + +function defaultPackageName(spec: string, ctx: RegenerateContext): string { + const specDir = isAzureSpec(spec) ? ctx.azureHttpSpecs : ctx.httpSpecs; + return toPosix(relative(specDir, dirname(spec))) + .replace(/\//g, "-") + .toLowerCase(); +} + +function getEmitterOptions( + spec: string, + flavor: string, + ctx: RegenerateContext, +): Record[] { + const specDir = isAzureSpec(spec) ? ctx.azureHttpSpecs : ctx.httpSpecs; + const relativeSpec = toPosix(relative(specDir, spec)); + const key = relativeSpec.includes("resiliency/srv-driven/old.tsp") + ? relativeSpec + : dirname(relativeSpec); + const emitterOpts = EMITTER_OPTIONS[key] || + (flavor === "azure" ? AZURE_EMITTER_OPTIONS[key] : [{}]) || [{}]; + return Array.isArray(emitterOpts) ? emitterOpts : [emitterOpts]; +} + +async function getSubdirectories(baseDir: string, flags: RegenerateFlags): Promise { + const subdirectories: string[] = []; + + async function searchDir(currentDir: string) { + const items = await readdir(currentDir, { withFileTypes: true }); + + const promisesArray = items.map(async (item) => { + const subDirPath = join(currentDir, item.name); + if (item.isDirectory()) { + const mainTspPath = join(subDirPath, "main.tsp"); + const clientTspPath = join(subDirPath, "client.tsp"); + + const mainTspRelativePath = toPosix(relative(baseDir, mainTspPath)); + + if (SKIP_SPECS.some((skipSpec) => mainTspRelativePath.includes(skipSpec))) return; + + const hasMainTsp = await access(mainTspPath) + .then(() => true) + .catch(() => false); + const hasClientTsp = await access(clientTspPath) + .then(() => true) + .catch(() => false); + + if (mainTspRelativePath.toLowerCase().includes(flags.name || "")) { + if (mainTspRelativePath.includes("resiliency/srv-driven")) { + subdirectories.push(resolve(subDirPath, "old.tsp")); + } + if (hasClientTsp) { + subdirectories.push(resolve(subDirPath, "client.tsp")); + } else if (hasMainTsp) { + subdirectories.push(resolve(subDirPath, "main.tsp")); + } + } + + await searchDir(subDirPath); + } + }); + + await Promise.all(promisesArray); + } + + await searchDir(baseDir); + return subdirectories; +} + +function buildTaskGroups( + specs: string[], + flags: RegenerateFlags, + ctx: RegenerateContext, + options: BuildTaskGroupsOptions = {}, +): TaskGroup[] { + const groups: TaskGroup[] = []; + + for (const spec of specs) { + const tasks: CompileTask[] = []; + + for (const emitterConfig of getEmitterOptions(spec, flags.flavor, ctx)) { + const opts: Record = {}; + for (const [k, v] of Object.entries(SpecialFlags[flags.flavor] ?? {})) { + opts[k] = v; + } + Object.assign(opts, emitterConfig); + + opts["flavor"] = flags.flavor; + + const packageName = (opts["package-name"] as string) || defaultPackageName(spec, ctx); + const outputDir = + (opts["emitter-output-dir"] as string) || + toPosix(`${ctx.generatedFolder}/../tests/generated/${flags.flavor}/${packageName}`); + opts["emitter-output-dir"] = outputDir; + + if (flags.debug) { + opts["debug"] = true; + } + + opts["examples-dir"] = toPosix(join(dirname(spec), "examples")); + + if (options.emitYamlOnly) { + opts["emit-yaml-only"] = true; + } + + tasks.push({ spec, outputDir, options: opts }); + } + + groups.push({ spec, tasks }); + } + + return groups; +} + +async function compileSpec( + task: CompileTask, + ctx: RegenerateContext, +): Promise<{ success: boolean; error?: string }> { + const { spec, outputDir, options } = task; + + try { + const compilerOptions = { + emit: [ctx.pluginDir], + options: { + [ctx.emitterName]: options, + }, + }; + + const program = await compile(NodeHost, spec, compilerOptions); + + if (program.hasError()) { + const errors = program.diagnostics + .filter((d) => d.severity === "error") + .map((d) => d.message) + .join("\n"); + return { success: false, error: errors }; + } + + return { success: true }; + } catch (err) { + rmSync(outputDir, { recursive: true, force: true }); + return { success: false, error: String(err) }; + } +} + +function renderProgressBar( + completed: number, + failed: number, + total: number, + width: number = 40, +): string { + const successCount = completed - failed; + const successWidth = Math.round((successCount / total) * width); + const failWidth = Math.round((failed / total) * width); + const emptyWidth = width - successWidth - failWidth; + + const successBar = pc.bgGreen(" ".repeat(successWidth)); + const failBar = failed > 0 ? pc.bgRed(" ".repeat(failWidth)) : ""; + const emptyBar = pc.dim("░".repeat(Math.max(0, emptyWidth))); + + const percent = Math.round((completed / total) * 100); + return `${successBar}${failBar}${emptyBar} ${pc.cyan(`${percent}%`)} (${completed}/${total})`; +} + +async function runParallel( + groups: TaskGroup[], + maxJobs: number, + ctx: RegenerateContext, +): Promise> { + const results = new Map(); + const executing: Set> = new Set(); + + const totalTasks = groups.reduce((sum, g) => sum + g.tasks.length, 0); + let completed = 0; + let failed = 0; + const failedSpecs: string[] = []; + + const isTTY = process.stdout.isTTY; + + const updateProgress = () => { + if (isTTY) { + process.stdout.write(`\r${renderProgressBar(completed, failed, totalTasks)}`); + } + }; + + updateProgress(); + + for (const group of groups) { + const runGroup = async () => { + const specDir = isAzureSpec(group.spec) ? ctx.azureHttpSpecs : ctx.httpSpecs; + const shortName = toPosix(relative(specDir, dirname(group.spec))); + + let groupSuccess = true; + for (const task of group.tasks) { + const packageName = (task.options["package-name"] as string) || shortName; + + const result = await compileSpec(task, ctx); + completed++; + + if (!result.success) { + failed++; + failedSpecs.push(`${packageName}: ${result.error}`); + groupSuccess = false; + } + + updateProgress(); + } + + results.set(group.spec, groupSuccess); + }; + + const p = runGroup().finally(() => executing.delete(p)); + executing.add(p); + + if (executing.size >= maxJobs) { + await Promise.race(executing); + } + } + + await Promise.all(executing); + + if (isTTY) { + process.stdout.write("\r" + " ".repeat(60) + "\r"); + } + + if (failedSpecs.length > 0) { + console.log(pc.red(`\nFailed specs:`)); + for (const spec of failedSpecs) { + console.log(pc.red(` • ${spec}`)); + } + } + + return results; +} + +async function preprocess(flavor: string, generatedFolder: string): Promise { + if (flavor !== "azure") return; + + const testsGeneratedDir = resolve(generatedFolder, "../tests/generated/azure"); + + const DELETE_CONTENT = "# This file is to be deleted after regeneration"; + const KEEP_CONTENT = "# This file is to be kept after regeneration"; + const DELETE_FILE = "to_be_deleted.py"; + const entries: { folder: string[]; file: string; content: string }[] = [ + { + folder: ["authentication-api-key", "authentication", "apikey", "_operations"], + file: DELETE_FILE, + content: DELETE_CONTENT, + }, + { + folder: ["generation-subdir", "generation", "subdir", "_generated"], + file: DELETE_FILE, + content: DELETE_CONTENT, + }, + { + folder: ["generation-subdir", "generated_tests"], + file: DELETE_FILE, + content: DELETE_CONTENT, + }, + { + folder: ["generation-subdir", "generation", "subdir"], + file: "to_be_kept.py", + content: KEEP_CONTENT, + }, + ]; + + await Promise.all( + entries.map(async ({ folder, file, content }) => { + const targetFolder = join(testsGeneratedDir, ...folder); + await mkdir(targetFolder, { recursive: true }); + await writeFile(join(targetFolder, file), content); + }), + ); +} + +async function prepareBaselineOfGeneratedCode(generatedFolder: string): Promise { + const repoUrl = "https://github.com/Azure/azure-sdk-for-python.git"; + const branch = "typespec-python-generated-tests"; + const sourceSubdir = "eng/tools/azure-sdk-tools/emitter/generated"; + const testsGeneratedDir = resolve(generatedFolder, "../tests/generated"); + + console.log(pc.cyan(`\n${"=".repeat(60)}`)); + console.log(pc.cyan(`Resetting baseline from ${repoUrl} (${branch}/${sourceSubdir})`)); + console.log(pc.cyan(`${"=".repeat(60)}\n`)); + + if (existsSync(testsGeneratedDir)) { + console.log(pc.dim(`Removing ${testsGeneratedDir}`)); + rmSync(testsGeneratedDir, { recursive: true, force: true }); + } + + const tempDir = await mkdtemp(join(tmpdir(), "azsdk-baseline-")); + try { + console.log(pc.dim(`Cloning into ${tempDir}`)); + const run = (cmd: string) => + execSync(cmd, { cwd: tempDir, stdio: ["ignore", "ignore", "inherit"] }); + + run(`git init`); + run(`git config core.longpaths true`); + run(`git remote add origin ${repoUrl}`); + run(`git config core.sparseCheckout true`); + run(`git sparse-checkout init --cone`); + run(`git sparse-checkout set ${sourceSubdir}`); + run(`git fetch --depth 1 origin ${branch}`); + run(`git checkout FETCH_HEAD`); + + const sourceRoot = join(tempDir, ...sourceSubdir.split("/")); + for (const flavor of ["azure", "unbranded"]) { + const src = join(sourceRoot, flavor); + const dest = join(testsGeneratedDir, flavor); + if (!existsSync(src)) { + console.warn(pc.yellow(`Baseline folder not found: ${src}`)); + continue; + } + console.log(pc.dim(`Copying ${flavor}/ -> ${dest}`)); + await cp(src, dest, { recursive: true }); + } + + console.log(pc.green(`Baseline reset complete.\n`)); + } finally { + rmSync(tempDir, { recursive: true, force: true }); + } + + const deleteIfExists = (path: string) => { + if (!existsSync(path)) return; + console.log(pc.dim(`Deleting ${path}`)); + rmSync(path, { recursive: true, force: true }); + }; + + deleteIfExists(join(testsGeneratedDir, "azure", "authentication-http-custom")); + deleteIfExists(join(testsGeneratedDir, "unbranded", "encode-array")); + + if (existsSync(testsGeneratedDir)) { + const entries = await readdir(testsGeneratedDir, { recursive: true, withFileTypes: true }); + for (const entry of entries) { + if (entry.isFile() && entry.name === "README.md") { + deleteIfExists(join(entry.parentPath, entry.name)); + } + } + } +} // Parse arguments const argv = parseArgs({ diff --git a/packages/http-client-python/generator/pygen/__init__.py b/packages/http-client-python/generator/pygen/__init__.py index 99c890b9553..b8841e4fc1a 100644 --- a/packages/http-client-python/generator/pygen/__init__.py +++ b/packages/http-client-python/generator/pygen/__init__.py @@ -168,10 +168,10 @@ def _validate_and_transform(self, key: str, value: Any) -> Any: if key == "models-mode" and value == "none": value = False # switch to falsy value for easier code writing - if key == "models-mode" and value not in ["msrest", "dpg", False]: + if key == "models-mode" and value not in ["msrest", "dpg", "typeddict", False]: raise ValueError( - "--models-mode can only be 'msrest', 'dpg' or 'none'. " - "Pass in 'msrest' if you want msrest models, or " + "--models-mode can only be 'msrest', 'dpg', 'typeddict', or 'none'. " + "Pass in 'msrest' if you want msrest models, 'typeddict' for TypedDict models, or " "'none' if you don't want any." ) if key == "package-mode": diff --git a/packages/http-client-python/generator/pygen/codegen/models/__init__.py b/packages/http-client-python/generator/pygen/codegen/models/__init__.py index a1d9f9a4dbc..1706576fb9d 100644 --- a/packages/http-client-python/generator/pygen/codegen/models/__init__.py +++ b/packages/http-client-python/generator/pygen/codegen/models/__init__.py @@ -9,7 +9,7 @@ from .base_builder import BaseBuilder, ParameterListType from .code_model import CodeModel from .client import Client -from .model_type import ModelType, JSONModelType, DPGModelType, MsrestModelType +from .model_type import ModelType, JSONModelType, DPGModelType, MsrestModelType, TypedDictModelType from .dictionary_type import DictionaryType from .list_type import ListType from .combined_type import CombinedType @@ -167,7 +167,9 @@ def build_type(yaml_data: dict[str, Any], code_model: CodeModel) -> BaseType: response: Optional[BaseType] = None if yaml_data["type"] == "model": # need to special case model to avoid recursion - if yaml_data["base"] == "json" or not code_model.options["models-mode"]: + if yaml_data["base"] == "typeddict": + model_type = TypedDictModelType # type: ignore + elif yaml_data["base"] == "json" or not code_model.options["models-mode"]: model_type = JSONModelType elif yaml_data["base"] == "dpg": model_type = DPGModelType # type: ignore diff --git a/packages/http-client-python/generator/pygen/codegen/models/code_model.py b/packages/http-client-python/generator/pygen/codegen/models/code_model.py index b6c3b2dda0d..112485a3f86 100644 --- a/packages/http-client-python/generator/pygen/codegen/models/code_model.py +++ b/packages/http-client-python/generator/pygen/codegen/models/code_model.py @@ -152,16 +152,28 @@ def get_relative_import_path( return result return f"{result}{module_name}" if result.endswith(".") else f"{result}.{module_name}" - def get_unique_models_alias(self, serialize_namespace: str, imported_namespace: str) -> str: + def _get_unique_import_alias(self, serialize_namespace: str, imported_namespace: str, module_name: str) -> str: if not self.has_subnamespace: - return "_models" + return f"_{module_name}" relative_path = self.get_relative_import_path( serialize_namespace, self.get_imported_namespace_for_model(imported_namespace) ) dot_num = max(relative_path.count(".") - 1, 0) - parts = [""] + ([p for p in relative_path.split(".") if p] or ["models"]) + path_parts = [p for p in relative_path.split(".") if p] + # For "models", keep existing format: _ (e.g. _models1, _firstnamespace_models2) + # For other modules like "types", prefix with module name: _types_ + if module_name == "models": + parts = [""] + (path_parts or [module_name]) + else: + parts = [f"_{module_name}"] + (path_parts or []) return "_".join(parts) + (str(dot_num) if dot_num > 0 else "") + def get_unique_models_alias(self, serialize_namespace: str, imported_namespace: str) -> str: + return self._get_unique_import_alias(serialize_namespace, imported_namespace, "models") + + def get_unique_types_alias(self, serialize_namespace: str, imported_namespace: str) -> str: + return self._get_unique_import_alias(serialize_namespace, imported_namespace, "types") + @property def client_namespace_types(self) -> dict[str, ClientNamespaceType]: if not self._client_namespace_types: @@ -174,6 +186,14 @@ def client_namespace_types(self) -> dict[str, ClientNamespaceType]: if model.client_namespace not in self._client_namespace_types: self._client_namespace_types[model.client_namespace] = ClientNamespaceType() self._client_namespace_types[model.client_namespace].models.append(model) + # TypedDict copies (base="typeddict") are excluded from model_types to keep + # them out of _models.py, but they need to be in the namespace model list + # so the TypesSerializer can render them in types.py. + for t in self.types_map.values(): + if isinstance(t, ModelType) and t.base == "typeddict" and t.usage != UsageFlags.Default.value: + if t.client_namespace not in self._client_namespace_types: + self._client_namespace_types[t.client_namespace] = ClientNamespaceType() + self._client_namespace_types[t.client_namespace].models.append(t) for enum in self.enums: if enum.client_namespace not in self._client_namespace_types: self._client_namespace_types[enum.client_namespace] = ClientNamespaceType() @@ -339,7 +359,9 @@ def model_types(self) -> list[ModelType]: """All of the model types in this class""" if not self._model_types: self._model_types = [ - t for t in self.types_map.values() if isinstance(t, ModelType) and t.usage != UsageFlags.Default.value + t + for t in self.types_map.values() + if isinstance(t, ModelType) and t.usage != UsageFlags.Default.value and t.base != "typeddict" ] return self._model_types @@ -349,7 +371,7 @@ def model_types(self, val: list[ModelType]) -> None: @staticmethod def get_public_model_types(models: list[ModelType]) -> list[ModelType]: - return [m for m in models if not m.internal and not m.base == "json"] + return [m for m in models if not m.internal and not m.base == "json" and not m.is_typed_dict_only] @property def public_model_types(self) -> list[ModelType]: diff --git a/packages/http-client-python/generator/pygen/codegen/models/combined_type.py b/packages/http-client-python/generator/pygen/codegen/models/combined_type.py index 249b7474dcd..1014482f5e0 100644 --- a/packages/http-client-python/generator/pygen/codegen/models/combined_type.py +++ b/packages/http-client-python/generator/pygen/codegen/models/combined_type.py @@ -66,7 +66,7 @@ def docstring_type(self, **kwargs: Any) -> str: def type_annotation(self, **kwargs: Any) -> str: if self.name: - return f'"_types.{self.name}"' + return f'"_unions.{self.name}"' return self.type_definition(**kwargs) def type_definition(self, **kwargs: Any) -> str: @@ -116,10 +116,10 @@ def imports(self, **kwargs: Any) -> FileImport: file_import = FileImport(self.code_model) serialize_namespace = kwargs.get("serialize_namespace", self.code_model.namespace) serialize_namespace_type = kwargs.get("serialize_namespace_type") - if self.name and serialize_namespace_type != NamespaceType.TYPES_FILE: + if self.name and serialize_namespace_type != NamespaceType.UNIONS_FILE: file_import.add_submodule_import( self.code_model.get_relative_import_path(serialize_namespace), - "_types", + "_unions", ImportType.LOCAL, TypingSection.TYPING, ) diff --git a/packages/http-client-python/generator/pygen/codegen/models/enum_type.py b/packages/http-client-python/generator/pygen/codegen/models/enum_type.py index 9cbec3d1b30..e9fd13a5146 100644 --- a/packages/http-client-python/generator/pygen/codegen/models/enum_type.py +++ b/packages/http-client-python/generator/pygen/codegen/models/enum_type.py @@ -167,6 +167,10 @@ def description(self, *, is_operation_file: bool) -> str: enum_description = f"Known values are: {possible_values_str}." return enum_description + @property + def is_typeddict_mode(self) -> bool: + return self.code_model.options["models-mode"] == "typeddict" + def type_annotation(self, **kwargs: Any) -> str: """The python type used for type annotation @@ -174,9 +178,20 @@ def type_annotation(self, **kwargs: Any) -> str: :rtype: str """ if self.code_model.options["models-mode"]: + if self.is_typeddict_mode: + # In typeddict mode, enums are Literal aliases defined in types.py + serialize_namespace_type = kwargs.get("serialize_namespace_type") + if serialize_namespace_type == NamespaceType.TYPES_FILE: + # Same file — just the name, no module prefix + return self.name + # From operation/client files, use types.EnumName (matching model pattern) + return f"types.{self.name}" module_name = "" - if kwargs.get("need_model_alias", True): + serialize_namespace_type = kwargs.get("serialize_namespace_type") + if serialize_namespace_type == NamespaceType.TYPES_FILE: + pass # no module prefix for types.py + elif kwargs.get("need_model_alias", True): serialize_namespace = kwargs.get("serialize_namespace", self.code_model.namespace) model_alias = self.code_model.get_unique_models_alias(serialize_namespace, self.client_namespace) module_name = f"{model_alias}." @@ -240,31 +255,66 @@ def imports(self, **kwargs: Any) -> FileImport: file_import = FileImport(self.code_model) file_import.merge(self.value_type.imports(**kwargs)) if self.code_model.options["models-mode"]: - file_import.add_submodule_import("typing", "Union", ImportType.STDLIB, TypingSection.REGULAR) + if self.is_typeddict_mode: + # In typeddict mode, enums are Literal aliases in types.py — no Union needed + serialize_namespace_type = kwargs.get("serialize_namespace_type") + if serialize_namespace_type == NamespaceType.TYPES_FILE: + # Same file — no import needed for same-namespace enums + serialize_namespace = kwargs.get("serialize_namespace", self.code_model.namespace) + if self.client_namespace != serialize_namespace: + # Cross-namespace: import from sibling types module + relative_path = self.code_model.get_relative_import_path( + serialize_namespace, self.client_namespace + ) + file_import.add_submodule_import( + f"{relative_path}types" if relative_path != "." else ".types", + self.name, + ImportType.LOCAL, + typing_section=TypingSection.REGULAR, + ) + elif serialize_namespace_type in [NamespaceType.OPERATION, NamespaceType.CLIENT]: + # Import types module directly (matching model pattern) + serialize_namespace = kwargs.get("serialize_namespace", self.code_model.namespace) + relative_path = self.code_model.get_relative_import_path(serialize_namespace, self.client_namespace) + file_import.add_submodule_import( + relative_path, + "types", + ImportType.LOCAL, + ) + else: + file_import.add_submodule_import("typing", "Union", ImportType.STDLIB, TypingSection.REGULAR) - serialize_namespace = kwargs.get("serialize_namespace", self.code_model.namespace) - relative_path = self.code_model.get_relative_import_path(serialize_namespace, self.client_namespace) - alias = self.code_model.get_unique_models_alias(serialize_namespace, self.client_namespace) - serialize_namespace_type = kwargs.get("serialize_namespace_type") - called_by_property = kwargs.get("called_by_property", False) - if serialize_namespace_type in [NamespaceType.OPERATION, NamespaceType.CLIENT]: - file_import.add_submodule_import( - relative_path, - "models", - ImportType.LOCAL, - alias=alias, - typing_section=TypingSection.REGULAR, - ) - elif serialize_namespace_type == NamespaceType.TYPES_FILE or ( - serialize_namespace_type == NamespaceType.MODEL and called_by_property - ): - file_import.add_submodule_import( - relative_path, - "models", - ImportType.LOCAL, - alias=alias, - typing_section=TypingSection.TYPING, - ) + serialize_namespace = kwargs.get("serialize_namespace", self.code_model.namespace) + relative_path = self.code_model.get_relative_import_path(serialize_namespace, self.client_namespace) + alias = self.code_model.get_unique_models_alias(serialize_namespace, self.client_namespace) + serialize_namespace_type = kwargs.get("serialize_namespace_type") + called_by_property = kwargs.get("called_by_property", False) + if serialize_namespace_type in [NamespaceType.OPERATION, NamespaceType.CLIENT]: + file_import.add_submodule_import( + relative_path, + "models", + ImportType.LOCAL, + alias=alias, + typing_section=TypingSection.REGULAR, + ) + elif serialize_namespace_type == NamespaceType.TYPES_FILE: + # Import enum name directly to avoid dotted forward refs in TypedDict annotations + file_import.add_submodule_import( + f"{relative_path}models" if relative_path != "." else ".models", + self.name, + ImportType.LOCAL, + typing_section=TypingSection.TYPING, + ) + elif serialize_namespace_type == NamespaceType.UNIONS_FILE or ( + serialize_namespace_type == NamespaceType.MODEL and called_by_property + ): + file_import.add_submodule_import( + relative_path, + "models", + ImportType.LOCAL, + alias=alias, + typing_section=TypingSection.TYPING, + ) file_import.merge(self.value_type.imports(**kwargs)) return file_import diff --git a/packages/http-client-python/generator/pygen/codegen/models/list_type.py b/packages/http-client-python/generator/pygen/codegen/models/list_type.py index 38d5b31ef07..275432b4179 100644 --- a/packages/http-client-python/generator/pygen/codegen/models/list_type.py +++ b/packages/http-client-python/generator/pygen/codegen/models/list_type.py @@ -6,6 +6,7 @@ from typing import Any, Optional, Union, TYPE_CHECKING from .base import BaseType from .imports import FileImport +from .utils import NamespaceType if TYPE_CHECKING: from .code_model import CodeModel @@ -41,12 +42,17 @@ def type_annotation(self, **kwargs: Any) -> str: # this means we're version tolerant XML, we just return the XML element return self.element_type.type_annotation(**kwargs) - # if there is a function/property named `list` we have to make sure there's no conflict with the built-in `list` + # if there is a function named `list` we have to make sure there's no conflict with the built-in `list` + # in operation files. The operation_groups_serializer defines `List = list` alias for this case. + serialize_namespace_type = kwargs.get("serialize_namespace_type") is_operation_file = kwargs.get("is_operation_file", False) - use_list_import = (self.code_model.has_operation_named_list and is_operation_file) or ( - self.code_model.has_property_named_list and not is_operation_file + in_operation_context = ( + serialize_namespace_type in (NamespaceType.OPERATION, NamespaceType.CLIENT) or is_operation_file ) - list_type = "List" if use_list_import else "list" + if in_operation_context and self.code_model.has_operation_named_list: + list_type = "List" + else: + list_type = "list" return f"{list_type}[{self.element_type.type_annotation(**kwargs)}]" def description(self, *, is_operation_file: bool) -> str: diff --git a/packages/http-client-python/generator/pygen/codegen/models/model_type.py b/packages/http-client-python/generator/pygen/codegen/models/model_type.py index 76ef6e8be8d..c4f83fc6c18 100644 --- a/packages/http-client-python/generator/pygen/codegen/models/model_type.py +++ b/packages/http-client-python/generator/pygen/codegen/models/model_type.py @@ -77,11 +77,19 @@ def __init__( self.cross_language_definition_id: Optional[str] = self.yaml_data.get("crossLanguageDefinitionId") self.usage: int = self.yaml_data.get("usage", UsageFlags.Input.value | UsageFlags.Output.value) self.client_namespace: str = self.yaml_data.get("clientNamespace", code_model.namespace) + self.is_typed_dict_only: bool = ( + self.yaml_data.get("typedDictOnly", False) or code_model.options["models-mode"] == "typeddict" + ) @property def is_usage_output(self) -> bool: return bool(self.usage & UsageFlags.Output.value) + @property + def is_used_in_operations_via_types(self) -> bool: + """Whether this model would be imported from types.py (not models) in operations.""" + return False + @property def flattened_property(self) -> Optional[Property]: try: @@ -275,12 +283,20 @@ class GeneratedModelType(ModelType): def type_annotation(self, **kwargs: Any) -> str: is_operation_file = kwargs.pop("is_operation_file", False) skip_quote = kwargs.get("skip_quote", False) + serialize_namespace_type = kwargs.get("serialize_namespace_type") module_name = "" - if kwargs.get("need_model_alias", True): + # In types.py, use bare name to avoid pyright "variable in type expression" errors + if serialize_namespace_type == NamespaceType.TYPES_FILE: + pass # no module prefix, no internal file prefix + elif kwargs.get("need_model_alias", True): serialize_namespace = kwargs.get("serialize_namespace", self.code_model.namespace) model_alias = self.code_model.get_unique_models_alias(serialize_namespace, self.client_namespace) module_name = f"{model_alias}." - file_name = f"{self.code_model.models_filename}." if self.internal else "" + file_name = ( + f"{self.code_model.models_filename}." + if self.internal and serialize_namespace_type != NamespaceType.TYPES_FILE + else "" + ) retval = module_name + file_name + self.name return retval if is_operation_file or skip_quote else f'"{retval}"' @@ -305,7 +321,7 @@ def imports(self, **kwargs: Any) -> FileImport: alias = self.code_model.get_unique_models_alias(serialize_namespace, self.client_namespace) serialize_namespace_type = kwargs.get("serialize_namespace_type") called_by_property = kwargs.get("called_by_property", False) - # add import for models in operations or _types file + # add import for models in operations, types, or unions file if serialize_namespace_type in [NamespaceType.OPERATION, NamespaceType.CLIENT]: file_import.add_submodule_import( relative_path, @@ -320,7 +336,33 @@ def imports(self, **kwargs: Any) -> FileImport: ImportType.LOCAL, alias="_Model", ) - elif serialize_namespace_type == NamespaceType.TYPES_FILE or ( + elif serialize_namespace_type == NamespaceType.TYPES_FILE: + # Don't import models that will be defined in this namespace's types.py — + # either as TypedDict classes (non-discriminated) or as Union aliases (discriminated bases). + # Only same-namespace non-json models are in the same types.py file. + same_namespace = relative_path == "." + will_be_in_types_file = self.base != "json" and same_namespace + if not will_be_in_types_file: + if same_namespace: + # json models from same namespace — import from .models (or .models._models for internal) + import_path = f".models.{self.code_model.models_filename}" if self.internal else ".models" + file_import.add_submodule_import( + import_path, + self.name, + ImportType.LOCAL, + typing_section=TypingSection.TYPING, + ) + else: + # Cross-namespace model — import from sibling namespace's types module + file_import.add_submodule_import( + self.code_model.get_relative_import_path( + serialize_namespace, self.client_namespace, module_name="types" + ), + self.name, + ImportType.LOCAL, + typing_section=TypingSection.TYPING, + ) + elif serialize_namespace_type == NamespaceType.UNIONS_FILE or ( serialize_namespace_type == NamespaceType.MODEL and called_by_property ): file_import.add_submodule_import( @@ -352,6 +394,33 @@ def imports(self, **kwargs: Any) -> FileImport: class DPGModelType(GeneratedModelType): base = "dpg" + @property + def is_used_in_operations_via_types(self) -> bool: + return self.is_typed_dict_only + + def type_annotation(self, **kwargs: Any) -> str: + if self.is_typed_dict_only: + is_operation_file = kwargs.pop("is_operation_file", False) + skip_quote = kwargs.get("skip_quote", False) + serialize_namespace_type = kwargs.get("serialize_namespace_type") + # Within types.py, use bare name (no module prefix) + if serialize_namespace_type == NamespaceType.TYPES_FILE: + retval = self.name + else: + serialize_namespace = kwargs.get("serialize_namespace", self.code_model.namespace) + types_alias = self.code_model.get_unique_types_alias(serialize_namespace, self.client_namespace) + retval = f"{types_alias}.{self.name}" + return retval if is_operation_file or skip_quote else f'"{retval}"' + return super().type_annotation(**kwargs) + + def docstring_type(self, **kwargs: Any) -> str: + if self.is_typed_dict_only: + client_namespace = self.client_namespace + if self.code_model.options.get("generation-subdir"): + client_namespace += f".{self.code_model.options['generation-subdir']}" + return f"~{client_namespace}.types.{self.name}" + return super().docstring_type(**kwargs) + def serialization_type(self, **kwargs: Any) -> str: return ( self.type_annotation(skip_quote=True, **kwargs) @@ -364,7 +433,79 @@ def instance_check_template(self) -> str: return "isinstance({}, " + f"_models.{self.name})" def imports(self, **kwargs: Any) -> FileImport: + if self.is_typed_dict_only: + file_import = FileImport(self.code_model) + serialize_namespace_type = kwargs.get("serialize_namespace_type") + serialize_namespace = kwargs.get("serialize_namespace", self.code_model.namespace) + relative_path = self.code_model.get_relative_import_path(serialize_namespace, self.client_namespace) + alias = self.code_model.get_unique_types_alias(serialize_namespace, self.client_namespace) + same_namespace = relative_path == "." + if serialize_namespace_type in [NamespaceType.OPERATION, NamespaceType.CLIENT]: + file_import.add_submodule_import( + relative_path, + "types", + ImportType.LOCAL, + alias=alias, + ) + elif serialize_namespace_type == NamespaceType.TYPES_FILE and same_namespace: + pass # model is defined in this types.py — no import needed + elif serialize_namespace_type in [NamespaceType.TYPES_FILE, NamespaceType.UNIONS_FILE] or ( + serialize_namespace_type == NamespaceType.MODEL and kwargs.get("called_by_property", False) + ): + file_import.add_submodule_import( + relative_path, + "types", + ImportType.LOCAL, + alias=alias, + typing_section=TypingSection.TYPING, + ) + return file_import file_import = super().imports(**kwargs) if self.flattened_property: file_import.add_submodule_import("typing", "Any", ImportType.STDLIB) return file_import + + +class TypedDictModelType(DPGModelType): + base = "typeddict" + + @property + def is_used_in_operations_via_types(self) -> bool: + return True + + def type_annotation(self, **kwargs: Any) -> str: + is_operation_file = kwargs.pop("is_operation_file", False) + skip_quote = kwargs.get("skip_quote", False) + serialize_namespace_type = kwargs.get("serialize_namespace_type") + if serialize_namespace_type == NamespaceType.TYPES_FILE: + retval = self.name + else: + serialize_namespace = kwargs.get("serialize_namespace", self.code_model.namespace) + types_alias = self.code_model.get_unique_types_alias(serialize_namespace, self.client_namespace) + retval = f"{types_alias}.{self.name}" + return retval if is_operation_file or skip_quote else f'"{retval}"' + + def docstring_type(self, **kwargs: Any) -> str: + client_namespace = self.client_namespace + if self.code_model.options.get("generation-subdir"): + client_namespace += f".{self.code_model.options['generation-subdir']}" + return f"~{client_namespace}.types.{self.name}" + + @property + def instance_check_template(self) -> str: + return "isinstance({}, MutableMapping)" + + def imports(self, **kwargs: Any) -> FileImport: + file_import = FileImport(self.code_model) + serialize_namespace_type = kwargs.get("serialize_namespace_type") + serialize_namespace = kwargs.get("serialize_namespace", self.code_model.namespace) + relative_path = self.code_model.get_relative_import_path(serialize_namespace, self.client_namespace) + alias = self.code_model.get_unique_types_alias(serialize_namespace, self.client_namespace) + if serialize_namespace_type in [NamespaceType.OPERATION, NamespaceType.CLIENT]: + file_import.add_submodule_import( + relative_path, + "types", + ImportType.LOCAL, + alias=alias, + ) + return file_import diff --git a/packages/http-client-python/generator/pygen/codegen/models/operation.py b/packages/http-client-python/generator/pygen/codegen/models/operation.py index 38e94483ed5..b63e4205dbd 100644 --- a/packages/http-client-python/generator/pygen/codegen/models/operation.py +++ b/packages/http-client-python/generator/pygen/codegen/models/operation.py @@ -458,6 +458,7 @@ def imports( # pylint: disable=too-many-branches, disable=too-many-statements r.type and not isinstance(r.type, BinaryIteratorType) and not xml_serializable(str(r.default_content_type)) + and not (isinstance(r.type, ModelType) and r.type.is_typed_dict_only) for r in self.responses ): file_import.add_submodule_import(relative_path, "_deserialize", ImportType.LOCAL) diff --git a/packages/http-client-python/generator/pygen/codegen/models/paging_operation.py b/packages/http-client-python/generator/pygen/codegen/models/paging_operation.py index f64363ed5fb..91f2d39ec05 100644 --- a/packages/http-client-python/generator/pygen/codegen/models/paging_operation.py +++ b/packages/http-client-python/generator/pygen/codegen/models/paging_operation.py @@ -186,7 +186,11 @@ def imports(self, async_mode: bool, **kwargs: Any) -> FileImport: serialize_namespace, module_name="_utils.model_base" ) file_import.merge(self.item_type.imports(**kwargs)) - if self.default_error_deserialization(serialize_namespace) or self.need_deserialize: + need_deserialize_import = self.default_error_deserialization(serialize_namespace) or ( + self.need_deserialize + and not (isinstance(self.item_type, ModelType) and self.item_type.is_typed_dict_only) + ) + if need_deserialize_import: file_import.add_submodule_import(relative_path, "_deserialize", ImportType.LOCAL) if self.is_xml_paging: file_import.add_submodule_import("xml.etree", "ElementTree", ImportType.STDLIB, alias="ET") diff --git a/packages/http-client-python/generator/pygen/codegen/models/parameter.py b/packages/http-client-python/generator/pygen/codegen/models/parameter.py index 13a8c95d0c0..775a7db00f0 100644 --- a/packages/http-client-python/generator/pygen/codegen/models/parameter.py +++ b/packages/http-client-python/generator/pygen/codegen/models/parameter.py @@ -180,7 +180,7 @@ def imports(self, async_mode: bool, **kwargs: Any) -> FileImport: if isinstance(self.type, CombinedType) and self.type.name: file_import.add_submodule_import( self.code_model.get_relative_import_path(serialize_namespace), - "_types", + "_unions", ImportType.LOCAL, TypingSection.TYPING, ) @@ -338,7 +338,7 @@ def method_location( # pylint: disable=too-many-return-statements ) -> ParameterMethodLocation: if not self.in_method_signature: raise ValueError(f"Parameter '{self.client_name}' is not in the method.") - if self.code_model.options["models-mode"] == "dpg" and self.in_flattened_body: + if self.code_model.options["models-mode"] in ("dpg", "typeddict") and self.in_flattened_body: return ParameterMethodLocation.KEYWORD_ONLY if self.grouper: return ParameterMethodLocation.POSITIONAL diff --git a/packages/http-client-python/generator/pygen/codegen/models/primitive_types.py b/packages/http-client-python/generator/pygen/codegen/models/primitive_types.py index cc864cf3ee1..6b4a4f8a444 100644 --- a/packages/http-client-python/generator/pygen/codegen/models/primitive_types.py +++ b/packages/http-client-python/generator/pygen/codegen/models/primitive_types.py @@ -9,6 +9,7 @@ from .base import BaseType from .imports import FileImport, ImportType, TypingSection +from .utils import NamespaceType if TYPE_CHECKING: from .code_model import CodeModel @@ -271,6 +272,8 @@ def docstring_type(self, **kwargs: Any) -> str: return "~" + self.type_annotation() def type_annotation(self, **kwargs: Any) -> str: + if kwargs.get("serialize_namespace_type") == NamespaceType.TYPES_FILE: + return "float" return "decimal.Decimal" def docstring_text(self, **kwargs: Any) -> str: @@ -281,7 +284,8 @@ def get_declaration(self, value: decimal.Decimal) -> str: def imports(self, **kwargs: Any) -> FileImport: file_import = FileImport(self.code_model) - file_import.add_import("decimal", ImportType.STDLIB) + if kwargs.get("serialize_namespace_type") != NamespaceType.TYPES_FILE: + file_import.add_import("decimal", ImportType.STDLIB) return file_import @property @@ -357,6 +361,8 @@ def docstring_type(self, **kwargs: Any) -> str: return "~" + self.type_annotation() def type_annotation(self, **kwargs: Any) -> str: + if kwargs.get("serialize_namespace_type") == NamespaceType.TYPES_FILE: + return "str" return "datetime.datetime" def docstring_text(self, **kwargs: Any) -> str: @@ -370,7 +376,8 @@ def get_declaration(self, value: datetime.datetime) -> str: def imports(self, **kwargs: Any) -> FileImport: file_import = FileImport(self.code_model) - file_import.add_import("datetime", ImportType.STDLIB) + if kwargs.get("serialize_namespace_type") != NamespaceType.TYPES_FILE: + file_import.add_import("datetime", ImportType.STDLIB) return file_import @property @@ -399,6 +406,8 @@ def docstring_type(self, **kwargs: Any) -> str: return "~" + self.type_annotation() def type_annotation(self, **kwargs: Any) -> str: + if kwargs.get("serialize_namespace_type") == NamespaceType.TYPES_FILE: + return "str" return "datetime.time" def docstring_text(self, **kwargs: Any) -> str: @@ -412,7 +421,8 @@ def get_declaration(self, value: datetime.time) -> str: def imports(self, **kwargs: Any) -> FileImport: file_import = FileImport(self.code_model) - file_import.add_import("datetime", ImportType.STDLIB) + if kwargs.get("serialize_namespace_type") != NamespaceType.TYPES_FILE: + file_import.add_import("datetime", ImportType.STDLIB) return file_import @property @@ -445,6 +455,8 @@ def docstring_type(self, **kwargs: Any) -> str: return "~" + self.type_annotation() def type_annotation(self, **kwargs: Any) -> str: + if kwargs.get("serialize_namespace_type") == NamespaceType.TYPES_FILE: + return "int" return "datetime.datetime" def docstring_text(self, **kwargs: Any) -> str: @@ -458,7 +470,8 @@ def get_declaration(self, value: datetime.datetime) -> str: def imports(self, **kwargs: Any) -> FileImport: file_import = FileImport(self.code_model) - file_import.add_import("datetime", ImportType.STDLIB) + if kwargs.get("serialize_namespace_type") != NamespaceType.TYPES_FILE: + file_import.add_import("datetime", ImportType.STDLIB) return file_import @property @@ -487,6 +500,8 @@ def docstring_type(self, **kwargs: Any) -> str: return "~" + self.type_annotation() def type_annotation(self, **kwargs: Any) -> str: + if kwargs.get("serialize_namespace_type") == NamespaceType.TYPES_FILE: + return "str" return "datetime.date" def docstring_text(self, **kwargs: Any) -> str: @@ -500,7 +515,8 @@ def get_declaration(self, value: datetime.date) -> str: def imports(self, **kwargs: Any) -> FileImport: file_import = FileImport(self.code_model) - file_import.add_import("datetime", ImportType.STDLIB) + if kwargs.get("serialize_namespace_type") != NamespaceType.TYPES_FILE: + file_import.add_import("datetime", ImportType.STDLIB) return file_import @property @@ -543,6 +559,8 @@ def docstring_type(self, **kwargs: Any) -> str: return "~" + self.type_annotation() def type_annotation(self, **kwargs: Any) -> str: + if kwargs.get("serialize_namespace_type") == NamespaceType.TYPES_FILE: + return "str" return "datetime.timedelta" def docstring_text(self, **kwargs: Any) -> str: @@ -556,7 +574,8 @@ def get_declaration(self, value: datetime.timedelta) -> str: def imports(self, **kwargs: Any) -> FileImport: file_import = FileImport(self.code_model) - file_import.add_import("datetime", ImportType.STDLIB) + if kwargs.get("serialize_namespace_type") != NamespaceType.TYPES_FILE: + file_import.add_import("datetime", ImportType.STDLIB) return file_import @property @@ -590,6 +609,11 @@ def serialization_type(self, **kwargs: Any) -> str: def docstring_type(self, **kwargs: Any) -> str: return "bytes" + def type_annotation(self, **kwargs: Any) -> str: + if kwargs.get("serialize_namespace_type") == NamespaceType.TYPES_FILE: + return "str" + return "bytes" + def get_declaration(self, value: str) -> str: return f'bytes("{value}", encoding="utf-8")' diff --git a/packages/http-client-python/generator/pygen/codegen/models/property.py b/packages/http-client-python/generator/pygen/codegen/models/property.py index ca18aba8ddf..d2a544c0511 100644 --- a/packages/http-client-python/generator/pygen/codegen/models/property.py +++ b/packages/http-client-python/generator/pygen/codegen/models/property.py @@ -10,7 +10,7 @@ from .enum_type import EnumType from .base import BaseType from .imports import FileImport, ImportType -from .utils import add_to_description, add_to_pylint_disable +from .utils import add_to_description, add_to_pylint_disable, NamespaceType if TYPE_CHECKING: from .code_model import CodeModel @@ -29,6 +29,7 @@ def __init__( self.client_name: str = self.yaml_data["clientName"] self.type = type self.optional: bool = self.yaml_data["optional"] + self.nullable: bool = self.yaml_data.get("nullable", False) self.readonly: bool = self.yaml_data.get("readonly", False) self.visibility: list[str] = self.yaml_data.get("visibility", []) self.is_polymorphic: bool = self.yaml_data.get("isPolymorphic", False) @@ -110,6 +111,12 @@ def type_annotation(self, *, is_operation_file: bool = False, **kwargs: Any) -> if self.is_base_discriminator: return "str" types_type_annotation = self.type.type_annotation(is_operation_file=is_operation_file, **kwargs) + serialize_namespace_type = kwargs.get("serialize_namespace_type") + # In TypedDict types.py, Optional means nullable (not "not required" — that's handled by Required/total=False) + if serialize_namespace_type == NamespaceType.TYPES_FILE: + if self.nullable: + return f"Optional[{types_type_annotation}]" + return types_type_annotation if (self.optional and self.client_default_value is None) or self.readonly: return f"Optional[{types_type_annotation}]" return types_type_annotation @@ -152,15 +159,21 @@ def imports(self, **kwargs) -> FileImport: if self.is_discriminator and isinstance(self.type, EnumType): return file_import file_import.merge(self.type.imports(**kwargs)) - if (self.optional and self.client_default_value is None) or self.readonly: + serialize_namespace_type = kwargs.get("serialize_namespace_type") + if serialize_namespace_type == NamespaceType.TYPES_FILE: + # In TypedDict types.py, Optional means nullable + if self.nullable: + file_import.add_submodule_import("typing", "Optional", ImportType.STDLIB) + elif (self.optional and self.client_default_value is None) or self.readonly: file_import.add_submodule_import("typing", "Optional", ImportType.STDLIB) if self.code_model.options["models-mode"] == "dpg": - serialize_namespace = kwargs.get("serialize_namespace", self.code_model.namespace) - file_import.add_submodule_import( - self.code_model.get_relative_import_path(serialize_namespace, module_name="_utils.model_base"), - "rest_discriminator" if self.is_discriminator else "rest_field", - ImportType.LOCAL, - ) + if serialize_namespace_type != NamespaceType.TYPES_FILE: + serialize_namespace = kwargs.get("serialize_namespace", self.code_model.namespace) + file_import.add_submodule_import( + self.code_model.get_relative_import_path(serialize_namespace, module_name="_utils.model_base"), + "rest_discriminator" if self.is_discriminator else "rest_field", + ImportType.LOCAL, + ) return file_import @classmethod diff --git a/packages/http-client-python/generator/pygen/codegen/models/response.py b/packages/http-client-python/generator/pygen/codegen/models/response.py index d37986146fd..99a90481319 100644 --- a/packages/http-client-python/generator/pygen/codegen/models/response.py +++ b/packages/http-client-python/generator/pygen/codegen/models/response.py @@ -68,7 +68,12 @@ def result_property(self) -> str: def get_polymorphic_subtypes(self, polymorphic_subtypes: list["ModelType"]) -> None: if self.type: - self.type.get_polymorphic_subtypes(polymorphic_subtypes) + if isinstance(self.type, CombinedType): + target = self.type.target_model_subtype((ModelType,)) + if target: + target.get_polymorphic_subtypes(polymorphic_subtypes) + else: + self.type.get_polymorphic_subtypes(polymorphic_subtypes) def get_json_template_representation(self) -> Any: if not self.type: @@ -95,6 +100,7 @@ def serialization_type(self, **kwargs: Any) -> str: def type_annotation(self, **kwargs: Any) -> str: if self.type: kwargs["is_operation_file"] = True + kwargs["is_response"] = True type_annotation = self.type.type_annotation(**kwargs) if self.nullable: return f"Optional[{type_annotation}]" @@ -102,11 +108,13 @@ def type_annotation(self, **kwargs: Any) -> str: return "None" def docstring_text(self, **kwargs: Any) -> str: + kwargs["is_response"] = True if self.nullable and self.type: return f"{self.type.docstring_text(**kwargs)} or None" return self.type.docstring_text(**kwargs) if self.type else "None" def docstring_type(self, **kwargs: Any) -> str: + kwargs["is_response"] = True if self.nullable and self.type: return f"{self.type.docstring_type(**kwargs)} or None" return self.type.docstring_type(**kwargs) if self.type else "None" @@ -121,7 +129,7 @@ def imports(self, **kwargs: Any) -> FileImport: serialize_namespace = kwargs.get("serialize_namespace", self.code_model.namespace) file_import.add_submodule_import( self.code_model.get_relative_import_path(serialize_namespace), - "_types", + "_unions", ImportType.LOCAL, TypingSection.TYPING, ) @@ -165,7 +173,12 @@ def __init__(self, *args, **kwargs) -> None: ) def get_polymorphic_subtypes(self, polymorphic_subtypes: list["ModelType"]) -> None: - return self.item_type.get_polymorphic_subtypes(polymorphic_subtypes) + if isinstance(self.item_type, CombinedType): + target = self.item_type.target_model_subtype((ModelType,)) + if target: + target.get_polymorphic_subtypes(polymorphic_subtypes) + else: + self.item_type.get_polymorphic_subtypes(polymorphic_subtypes) def get_json_template_representation(self) -> Any: return self.item_type.get_json_template_representation() diff --git a/packages/http-client-python/generator/pygen/codegen/models/utils.py b/packages/http-client-python/generator/pygen/codegen/models/utils.py index cc980eff4d2..dd462d1d5bd 100644 --- a/packages/http-client-python/generator/pygen/codegen/models/utils.py +++ b/packages/http-client-python/generator/pygen/codegen/models/utils.py @@ -41,14 +41,11 @@ class NamespaceType(str, Enum): OPERATION = "operation" CLIENT = "client" TYPES_FILE = "types_file" + UNIONS_FILE = "unions_file" LOCALS_LENGTH_LIMIT = 25 -REQUEST_BUILDER_BODY_VARIABLES_LENGTH = ( - 6 # how many body variables are present in a request builder -) +REQUEST_BUILDER_BODY_VARIABLES_LENGTH = 6 # how many body variables are present in a request builder -OPERATION_BODY_VARIABLES_LENGTH = ( - 14 # how many body variables are present in an operation -) +OPERATION_BODY_VARIABLES_LENGTH = 14 # how many body variables are present in an operation diff --git a/packages/http-client-python/generator/pygen/codegen/serializers/__init__.py b/packages/http-client-python/generator/pygen/codegen/serializers/__init__.py index b949242ba54..e2acbcf10b9 100644 --- a/packages/http-client-python/generator/pygen/codegen/serializers/__init__.py +++ b/packages/http-client-python/generator/pygen/codegen/serializers/__init__.py @@ -33,6 +33,7 @@ from .sample_serializer import SampleSerializer from .test_serializer import TestSerializer, TestGeneralSerializer from .types_serializer import TypesSerializer +from .unions_serializer import UnionsSerializer from ...utils import to_snake_case, VALID_PACKAGE_MODE from .utils import extract_sample_name, get_namespace_from_package_name, get_namespace_config, hash_file_import @@ -120,7 +121,7 @@ def keep_version_file(self) -> bool: # If parsing the version fails, we assume the version file is not valid and overwrite. return False - # pylint: disable=too-many-branches + # pylint: disable=too-many-branches,too-many-statements def serialize(self) -> None: # remove existing folders when generate from tsp if self.code_model.is_tsp and self.code_model.options.get("clear-output-folder"): @@ -200,12 +201,13 @@ def serialize(self) -> None: general_serializer.serialize_pkgutil_init_file(), ) - # _utils/py.typed/_types.py/_validation.py + # _utils/py.typed/_unions.py/types.py/_validation.py # is always put in top level namespace if self.code_model.is_top_namespace(client_namespace): self._serialize_and_write_top_level_folder(env=env, namespace=client_namespace) # add models folder if there are models in this namespace + is_typeddict_mode = self.code_model.options["models-mode"] == "typeddict" if ( self.code_model.has_non_json_models(client_namespace_type.models) or client_namespace_type.enums ) and self.code_model.options["models-mode"]: @@ -213,7 +215,27 @@ def serialize(self) -> None: env=env, namespace=client_namespace, models=client_namespace_type.models, - enums=client_namespace_type.enums, + enums=[] if is_typeddict_mode else client_namespace_type.enums, + ) + + # write types.py per namespace (alongside models/) + # Only generate types.py if at least one model/enum would be imported via types + # in operations (the model itself volunteers this via is_used_in_operations_via_types) + has_types_models = any( + m.is_used_in_operations_via_types for m in client_namespace_type.models if m.base != "json" + ) + has_types_enums = any(e.is_typeddict_mode for e in client_namespace_type.enums) + if has_types_models or has_types_enums: + generation_dir = self.code_model.get_generation_dir(client_namespace) + self.write_file( + generation_dir / Path("types.py"), + TypesSerializer( + code_model=self.code_model, + env=env, + client_namespace=client_namespace, + models=client_namespace_type.models, + enums=client_namespace_type.enums if is_typeddict_mode else None, + ).serialize(), ) if not self.code_model.options["models-mode"]: @@ -298,11 +320,19 @@ def _serialize_and_write_models_folder( ) -> None: # Write the models folder models_path = self.code_model.get_generation_dir(namespace) / "models" - serializer = DpgModelSerializer if self.code_model.options["models-mode"] == "dpg" else MsrestModelSerializer - if self.code_model.has_non_json_models(models): + models_mode = self.code_model.options["models-mode"] + if models_mode in ("dpg", "typeddict"): + serializer = DpgModelSerializer + else: + serializer = MsrestModelSerializer + # Filter out typed-dict-only models and typeddict copies — they only appear in types.py, not as model classes + class_models = [m for m in models if not m.is_typed_dict_only and m.base != "typeddict"] + if self.code_model.has_non_json_models(class_models): self.write_file( models_path / Path(f"{self.code_model.models_filename}.py"), - serializer(code_model=self.code_model, env=env, client_namespace=namespace, models=models).serialize(), + serializer( + code_model=self.code_model, env=env, client_namespace=namespace, models=class_models + ).serialize(), ) if enums: self.write_file( @@ -313,7 +343,7 @@ def _serialize_and_write_models_folder( ) self.write_file( models_path / Path("__init__.py"), - ModelInitSerializer(code_model=self.code_model, env=env, models=models, enums=enums).serialize(), + ModelInitSerializer(code_model=self.code_model, env=env, models=class_models, enums=enums).serialize(), ) self._keep_patch_file(models_path / Path("_patch.py"), env) @@ -442,7 +472,9 @@ def _serialize_client_and_config_files( # when there is client.py, there must be __init__.py self.write_file( generation_path / Path(f"{async_path}__init__.py"), - general_serializer.serialize_init_file([c for c in clients if c.has_operations]), + general_serializer.serialize_init_file( + [c for c in clients if c.has_operations], + ), ) # if there was a patch file before, we keep it @@ -519,11 +551,14 @@ def _serialize_and_write_top_level_folder(self, env: Environment, namespace: str general_serializer.serialize_validation_file(), ) - # write _types.py + # write _unions.py if self.code_model.named_unions: self.write_file( - generation_dir / Path("_types.py"), - TypesSerializer(code_model=self.code_model, env=env).serialize(), + generation_dir / Path("_unions.py"), + UnionsSerializer( + code_model=self.code_model, + env=env, + ).serialize(), ) # pylint: disable=line-too-long diff --git a/packages/http-client-python/generator/pygen/codegen/serializers/builder_serializer.py b/packages/http-client-python/generator/pygen/codegen/serializers/builder_serializer.py index c615baf827f..2a671d7b220 100644 --- a/packages/http-client-python/generator/pygen/codegen/serializers/builder_serializer.py +++ b/packages/http-client-python/generator/pygen/codegen/serializers/builder_serializer.py @@ -196,6 +196,16 @@ def is_json_model_type(parameters: ParameterListType) -> bool: ) +def _is_dpg_or_typeddict_body(body_param: BodyParameter) -> bool: + """Check if a body parameter is a DPG model or a CombinedType wrapping one.""" + body_type = body_param.type + if isinstance(body_type, DPGModelType): + return True + if isinstance(body_type, CombinedType): + return body_type.target_model_subtype((DPGModelType,)) is not None + return False + + class _BuilderBaseSerializer(Generic[BuilderType]): def __init__(self, code_model: CodeModel, async_mode: bool, client_namespace: str) -> None: self.code_model = code_model @@ -712,8 +722,20 @@ def _serialize_body_parameter(self, builder: OperationType) -> list[str]: f"_{body_kwarg_name} = self._serialize.body({body_param.client_name}, " f"'{serialization_type}'{is_xml_cmd}{serialization_ctxt_cmd})" ) + elif self.code_model.options["models-mode"] == "typeddict": + # TypedDict-only models are plain dicts — no serialization needed + create_body_call = f"_{body_kwarg_name} = {body_param.client_name}" elif self.code_model.options["models-mode"] == "dpg": - if json_serializable(body_param.default_content_type): + # Check if this is a typeddict-only model within dpg mode — skip serialization + body_model_type = body_param.type + if isinstance(body_model_type, CombinedType): + body_model_type = body_model_type.target_model_subtype((DPGModelType,)) + is_typeddict_only_body = isinstance(body_model_type, DPGModelType) and getattr( + body_model_type, "is_typed_dict_only", False + ) + if is_typeddict_only_body: + create_body_call = f"_{body_kwarg_name} = {body_param.client_name}" + elif json_serializable(body_param.default_content_type): if hasattr(body_param.type, "encode") and body_param.type.encode: # type: ignore create_body_call = ( f"_{body_kwarg_name} = json.dumps({body_param.client_name}, " @@ -791,9 +813,9 @@ def _initialize_overloads(self, builder: OperationType, is_paging: bool = False) overload.request_builder.parameters.body_parameter.client_name for overload in builder.overloads ] all_dpg_model_overloads = False - if self.code_model.options["models-mode"] == "dpg" and builder.overloads: + if self.code_model.options["models-mode"] in ("dpg", "typeddict") and builder.overloads: all_dpg_model_overloads = all( - isinstance(o.parameters.body_parameter.type, DPGModelType) for o in builder.overloads + _is_dpg_or_typeddict_body(o.parameters.body_parameter) for o in builder.overloads ) if not all_dpg_model_overloads: for v in sorted(set(client_names), key=client_names.index): @@ -1002,6 +1024,12 @@ def response_deserialization( # pylint: disable=too-many-statements elif self.code_model.options["models-mode"] == "dpg": if builder.has_stream_response: deserialize_code.append("deserialized = response.content") + elif isinstance(response.type, ModelType) and response.type.is_typed_dict_only: + # Typed-dict-only models skip deserialization — return raw JSON + deserialize_code.append("if response.content:") + deserialize_code.append(" deserialized = response.json()") + deserialize_code.append("else:") + deserialize_code.append(" deserialized = None") else: format_filed = ( f', format="{response.type.encode}"' @@ -1437,18 +1465,23 @@ def _extract_data_callback( # pylint: disable=too-many-statements,too-many-bran ) pylint_disable = "" if self.code_model.options["models-mode"] == "dpg": - item_type = builder.item_type.type_annotation( - is_operation_file=True, serialize_namespace=self.serialize_namespace - ) - pylint_disable = ( - " # pylint: disable=protected-access" if getattr(builder.item_type, "internal", False) else "" - ) - list_of_elem_deserialized = [ - "_deserialize(", - f"{item_type},{pylint_disable}", - f"deserialized{access},", - ")", - ] + is_item_typed_dict_only = isinstance(builder.item_type, ModelType) and builder.item_type.is_typed_dict_only + if is_item_typed_dict_only: + # Typed-dict-only models skip deserialization — return raw JSON items + list_of_elem_deserialized = [f"deserialized{access}"] + else: + item_type = builder.item_type.type_annotation( + is_operation_file=True, serialize_namespace=self.serialize_namespace + ) + pylint_disable = ( + " # pylint: disable=protected-access" if getattr(builder.item_type, "internal", False) else "" + ) + list_of_elem_deserialized = [ + "_deserialize(", + f"{item_type},{pylint_disable}", + f"deserialized{access},", + ")", + ] else: list_of_elem_deserialized = [f"deserialized{access}"] list_of_elem_deserialized_str = "\n ".join(list_of_elem_deserialized) diff --git a/packages/http-client-python/generator/pygen/codegen/serializers/model_init_serializer.py b/packages/http-client-python/generator/pygen/codegen/serializers/model_init_serializer.py index c1ae3e6f563..879d56475a5 100644 --- a/packages/http-client-python/generator/pygen/codegen/serializers/model_init_serializer.py +++ b/packages/http-client-python/generator/pygen/codegen/serializers/model_init_serializer.py @@ -32,7 +32,7 @@ def serialize(self) -> str: ", ".join(model_enum_name_intersection) ) ) - has_models = self.models + has_models = self.code_model.has_non_json_models(self.models) has_enums = self.enums template = self.env.get_template("model_init.py.jinja2") return template.render( diff --git a/packages/http-client-python/generator/pygen/codegen/serializers/model_serializer.py b/packages/http-client-python/generator/pygen/codegen/serializers/model_serializer.py index c7523c61feb..87a0b2cc4aa 100644 --- a/packages/http-client-python/generator/pygen/codegen/serializers/model_serializer.py +++ b/packages/http-client-python/generator/pygen/codegen/serializers/model_serializer.py @@ -3,7 +3,7 @@ # Licensed under the MIT License. See License.txt in the project root for # license information. # -------------------------------------------------------------------------- -from typing import Optional +from typing import Any, Optional from abc import ABC, abstractmethod from ..models import ModelType, Property, ConstantType, EnumValue, EnumType @@ -73,12 +73,19 @@ def _get_xml_deserializer_enum_type(prop: Property) -> Optional[EnumType]: return prop_type if isinstance(prop_type, EnumType) else None -def _documentation_string(prop: Property, description_keyword: str, docstring_type_keyword: str) -> list[str]: +def _documentation_string( + prop: Property, description_keyword: str, docstring_type_keyword: str, **kwargs: Any +) -> list[str]: retval: list[str] = [] sphinx_prefix = f":{description_keyword} {prop.client_name}:" description = prop.description(is_operation_file=False).replace("\\", "\\\\") retval.append(f"{sphinx_prefix} {description}" if description else sphinx_prefix) - retval.append(f":{docstring_type_keyword} {prop.client_name}: {prop.type.docstring_type()}") + # In the types file, use type_annotation (the serialized form) for docstrings + if kwargs.get("serialize_namespace_type") == NamespaceType.TYPES_FILE: + doc_type = prop.type.type_annotation(**kwargs) + else: + doc_type = prop.type.docstring_type(**kwargs) + retval.append(f":{docstring_type_keyword} {prop.client_name}: {doc_type}") return retval @@ -494,3 +501,120 @@ def global_pylint_disables(self) -> str: if final_result: return "# pylint: disable=" + ", ".join(final_result) return "" + + +class TypedDictModelSerializer(_ModelSerializer): + def _is_parent_discriminated_base(self, model: ModelType) -> bool: + """Check if any parent of this model is a discriminated base (has discriminated_subtypes).""" + return any(p.discriminated_subtypes for p in model.parents) + + def _reorder_models(self, models: list[ModelType]) -> list[ModelType]: + """Reorder so discriminated base Union aliases come after all their subtypes.""" + bases = [m for m in models if m.discriminated_subtypes] + non_bases = [m for m in models if not m.discriminated_subtypes] + return non_bases + bases + + def serialize(self) -> str: + template = self.env.get_template("model_container.py.jinja2") + return template.render( + code_model=self.code_model, + imports=FileImportSerializer(self.imports()), + str=str, + serializer=self, + models=self._reorder_models(self.models), + ) + + def imports(self) -> FileImport: + file_import = FileImport(self.code_model) + has_required = False + has_discriminated_union = False + for model in self.models: + if model.base == "json": + continue + if model.discriminated_subtypes: + has_discriminated_union = True + file_import.merge( + model.imports( + is_operation_file=False, + serialize_namespace=self.serialize_namespace, + serialize_namespace_type=NamespaceType.MODEL, + ) + ) + for prop in model.properties: + file_import.merge( + prop.imports( + serialize_namespace=self.serialize_namespace, + serialize_namespace_type=NamespaceType.MODEL, + called_by_property=True, + ) + ) + if not (prop.optional or prop.client_default_value is not None): + has_required = True + for parent in model.parents: + if parent.client_namespace != model.client_namespace and not parent.discriminated_subtypes: + file_import.add_submodule_import( + self.code_model.get_relative_import_path( + self.serialize_namespace, + self.code_model.get_imported_namespace_for_model(parent.client_namespace), + ), + parent.name, + ImportType.LOCAL, + ) + file_import.add_submodule_import("typing_extensions", "TypedDict", ImportType.STDLIB) + if has_required: + file_import.add_submodule_import("typing_extensions", "Required", ImportType.STDLIB) + if has_discriminated_union: + file_import.add_submodule_import("typing", "Union", ImportType.STDLIB) + return file_import + + def declare_model(self, model: ModelType) -> str: + # If the model's parent is a discriminated base, don't inherit from it + non_discriminated_parents = [p for p in model.parents if not p.discriminated_subtypes] + if non_discriminated_parents: + basename = ", ".join([m.name for m in non_discriminated_parents]) + return f"class {model.name}({basename}):{model.pylint_disable()}" + return f"class {model.name}(TypedDict, total=False):{model.pylint_disable()}" + + @staticmethod + def get_properties_to_declare(model: ModelType) -> list[Property]: + # Only exclude inherited properties from non-discriminated parents + non_discriminated_parents = [p for p in model.parents if not p.discriminated_subtypes] + if non_discriminated_parents: + parent_properties = [p for bm in non_discriminated_parents for p in bm.properties] + properties_to_declare = [ + p + for p in model.properties + if not any( + p.client_name == pp.client_name + and p.type_annotation() == pp.type_annotation() + and not p.is_base_discriminator + for pp in parent_properties + ) + ] + else: + properties_to_declare = model.properties + return properties_to_declare + + def declare_property(self, prop: Property) -> str: + type_annotation = prop.type_annotation(serialize_namespace=self.serialize_namespace) + is_optional = prop.optional or prop.client_default_value is not None + if is_optional: + return f"{prop.wire_name}: {type_annotation}" + return f"{prop.wire_name}: Required[{type_annotation}]" + + def initialize_properties(self, model: ModelType) -> list[str]: + return [] + + def need_init(self, model: ModelType) -> bool: + return False + + def discriminated_subtypes_union(self, model: ModelType) -> str: + subtypes = list(model.discriminated_subtypes.values()) + subtype_names = [s.name for s in subtypes] + return f"{model.name} = Union[{', '.join(subtype_names)}]" + + def is_discriminated_base(self, model: ModelType) -> bool: + return bool(model.discriminated_subtypes) + + def global_pylint_disables(self) -> str: + return "" diff --git a/packages/http-client-python/generator/pygen/codegen/serializers/types_serializer.py b/packages/http-client-python/generator/pygen/codegen/serializers/types_serializer.py index 749d8ca240c..5ee26ebf70c 100644 --- a/packages/http-client-python/generator/pygen/codegen/serializers/types_serializer.py +++ b/packages/http-client-python/generator/pygen/codegen/serializers/types_serializer.py @@ -3,34 +3,285 @@ # Licensed under the MIT License. See License.txt in the project root for # license information. # -------------------------------------------------------------------------- +import keyword +import re +from typing import Optional +from ..models import ModelType, CodeModel +from ..models.enum_type import EnumType from ..models.imports import FileImport, ImportType from ..models.utils import NamespaceType +from ..models.property import Property +from .model_serializer import _documentation_string from .import_serializer import FileImportSerializer from .base_serializer import BaseSerializer +# Python builtin type names that can be shadowed by TypedDict field wire_names. +# When a field name matches one of these, all references to that builtin in type +# annotations within the same class are qualified as builtins.X. +_BUILTIN_TYPE_NAMES = frozenset( + { + "int", + "str", + "float", + "bool", + "list", + "dict", + "tuple", + "set", + "bytes", + "type", + "object", + "complex", + "frozenset", + "bytearray", + "memoryview", + } +) + + +def _qualify_shadowed_builtins(annotation: str, shadowed: frozenset[str]) -> str: + """Replace bare builtin type references with builtins.X when shadowed by a field name.""" + if not shadowed: + return annotation + for name in shadowed: + annotation = re.sub(rf"\b{name}\b", f"builtins.{name}", annotation) + return annotation + class TypesSerializer(BaseSerializer): + def __init__( + self, + code_model: CodeModel, + env, + client_namespace: Optional[str] = None, + models: Optional[list[ModelType]] = None, + enums: Optional[list["EnumType"]] = None, + ): + super().__init__(code_model=code_model, env=env, client_namespace=client_namespace) + self._models = models or [] + self._enums = enums or [] + + @property + def literal_enums(self) -> list[EnumType]: + """Enums to render as Literal type aliases in typeddict mode.""" + return sorted(self._enums) + + def declare_literal_enum(self, enum: EnumType) -> str: + """Generate a Literal type alias for an enum, e.g. MyColor = Literal["red", "blue"].""" + values = [enum.get_declaration(v.value) for v in enum.values] + return f"{enum.name} = Literal[{', '.join(values)}]" + + @property + def typeddict_models(self) -> list[ModelType]: + """Models that should be rendered as TypedDicts (excluding discriminated bases which become unions). + + When both a dpg model and its typeddict copy exist (same crossLanguageDefinitionId), + prefer the dpg model (it already renders as a TypedDict in types.py) and skip the copy. + """ + candidates = [m for m in self._models if m.base != "json" and not m.discriminated_subtypes] + seen_ids: dict[str, "ModelType"] = {} + result: list["ModelType"] = [] + for m in candidates: + clid = m.yaml_data.get("crossLanguageDefinitionId") + if clid and clid in seen_ids: + # Prefer the dpg model over the typeddict copy + if m.base == "dpg" and seen_ids[clid].base == "typeddict": + # Replace the typeddict copy with the dpg model + result = [r if r is not seen_ids[clid] else m for r in result] + seen_ids[clid] = m + # Otherwise skip this duplicate + continue + if clid: + seen_ids[clid] = m + result.append(m) + return result + + @property + def discriminated_base_models(self) -> list[ModelType]: + """Discriminated base models that become Union type aliases in types.py. + + Topologically sorted so that nested discriminated bases (e.g. Shark) + are defined before their parents (e.g. Fish = Union[Salmon, Shark]). + """ + bases = [m for m in self._models if m.base != "json" and m.discriminated_subtypes] + base_names = {m.name for m in bases} + sorted_bases: list[ModelType] = [] + visited: set[str] = set() + + def visit(model: ModelType) -> None: + if model.name in visited: + return + visited.add(model.name) + for subtype in model.discriminated_subtypes.values(): + if subtype.name in base_names: + visit(subtype) + sorted_bases.append(model) + + for m in bases: + visit(m) + return sorted_bases + + def discriminated_subtypes_union(self, model: ModelType) -> str: + """Generate a Union alias for a discriminated base using TypedDict subtype names.""" + subtypes = list(model.discriminated_subtypes.values()) + subtype_names = [s.name for s in subtypes] + return f"{model.name} = Union[{', '.join(subtype_names)}]" + + @staticmethod + def has_keyword_wire_names(model: ModelType) -> bool: + """Whether any property wire_name is a Python keyword or requires functional TypedDict form.""" + return any(keyword.iskeyword(p.wire_name) or not p.wire_name.isidentifier() for p in model.properties) + + @staticmethod + def get_shadowed_builtins(model: ModelType) -> frozenset[str]: + """Return the set of builtin type names shadowed by property wire_names in this model. + + Only includes a builtin if it is both used as a wire_name AND referenced + in a type annotation within the same model (otherwise no shadowing occurs). + """ + wire_builtins = {p.wire_name for p in model.properties if p.wire_name in _BUILTIN_TYPE_NAMES} + if not wire_builtins: + return frozenset() + # Check which of these builtins actually appear in type annotations + used = set() + for prop in model.properties: + annotation = prop.type_annotation() + for name in wire_builtins: + if re.search(rf"\b{name}\b", annotation): + used.add(name) + return frozenset(used) + def imports(self) -> FileImport: file_import = FileImport(self.code_model) - if self.code_model.named_unions: - file_import.add_submodule_import( - "typing", - "Union", - ImportType.STDLIB, - ) - for nu in self.code_model.named_unions: - file_import.merge( - nu.imports( - serialize_namespace=self.serialize_namespace, serialize_namespace_type=NamespaceType.TYPES_FILE + + literal_enums = self.literal_enums + if literal_enums: + file_import.add_submodule_import("typing", "Literal", ImportType.STDLIB) + + td_models = self.typeddict_models + if td_models or self.discriminated_base_models: + if td_models: + file_import.add_submodule_import("typing_extensions", "TypedDict", ImportType.STDLIB) + if self.discriminated_base_models: + file_import.add_submodule_import("typing", "Union", ImportType.STDLIB) + has_required = False + needs_builtins = False + for model in td_models: + file_import.merge( + model.imports( + is_operation_file=False, + serialize_namespace=self.serialize_namespace, + serialize_namespace_type=NamespaceType.TYPES_FILE, + ) ) - ) + for prop in model.properties: + file_import.merge( + prop.imports( + serialize_namespace=self.serialize_namespace, + serialize_namespace_type=NamespaceType.TYPES_FILE, + called_by_property=True, + ) + ) + if not (prop.optional or prop.client_default_value is not None): + has_required = True + if self.get_shadowed_builtins(model): + needs_builtins = True + for parent in model.parents: + if parent.client_namespace != model.client_namespace and not parent.discriminated_subtypes: + # Import parent class from sibling namespace's types module + file_import.add_submodule_import( + self.code_model.get_relative_import_path( + self.serialize_namespace, + parent.client_namespace, + module_name="types", + ), + parent.name, + ImportType.LOCAL, + ) + if has_required: + file_import.add_submodule_import("typing_extensions", "Required", ImportType.STDLIB) + if needs_builtins: + file_import.add_import("builtins", ImportType.STDLIB) return file_import + def declare_model(self, model: ModelType) -> str: + """Generate the class declaration or functional form for a TypedDict model. + + Uses functional form when any property wire_name is a Python keyword + (e.g. 'and', 'class') since keywords can't be identifiers in class bodies. + """ + if self.has_keyword_wire_names(model): + return "" # functional form is rendered separately + non_discriminated_parents = [p for p in model.parents if not p.discriminated_subtypes] + if non_discriminated_parents: + basename = ", ".join([m.name for m in non_discriminated_parents]) + return f"class {model.name}({basename}):{model.pylint_disable()}" + return f"class {model.name}(TypedDict, total=False):{model.pylint_disable()}" + + def declare_functional_model(self, model: ModelType) -> str: + """Generate a functional-form TypedDict for models with keyword wire_names. + + Functional form is required when any field name is a Python keyword. + All fields (including inherited) are included since functional form + can't specify a base class. + """ + shadowed = self.get_shadowed_builtins(model) + entries: list[str] = [] + for prop in model.properties: + type_annotation = prop.type_annotation( + serialize_namespace=self.serialize_namespace, + serialize_namespace_type=NamespaceType.TYPES_FILE, + ) + type_annotation = _qualify_shadowed_builtins(type_annotation, shadowed) + is_optional = prop.optional or prop.client_default_value is not None + if is_optional: + entries.append(f' "{prop.wire_name}": {type_annotation},') + else: + entries.append(f' "{prop.wire_name}": Required[{type_annotation}],') + fields = "\n".join(entries) + return f'{model.name} = TypedDict("{model.name}", {{\n{fields}\n}}, total=False)' + + @staticmethod + def get_properties_to_declare(model: ModelType) -> list[Property]: + if TypesSerializer.has_keyword_wire_names(model): + return [] # functional form handles all properties + non_discriminated_parents = [p for p in model.parents if not p.discriminated_subtypes] + if non_discriminated_parents: + parent_properties = [p for bm in non_discriminated_parents for p in bm.properties] + return [ + p + for p in model.properties + if not any( + p.client_name == pp.client_name + and p.type_annotation() == pp.type_annotation() + and not p.is_base_discriminator + for pp in parent_properties + ) + ] + return list(model.properties) + + def declare_property(self, prop: Property, shadowed_builtins: frozenset[str]) -> str: + type_annotation = prop.type_annotation( + serialize_namespace=self.serialize_namespace, + serialize_namespace_type=NamespaceType.TYPES_FILE, + ) + type_annotation = _qualify_shadowed_builtins(type_annotation, shadowed_builtins) + is_optional = prop.optional or prop.client_default_value is not None + if is_optional: + return f"{prop.wire_name}: {type_annotation}" + return f"{prop.wire_name}: Required[{type_annotation}]" + + @staticmethod + def variable_documentation_string(prop: Property) -> list[str]: + return _documentation_string(prop, "ivar", "vartype", serialize_namespace_type=NamespaceType.TYPES_FILE) + def serialize(self) -> str: - # Generate the models template = self.env.get_template("types.py.jinja2") return template.render( code_model=self.code_model, imports=FileImportSerializer(self.imports()), serializer=self, + literal_enums=self.literal_enums, + models=self.typeddict_models, + discriminated_bases=self.discriminated_base_models, ) diff --git a/packages/http-client-python/generator/pygen/codegen/serializers/unions_serializer.py b/packages/http-client-python/generator/pygen/codegen/serializers/unions_serializer.py new file mode 100644 index 00000000000..28d606ccf01 --- /dev/null +++ b/packages/http-client-python/generator/pygen/codegen/serializers/unions_serializer.py @@ -0,0 +1,44 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- +from ..models import CodeModel +from ..models.imports import FileImport, ImportType +from ..models.utils import NamespaceType +from .import_serializer import FileImportSerializer +from .base_serializer import BaseSerializer + + +class UnionsSerializer(BaseSerializer): + def __init__( + self, + code_model: CodeModel, + env, + ): + super().__init__(code_model=code_model, env=env) + + def imports(self) -> FileImport: + file_import = FileImport(self.code_model) + if self.code_model.named_unions: + file_import.add_submodule_import( + "typing", + "Union", + ImportType.STDLIB, + ) + for nu in self.code_model.named_unions: + file_import.merge( + nu.imports( + serialize_namespace=self.serialize_namespace, + serialize_namespace_type=NamespaceType.UNIONS_FILE, + ) + ) + return file_import + + def serialize(self) -> str: + template = self.env.get_template("unions.py.jinja2") + return template.render( + code_model=self.code_model, + imports=FileImportSerializer(self.imports()), + serializer=self, + ) diff --git a/packages/http-client-python/generator/pygen/codegen/templates/model_container.py.jinja2 b/packages/http-client-python/generator/pygen/codegen/templates/model_container.py.jinja2 index dc98f999c45..d91228cbdd8 100644 --- a/packages/http-client-python/generator/pygen/codegen/templates/model_container.py.jinja2 +++ b/packages/http-client-python/generator/pygen/codegen/templates/model_container.py.jinja2 @@ -9,7 +9,7 @@ {{ imports }} {% for model in models %} -{% if model.base == "dpg" %} +{% if model.base == "dpg" or model.base == "typeddict" %} {% include "model_dpg.py.jinja2" %} {% elif model.base == "msrest" %} {% include "model_msrest.py.jinja2" %} diff --git a/packages/http-client-python/generator/pygen/codegen/templates/types.py.jinja2 b/packages/http-client-python/generator/pygen/codegen/templates/types.py.jinja2 index 19435f14e88..88dd962c5b9 100644 --- a/packages/http-client-python/generator/pygen/codegen/templates/types.py.jinja2 +++ b/packages/http-client-python/generator/pygen/codegen/templates/types.py.jinja2 @@ -4,6 +4,55 @@ {% endif %} {{ imports }} -{% for nu in code_model.named_unions %} -{{nu.name}} = {{nu.type_definition()}} +{% import 'operation_tools.jinja2' as op_tools %} +{% import "macros.jinja2" as macros %} +{% for enum in literal_enums %} + +{{ serializer.declare_literal_enum(enum) }} +{% if enum.yaml_data.get("description") %} +"""{{ op_tools.wrap_string(enum.yaml_data["description"], "\n") }}""" +{% endif %} +{% endfor %} +{% for model in models %} +{% if serializer.has_keyword_wire_names(model) %} + + +{{ serializer.declare_functional_model(model) }} +{{ model.name }}.__doc__ = """{{ op_tools.wrap_string(model.description(is_operation_file=False), "\n") }} + +{% if model.properties != None %} + {% for p in model.properties %} + {% for line in serializer.variable_documentation_string(p) %} +{{ macros.wrap_model_string(line, '\n') -}} + {% endfor %} + {% endfor %} +{% endif %} +""" +{% else %} + + +{{ serializer.declare_model(model) }} + """{{ op_tools.wrap_string(model.description(is_operation_file=False), "\n ") }} + + {% if model.properties != None %} + {% for p in model.properties %} + {% for line in serializer.variable_documentation_string(p) %} + {{ macros.wrap_model_string(line, '\n ') -}} + {% endfor %} + {% endfor %} + {% endif %} + """ + + {% set shadowed = serializer.get_shadowed_builtins(model) %} + {% for p in serializer.get_properties_to_declare(model)%} + {{ serializer.declare_property(p, shadowed) }} + {% set prop_description = p.description(is_operation_file=False).replace('"', '\\"') %} + {% if prop_description %} + """{{ macros.wrap_model_string(prop_description, '\n ', '\"\"\"') -}} + {% endif %} + {% endfor %} +{% endif %} +{% endfor %} +{% for model in discriminated_bases %} +{{ serializer.discriminated_subtypes_union(model) }} {% endfor %} diff --git a/packages/http-client-python/generator/pygen/codegen/templates/unions.py.jinja2 b/packages/http-client-python/generator/pygen/codegen/templates/unions.py.jinja2 new file mode 100644 index 00000000000..19435f14e88 --- /dev/null +++ b/packages/http-client-python/generator/pygen/codegen/templates/unions.py.jinja2 @@ -0,0 +1,9 @@ +# coding=utf-8 +{% if code_model.license_header %} +{{ code_model.license_header }} +{% endif %} + +{{ imports }} +{% for nu in code_model.named_unions %} +{{nu.name}} = {{nu.type_definition()}} +{% endfor %} diff --git a/packages/http-client-python/generator/pygen/preprocess/__init__.py b/packages/http-client-python/generator/pygen/preprocess/__init__.py index 8a2e30cdc87..780c0462cc8 100644 --- a/packages/http-client-python/generator/pygen/preprocess/__init__.py +++ b/packages/http-client-python/generator/pygen/preprocess/__init__.py @@ -36,18 +36,14 @@ def update_overload_section( if overload_s.get("type"): overload_s["type"] = original_s["type"] if overload_s.get("headers"): - for overload_h, original_h in zip( - overload_s["headers"], original_s["headers"] - ): + for overload_h, original_h in zip(overload_s["headers"], original_s["headers"]): if overload_h.get("type"): overload_h["type"] = original_h["type"] except KeyError as exc: raise ValueError(overload["name"]) from exc -def add_overload( - yaml_data: dict[str, Any], body_type: dict[str, Any], for_flatten_params=False -): +def add_overload(yaml_data: dict[str, Any], body_type: dict[str, Any], for_flatten_params=False): overload = copy.deepcopy(yaml_data) overload["isOverload"] = True overload["bodyParameter"]["type"] = body_type @@ -59,9 +55,7 @@ def add_overload( if for_flatten_params: overload["bodyParameter"]["flattened"] = True else: - overload["parameters"] = [ - p for p in overload["parameters"] if not p.get("inFlattenedBody") - ] + overload["parameters"] = [p for p in overload["parameters"] if not p.get("inFlattenedBody")] # for yaml sync, we need to make sure all of the responses, parameters, and exceptions' types have the same yaml id for overload_p, original_p in zip(overload["parameters"], yaml_data["parameters"]): overload_p["type"] = original_p["type"] @@ -69,9 +63,7 @@ def add_overload( update_overload_section(overload, yaml_data, "exceptions") # update content type to be an overloads content type - content_type_param = next( - p for p in overload["parameters"] if p["wireName"].lower() == "content-type" - ) + content_type_param = next(p for p in overload["parameters"] if p["wireName"].lower() == "content-type") content_type_param["inOverload"] = True content_type_param["inDocstring"] = True body_type_description = get_body_type_for_description(overload["bodyParameter"]) @@ -94,25 +86,18 @@ def add_overloads_for_body_param(yaml_data: dict[str, Any]) -> None: body_parameter = yaml_data["bodyParameter"] if not ( body_parameter["type"]["type"] == "combined" - and len(yaml_data["bodyParameter"]["type"]["types"]) - > len(yaml_data["overloads"]) + and len(yaml_data["bodyParameter"]["type"]["types"]) > len(yaml_data["overloads"]) ): return for body_type in body_parameter["type"]["types"]: - if any( - o - for o in yaml_data["overloads"] - if id(o["bodyParameter"]["type"]) == id(body_type) - ): + if any(o for o in yaml_data["overloads"] if id(o["bodyParameter"]["type"]) == id(body_type)): continue if body_type.get("type") == "model" and body_type.get("base") == "json": - yaml_data["overloads"].append( - add_overload(yaml_data, body_type, for_flatten_params=True) - ) + yaml_data["overloads"].append(add_overload(yaml_data, body_type, for_flatten_params=True)) + # Skip single-body JSON overload; the TypedDict overload replaces it + continue yaml_data["overloads"].append(add_overload(yaml_data, body_type)) - content_type_param = next( - p for p in yaml_data["parameters"] if p["wireName"].lower() == "content-type" - ) + content_type_param = next(p for p in yaml_data["parameters"] if p["wireName"].lower() == "content-type") content_type_param["inOverload"] = False content_type_param["inOverridden"] = True content_type_param["inDocstring"] = True @@ -122,19 +107,13 @@ def add_overloads_for_body_param(yaml_data: dict[str, Any]) -> None: content_type_param["optional"] = True -def update_description( - description: Optional[str], default_description: str = "" -) -> str: +def update_description(description: Optional[str], default_description: str = "") -> str: if not description: description = default_description description.rstrip(" ") # Don't append a trailing period when the description ends with a code block: the # period would land inside the rendered literal block (e.g. "]." ) and break Sphinx. - if ( - description - and description[-1] != "." - and not description_ends_with_code_block(description) - ): + if description and description[-1] != "." and not description_ends_with_code_block(description): description += "." return description @@ -197,9 +176,7 @@ def _get_etag_role(parameter: dict[str, Any]) -> Optional[str]: return parameter.get("etagRole") -def _pick_etag_slot( - candidates: list[dict[str, Any]], standard_wire_name: str -) -> Optional[dict[str, Any]]: +def _pick_etag_slot(candidates: list[dict[str, Any]], standard_wire_name: str) -> Optional[dict[str, Any]]: """Choose which etag-typed header should be promoted to the etag/match_condition slot. When more than one etag-typed header is present in an operation, prefer the @@ -227,25 +204,16 @@ def _resolve_etag_pair( Returns (property_if_match, property_if_none_match) — both None when there are no etag candidates. """ - property_if_match = _pick_etag_slot( - if_match_candidates, STANDARD_IF_MATCH_WIRE_NAME - ) - property_if_none_match = _pick_etag_slot( - if_none_match_candidates, STANDARD_IF_NONE_MATCH_WIRE_NAME - ) + property_if_match = _pick_etag_slot(if_match_candidates, STANDARD_IF_MATCH_WIRE_NAME) + property_if_none_match = _pick_etag_slot(if_none_match_candidates, STANDARD_IF_NONE_MATCH_WIRE_NAME) # Ensure the promoted pair come from the same family. When one slot is # standard and the other custom (cross-family), replace the custom slot # with a synthetic standard partner. Also synthesize the missing partner # when only one side is present. if property_if_match and property_if_none_match: - match_is_std = ( - get_wire_name_lower(property_if_match) == STANDARD_IF_MATCH_WIRE_NAME - ) - none_match_is_std = ( - get_wire_name_lower(property_if_none_match) - == STANDARD_IF_NONE_MATCH_WIRE_NAME - ) + match_is_std = get_wire_name_lower(property_if_match) == STANDARD_IF_MATCH_WIRE_NAME + none_match_is_std = get_wire_name_lower(property_if_none_match) == STANDARD_IF_NONE_MATCH_WIRE_NAME if match_is_std and not none_match_is_std: property_if_none_match = property_if_match.copy() property_if_none_match["wireName"] = STANDARD_IF_NONE_MATCH_WIRE_NAME @@ -293,15 +261,11 @@ def _process_operation_etag_headers( elif role == "ifNoneMatch": if_none_match_candidates.append(p) - property_if_match, property_if_none_match = _resolve_etag_pair( - if_match_candidates, if_none_match_candidates - ) + property_if_match, property_if_none_match = _resolve_etag_pair(if_match_candidates, if_none_match_candidates) if property_if_match and property_if_none_match: etag_params = {id(property_if_match), id(property_if_none_match)} - operation["parameters"] = [ - item for item in operation["parameters"] if id(item) not in etag_params - ] + [ + operation["parameters"] = [item for item in operation["parameters"] if id(item) not in etag_params] + [ property_if_match, property_if_none_match, ] @@ -320,9 +284,7 @@ def has_json_content_type(yaml_data: dict[str, Any]) -> bool: def has_multi_part_content_type(yaml_data: dict[str, Any]) -> bool: - return any( - ct for ct in yaml_data.get("contentTypes", []) if ct == "multipart/form-data" - ) + return any(ct for ct in yaml_data.get("contentTypes", []) if ct == "multipart/form-data") class PreProcessPlugin(YamlUpdatePlugin): @@ -344,6 +306,44 @@ def models_mode(self) -> Optional[str]: def is_tsp(self) -> bool: return self.options.get("tsp_file", False) + @staticmethod + def _find_existing_typeddict(code_model: dict[str, Any], cross_lang_id: Optional[str]) -> Optional[dict[str, Any]]: + """Find an existing typeddict copy with the given crossLanguageDefinitionId.""" + if not cross_lang_id: + return None + return next( + ( + t + for t in code_model["types"] + if t.get("type") == "model" + and t.get("base") == "typeddict" + and t.get("crossLanguageDefinitionId") == cross_lang_id + ), + None, + ) + + @staticmethod + def _insert_typeddict_overload( + code_model: dict[str, Any], + body_parameter: dict[str, Any], + source: dict[str, Any], + origin_type: str, + existing_td: Optional[dict[str, Any]], + ) -> None: + """Insert a typeddict type into the body parameter's combined types.""" + if origin_type == "model": + td_type = existing_td or {**source, "base": "typeddict"} + body_parameter["type"]["types"].insert(1, td_type) + if not existing_td: + code_model["types"].append(td_type) + else: + td_list_or_dict = copy.deepcopy(body_parameter["type"]["types"][0]) + td_elem = existing_td or {**source, "base": "typeddict"} + td_list_or_dict["elementType"] = td_elem + body_parameter["type"]["types"].insert(1, td_list_or_dict) + if not existing_td: + code_model["types"].append(td_elem) + def add_body_param_type( self, code_model: dict[str, Any], @@ -353,54 +353,73 @@ def add_body_param_type( if ( # pylint: disable=too-many-boolean-expressions body_parameter and body_parameter["type"]["type"] in ("model", "dict", "list") - and ( - has_json_content_type(body_parameter) - or (self.is_tsp and has_multi_part_content_type(body_parameter)) - ) + and (has_json_content_type(body_parameter) or (self.is_tsp and has_multi_part_content_type(body_parameter))) and not body_parameter["type"].get("xmlMetadata") and not any(t for t in ["flattened", "groupedBy"] if body_parameter.get(t)) ): origin_type = body_parameter["type"]["type"] model_type = ( - body_parameter["type"] - if origin_type == "model" - else body_parameter["type"].get("elementType", {}) + body_parameter["type"] if origin_type == "model" else body_parameter["type"].get("elementType", {}) ) is_dpg_model = model_type.get("base") == "dpg" + is_json_model = model_type.get("base") == "json" + is_typeddict_only = self.options["models-mode"] == "typeddict" + body_parameter["type"] = { "type": "combined", "types": [body_parameter["type"]], } - # don't add binary overload for multipart content type - if not (self.is_tsp and has_multi_part_content_type(body_parameter)): + # don't add binary overload for multipart content type or typeddict-only mode + if not (self.is_tsp and has_multi_part_content_type(body_parameter)) and not is_typeddict_only: body_parameter["type"]["types"].append(KNOWN_TYPES["binary"]) + # Add typeddict overload for non-spread dpg models if self.options["models-mode"] == "dpg" and is_dpg_model: - if origin_type == "model": - body_parameter["type"]["types"].insert(1, KNOWN_TYPES["any-object"]) - else: - # dict or list - # copy the original dict / list type - any_obj_list_or_dict = copy.deepcopy( - body_parameter["type"]["types"][0] + cross_lang_id = model_type.get("crossLanguageDefinitionId") + existing_td = self._find_existing_typeddict(code_model, cross_lang_id) + self._insert_typeddict_overload(code_model, body_parameter, model_type, origin_type, existing_td) + + # For spread bodies (json base), add a typeddict overload that references + # the original model. This replaces the JSON single-body overload. + if is_json_model: + cross_lang_id = model_type.get("crossLanguageDefinitionId") + original = None + if cross_lang_id: + original = next( + ( + t + for t in code_model["types"] + if t.get("type") == "model" + and t.get("crossLanguageDefinitionId") == cross_lang_id + and t is not model_type + ), + None, ) - any_obj_list_or_dict["elementType"] = KNOWN_TYPES["any-object"] - body_parameter["type"]["types"].insert(1, any_obj_list_or_dict) + + if is_typeddict_only and original: + # In typeddict-only mode, the original dpg model already renders + # as a TypedDict — reference it directly, no copy needed. + if origin_type == "model": + body_parameter["type"]["types"].insert(1, original) + else: + td_list_or_dict = copy.deepcopy(body_parameter["type"]["types"][0]) + td_list_or_dict["elementType"] = original + body_parameter["type"]["types"].insert(1, td_list_or_dict) + else: + source = original or model_type + existing_td = self._find_existing_typeddict(code_model, cross_lang_id) + self._insert_typeddict_overload(code_model, body_parameter, source, origin_type, existing_td) + code_model["types"].append(body_parameter["type"]) - def pad_reserved_words( - self, name: str, pad_type: PadType, yaml_type: dict[str, Any] - ) -> str: + def pad_reserved_words(self, name: str, pad_type: PadType, yaml_type: dict[str, Any]) -> str: # we want to pad hidden variables as well if not name: # we'll pass in empty operation groups sometime etc. return name if self.is_tsp: - reserved_words = { - k: (v + TSP_RESERVED_WORDS.get(k, [])) - for k, v in RESERVED_WORDS.items() - } + reserved_words = {k: (v + TSP_RESERVED_WORDS.get(k, [])) for k, v in RESERVED_WORDS.items()} else: reserved_words = RESERVED_WORDS name = pad_special_chars(name) @@ -417,24 +436,18 @@ def pad_reserved_words( def update_types(self, yaml_data: list[dict[str, Any]]) -> None: for type in yaml_data: for property in type.get("properties", []): - property["description"] = update_description( - property.get("description", "") - ) + property["description"] = update_description(property.get("description", "")) if not property.get("isExactName", False): property["clientName"] = self.pad_reserved_words( property["clientName"].lower(), PadType.PROPERTY, property ) add_redefined_builtin_info(property["clientName"], property) if type.get("name"): - pad_type = ( - PadType.MODEL if type["type"] == "model" else PadType.ENUM_CLASS - ) + pad_type = PadType.MODEL if type["type"] == "model" else PadType.ENUM_CLASS if type["type"] != "enumvalue": name = self.pad_reserved_words(type["name"], pad_type, type) type["name"] = name[0].upper() + name[1:] - type["description"] = update_description( - type.get("description", ""), type["name"] - ) + type["description"] = update_description(type.get("description", ""), type["name"]) type["snakeCaseName"] = to_snake_case(type["name"]) if type.get("values"): # we're enums - enum values are UPPER_CASE so no padding needed for reserved words @@ -452,24 +465,16 @@ def update_types(self, yaml_data: list[dict[str, Any]]) -> None: yaml_data.append(CLOUD_SETTING["type"]) # type: ignore def update_client(self, yaml_data: dict[str, Any]) -> None: - yaml_data["description"] = update_description( - yaml_data["description"], default_description=yaml_data["name"] - ) + yaml_data["description"] = update_description(yaml_data["description"], default_description=yaml_data["name"]) yaml_data["legacyFilename"] = to_snake_case(yaml_data["name"].replace(" ", "_")) parameters = yaml_data["parameters"] for parameter in parameters: self.update_parameter(parameter) if parameter["clientName"] == "credential": policy = parameter["type"].get("policy") - if ( - policy - and policy["type"] == "BearerTokenCredentialPolicy" - and self.azure_arm - ): + if policy and policy["type"] == "BearerTokenCredentialPolicy" and self.azure_arm: policy["type"] = "ARMChallengeAuthenticationPolicy" - policy["credentialScopes"] = [ - "https://management.azure.com/.default" - ] + policy["credentialScopes"] = ["https://management.azure.com/.default"] if ( (not self.version_tolerant or self.azure_arm) and parameters @@ -489,9 +494,7 @@ def update_client(self, yaml_data: dict[str, Any]) -> None: if self.azure_arm and yaml_data["parameters"]: yaml_data["parameters"].append(CLOUD_SETTING) - def get_operation_updater( - self, yaml_data: dict[str, Any] - ) -> Callable[[dict[str, Any], dict[str, Any]], None]: + def get_operation_updater(self, yaml_data: dict[str, Any]) -> Callable[[dict[str, Any], dict[str, Any]], None]: if yaml_data["discriminator"] == "lropaging": return self.update_lro_paging_operation if yaml_data["discriminator"] == "lro": @@ -503,8 +506,7 @@ def get_operation_updater( def update_parameter(self, yaml_data: dict[str, Any]) -> None: yaml_data["description"] = update_description(yaml_data.get("description", "")) if not yaml_data.get("isExactName", False) and not ( - yaml_data["location"] == "header" - and yaml_data["clientName"] in ("content_type", "accept") + yaml_data["location"] == "header" and yaml_data["clientName"] in ("content_type", "accept") ): yaml_data["clientName"] = self.pad_reserved_words( yaml_data["clientName"].lower(), PadType.PARAMETER, yaml_data @@ -520,16 +522,13 @@ def update_parameter(self, yaml_data: dict[str, Any]) -> None: prop: ( param_name if prop in exact_name_props - else self.pad_reserved_words( - param_name, PadType.PARAMETER, yaml_data - ).lower() + else self.pad_reserved_words(param_name, PadType.PARAMETER, yaml_data).lower() ) for prop, param_name in yaml_data["propertyToParameterName"].items() } wire_name_lower = (yaml_data.get("wireName") or "").lower() if yaml_data["location"] == "header" and ( - wire_name_lower in HEADERS_HIDE_IN_METHOD - or yaml_data.get("clientDefaultValue") == "multipart/form-data" + wire_name_lower in HEADERS_HIDE_IN_METHOD or yaml_data.get("clientDefaultValue") == "multipart/form-data" ): yaml_data["hideInMethod"] = True if self.version_tolerant and yaml_data["location"] == "header": @@ -538,10 +537,7 @@ def update_parameter(self, yaml_data: dict[str, Any]) -> None: headers_convert(yaml_data, ETAG_MATCH_DATA) elif role == "ifNoneMatch": headers_convert(yaml_data, ETAG_NONE_MATCH_DATA) - if ( - wire_name_lower in ["$host", "content-type", "accept"] - and yaml_data["type"]["type"] == "constant" - ): + if wire_name_lower in ["$host", "content-type", "accept"] and yaml_data["type"]["type"] == "constant": yaml_data["clientDefaultValue"] = yaml_data["type"]["value"] def update_operation( @@ -551,17 +547,13 @@ def update_operation( *, is_overload: bool = False, ) -> None: - yaml_data["groupName"] = self.pad_reserved_words( - yaml_data["groupName"], PadType.OPERATION_GROUP, yaml_data - ) + yaml_data["groupName"] = self.pad_reserved_words(yaml_data["groupName"], PadType.OPERATION_GROUP, yaml_data) yaml_data["groupName"] = to_snake_case(yaml_data["groupName"]) if yaml_data.get("isExactName", False): # exact() client name: keep the operation name as-is without lowercasing, # snake-casing, or padding reserved words. if yaml_data.get("isLroInitialOperation") is True: - yaml_data["name"] = ( - "_" + extract_original_name(yaml_data["name"]) + "_initial" - ) + yaml_data["name"] = "_" + extract_original_name(yaml_data["name"]) + "_initial" else: yaml_data["name"] = yaml_data["name"].lower() if yaml_data.get("isLroInitialOperation") is True: @@ -575,12 +567,8 @@ def update_operation( + "_initial" ) else: - yaml_data["name"] = self.pad_reserved_words( - yaml_data["name"], PadType.METHOD, yaml_data - ) - yaml_data["description"] = update_description( - yaml_data["description"], yaml_data["name"] - ) + yaml_data["name"] = self.pad_reserved_words(yaml_data["name"], PadType.METHOD, yaml_data) + yaml_data["description"] = update_description(yaml_data["description"], yaml_data["name"]) yaml_data["summary"] = update_description(yaml_data.get("summary", "")) body_parameter = yaml_data.get("bodyParameter") for parameter in yaml_data["parameters"]: @@ -601,12 +589,8 @@ def update_operation( def _update_lro_operation_helper(self, yaml_data: dict[str, Any]) -> None: for response in yaml_data.get("responses", []): response["discriminator"] = "lro" - response["pollerSync"] = ( - response.get("pollerSync") or "azure.core.polling.LROPoller" - ) - response["pollerAsync"] = ( - response.get("pollerAsync") or "azure.core.polling.AsyncLROPoller" - ) + response["pollerSync"] = response.get("pollerSync") or "azure.core.polling.LROPoller" + response["pollerAsync"] = response.get("pollerAsync") or "azure.core.polling.AsyncLROPoller" if not response.get("pollingMethodSync"): response["pollingMethodSync"] = ( "azure.mgmt.core.polling.arm_polling.ARMPolling" @@ -628,9 +612,7 @@ def update_lro_paging_operation( item_type: Optional[dict[str, Any]] = None, ) -> None: self.update_lro_operation(code_model, yaml_data, is_overload=is_overload) - self.update_paging_operation( - code_model, yaml_data, is_overload=is_overload, item_type=item_type - ) + self.update_paging_operation(code_model, yaml_data, is_overload=is_overload, item_type=item_type) yaml_data["discriminator"] = "lropaging" for response in yaml_data.get("responses", []): response["discriminator"] = "lropaging" @@ -653,16 +635,12 @@ def convert_initial_operation_response_type(data: dict[str, Any]) -> None: response["type"] = KNOWN_TYPES["binary"] self.update_operation(code_model, yaml_data, is_overload=is_overload) - self.update_operation( - code_model, yaml_data["initialOperation"], is_overload=is_overload - ) + self.update_operation(code_model, yaml_data["initialOperation"], is_overload=is_overload) convert_initial_operation_response_type(yaml_data["initialOperation"]) self._update_lro_operation_helper(yaml_data) for overload in yaml_data.get("overloads", []): self._update_lro_operation_helper(overload) - self.update_operation( - code_model, overload["initialOperation"], is_overload=True - ) + self.update_operation(code_model, overload["initialOperation"], is_overload=True) convert_initial_operation_response_type(overload["initialOperation"]) def update_paging_operation( @@ -680,9 +658,7 @@ def update_paging_operation( PadType.OPERATION_GROUP, yaml_data["nextOperation"], ) - yaml_data["nextOperation"]["groupName"] = to_snake_case( - yaml_data["nextOperation"]["groupName"] - ) + yaml_data["nextOperation"]["groupName"] = to_snake_case(yaml_data["nextOperation"]["groupName"]) for response in yaml_data["nextOperation"].get("responses", []): update_paging_response(response) response["itemType"] = item_type @@ -690,13 +666,9 @@ def update_paging_operation( update_paging_response(response) response["itemType"] = item_type for overload in yaml_data.get("overloads", []): - self.update_paging_operation( - code_model, overload, is_overload=True, item_type=item_type - ) + self.update_paging_operation(code_model, overload, is_overload=True, item_type=item_type) - def update_operation_groups( - self, code_model: dict[str, Any], client: dict[str, Any] - ) -> None: + def update_operation_groups(self, code_model: dict[str, Any], client: dict[str, Any]) -> None: operation_groups_yaml_data = client.get("operationGroups", []) for operation_group in operation_groups_yaml_data: operation_group["identifyName"] = self.pad_reserved_words( @@ -704,17 +676,13 @@ def update_operation_groups( PadType.OPERATION_GROUP, operation_group, ) - operation_group["identifyName"] = to_snake_case( - operation_group["identifyName"] - ) + operation_group["identifyName"] = to_snake_case(operation_group["identifyName"]) operation_group["propertyName"] = self.pad_reserved_words( operation_group["propertyName"], PadType.OPERATION_GROUP, operation_group, ) - operation_group["propertyName"] = to_snake_case( - operation_group["propertyName"] - ) + operation_group["propertyName"] = to_snake_case(operation_group["propertyName"]) operation_group["className"] = update_operation_group_class_name( client["name"], operation_group["className"] ) @@ -738,6 +706,4 @@ def update_yaml(self, yaml_data: dict[str, Any]) -> None: if __name__ == "__main__": # TSP pipeline will call this args, unknown_args = parse_args() - PreProcessPlugin( - output_folder=args.output_folder, tsp_file=args.tsp_file, **unknown_args - ).process() + PreProcessPlugin(output_folder=args.output_folder, tsp_file=args.tsp_file, **unknown_args).process() diff --git a/packages/http-client-python/generator/pygen/utils.py b/packages/http-client-python/generator/pygen/utils.py index fc198e148b1..15671b4f186 100644 --- a/packages/http-client-python/generator/pygen/utils.py +++ b/packages/http-client-python/generator/pygen/utils.py @@ -22,20 +22,13 @@ def description_ends_with_code_block(description: str) -> bool: block runs to the end of the description. """ lines = description.rstrip().splitlines() - directives = [ - i for i, line in enumerate(lines) if line.lstrip().startswith(CODE_BLOCK_MARKER) - ] + directives = [i for i, line in enumerate(lines) if line.lstrip().startswith(CODE_BLOCK_MARKER)] if not directives: return False - return all( - not line.strip() or line.startswith((" ", "\t")) - for line in lines[directives[-1] + 1 :] - ) + return all(not line.strip() or line.startswith((" ", "\t")) for line in lines[directives[-1] + 1 :]) -def update_enum_value( - name: str, value: Any, description: str, enum_type: dict[str, Any] -) -> dict[str, Any]: +def update_enum_value(name: str, value: Any, description: str, enum_type: dict[str, Any]) -> dict[str, Any]: return { "name": name, "type": "enumvalue", @@ -62,12 +55,7 @@ def replace_upper_characters(m) -> str: and len(name) - next_non_upper_case_char_location > 1 and name[next_non_upper_case_char_location].isalpha() ): - return ( - prefix - + match_str[: len(match_str) - 1] - + "_" - + match_str[len(match_str) - 1] - ) + return prefix + match_str[: len(match_str) - 1] + "_" + match_str[len(match_str) - 1] return prefix + match_str @@ -114,9 +102,7 @@ def _get_value(value: Any) -> Any: return value unknown_args_ret = { - ua.strip("--").split("=", maxsplit=1)[0]: _get_value( - ua.strip("--").split("=", maxsplit=1)[1] - ) + ua.strip("--").split("=", maxsplit=1)[0]: _get_value(ua.strip("--").split("=", maxsplit=1)[1]) for ua in unknown_args } return args, unknown_args_ret @@ -158,11 +144,7 @@ def build_policies( "self._config.user_agent_policy", "self._config.proxy_policy", "policies.ContentDecodePolicy(**kwargs)", - ( - f"{async_prefix}ARMAutoResourceProviderRegistrationPolicy()" - if is_arm - else None - ), + (f"{async_prefix}ARMAutoResourceProviderRegistrationPolicy()" if is_arm else None), "self._config.redirect_policy", "self._config.retry_policy", "self._config.authentication_policy", diff --git a/packages/http-client-python/tests/mock_api/azure/test_client_naming_typeddict.py b/packages/http-client-python/tests/mock_api/azure/test_client_naming_typeddict.py new file mode 100644 index 00000000000..cf9cf30401b --- /dev/null +++ b/packages/http-client-python/tests/mock_api/azure/test_client_naming_typeddict.py @@ -0,0 +1,62 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- +import pytest +from client.naming.typeddict import NamingClient, models + + +@pytest.fixture +def client(): + with NamingClient() as client: + yield client + + +def test_client(client: NamingClient): + """TypedDict uses wire name 'defaultName', not client name 'client_name'.""" + client.property.client({"defaultName": True}) + + +def test_language(client: NamingClient): + """TypedDict uses wire name 'defaultName', not language-specific name 'python_name'.""" + client.property.language({"defaultName": True}) + + +def test_compatible_with_encoded_name(client: NamingClient): + """TypedDict uses encoded wire name 'wireName', not client name 'client_name'.""" + client.property.compatible_with_encoded_name({"wireName": True}) + + +def test_operation(client: NamingClient): + client.client_name() + + +def test_parameter(client: NamingClient): + client.parameter(client_name="true") + + +def test_header_request(client: NamingClient): + client.header.request(client_name="true") + + +def test_header_response(client: NamingClient): + assert client.header.response(cls=lambda x, y, z: z)["default-name"] == "true" + + +def test_model_client(client: NamingClient): + """TypedDict uses wire name 'defaultName', not client name 'default_name'.""" + client.model_client.client({"defaultName": True}) + + +def test_model_language(client: NamingClient): + """TypedDict uses wire name 'defaultName', not client name 'default_name'.""" + client.model_client.language({"defaultName": True}) + + +def test_union_enum_member_name(client: NamingClient): + client.union_enum.union_enum_member_name(models.ExtensibleEnum.CLIENT_ENUM_VALUE1) + + +def test_union_enum_name(client: NamingClient): + client.union_enum.union_enum_name(models.ClientExtensibleEnum.ENUM_VALUE1) diff --git a/packages/http-client-python/tests/mock_api/shared/asynctests/test_typetest_model_usage_async.py b/packages/http-client-python/tests/mock_api/shared/asynctests/test_typetest_model_usage_async.py index a7e00354034..538977da9b4 100644 --- a/packages/http-client-python/tests/mock_api/shared/asynctests/test_typetest_model_usage_async.py +++ b/packages/http-client-python/tests/mock_api/shared/asynctests/test_typetest_model_usage_async.py @@ -7,6 +7,7 @@ import pytest_asyncio from typetest.model.usage import models from typetest.model.usage.aio import UsageClient +from typetest.model.usage.types import InputRecord, InputOutputRecord @pytest_asyncio.fixture @@ -31,3 +32,18 @@ async def test_output(client: UsageClient): async def test_input_and_output(client: UsageClient): input_output = models.InputOutputRecord(required_prop="example-value") assert input_output == await client.input_and_output(input_output) + + +@pytest.mark.asyncio +async def test_input_typeddict(client: UsageClient): + # Pass a TypedDict (plain dict with wire names) instead of a model + result = await client.input({"requiredProp": "example-value"}) + assert result is None + + +@pytest.mark.asyncio +async def test_input_and_output_typeddict(client: UsageClient): + # Pass a TypedDict, get a model back + result = await client.input_and_output({"requiredProp": "example-value"}) + assert isinstance(result, models.InputOutputRecord) + assert result.required_prop == "example-value" diff --git a/packages/http-client-python/tests/mock_api/shared/asynctests/test_typetest_model_usage_typeddictonly_async.py b/packages/http-client-python/tests/mock_api/shared/asynctests/test_typetest_model_usage_typeddictonly_async.py new file mode 100644 index 00000000000..e87f171ec7d --- /dev/null +++ b/packages/http-client-python/tests/mock_api/shared/asynctests/test_typetest_model_usage_typeddictonly_async.py @@ -0,0 +1,38 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- +import pytest +import pytest_asyncio +from typetest.model.usage.typeddictonly.aio import UsageClient +from typetest.model.usage.typeddictonly.types import InputRecord, OutputRecord, InputOutputRecord + + +@pytest_asyncio.fixture +async def client(): + async with UsageClient() as client: + yield client + + +@pytest.mark.asyncio +async def test_input(client: UsageClient): + # TypedDict-only: pass a plain dict matching the TypedDict schema + result = await client.input({"requiredProp": "example-value"}) + assert result is None + + +@pytest.mark.asyncio +async def test_output(client: UsageClient): + # TypedDict-only: output should be a plain dict (no model deserialization) + output = await client.output() + assert isinstance(output, dict) + assert output["requiredProp"] == "example-value" + + +@pytest.mark.asyncio +async def test_input_and_output(client: UsageClient): + # TypedDict-only: input a dict, get a dict back + result = await client.input_and_output({"requiredProp": "example-value"}) + assert isinstance(result, dict) + assert result["requiredProp"] == "example-value" diff --git a/packages/http-client-python/tests/mock_api/shared/test_typetest_model_inheritance_not_discriminated_typeddict.py b/packages/http-client-python/tests/mock_api/shared/test_typetest_model_inheritance_not_discriminated_typeddict.py new file mode 100644 index 00000000000..782791ab1e7 --- /dev/null +++ b/packages/http-client-python/tests/mock_api/shared/test_typetest_model_inheritance_not_discriminated_typeddict.py @@ -0,0 +1,37 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- +import pytest +from typetest.model.notdiscriminated.typeddict import NotDiscriminatedClient +from typetest.model.notdiscriminated.typeddict.models import Siamese + + +@pytest.fixture +def client(): + with NotDiscriminatedClient() as client: + yield client + + +@pytest.fixture +def valid_body(): + return Siamese(name="abc", age=32, smart=True) + + +def test_get_valid(client, valid_body): + result = client.get_valid() + assert result["name"] == "abc" + assert result["age"] == 32 + assert result["smart"] is True + + +def test_post_valid(client, valid_body): + client.post_valid(valid_body) + + +def test_put_valid(client, valid_body): + result = client.put_valid(valid_body) + assert result["name"] == "abc" + assert result["age"] == 32 + assert result["smart"] is True diff --git a/packages/http-client-python/tests/mock_api/shared/test_typetest_model_inheritance_single_discriminator_typeddict.py b/packages/http-client-python/tests/mock_api/shared/test_typetest_model_inheritance_single_discriminator_typeddict.py new file mode 100644 index 00000000000..19335d1bc69 --- /dev/null +++ b/packages/http-client-python/tests/mock_api/shared/test_typetest_model_inheritance_single_discriminator_typeddict.py @@ -0,0 +1,70 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- +import pytest +from typetest.model.singlediscriminator.typeddict import SingleDiscriminatorClient +from typetest.model.singlediscriminator.typeddict.models import Sparrow, Eagle + + +@pytest.fixture +def client(): + with SingleDiscriminatorClient() as client: + yield client + + +@pytest.fixture +def valid_body(): + return Sparrow(wingspan=1, kind="sparrow") + + +def test_get_model(client): + result = client.get_model() + assert result["wingspan"] == 1 + assert result["kind"] == "sparrow" + + +def test_put_model(client, valid_body): + client.put_model(valid_body) + + +@pytest.fixture +def recursive_body(): + return Eagle( + wingspan=5, + kind="eagle", + partner={"wingspan": 2, "kind": "goose"}, + friends=[{"wingspan": 2, "kind": "seagull"}], + hate={"key3": {"wingspan": 1, "kind": "sparrow"}}, + ) + + +def test_get_recursive_model(client): + result = client.get_recursive_model() + assert result["wingspan"] == 5 + assert result["kind"] == "eagle" + assert result["partner"]["kind"] == "goose" + assert result["friends"][0]["kind"] == "seagull" + assert result["hate"]["key3"]["kind"] == "sparrow" + + +def test_put_recursive_model(client, recursive_body): + client.put_recursive_model(recursive_body) + + +def test_get_missing_discriminator(client): + result = client.get_missing_discriminator() + assert result["wingspan"] == 1 + + +def test_get_wrong_discriminator(client): + result = client.get_wrong_discriminator() + assert result["wingspan"] == 1 + assert result["kind"] == "wrongKind" + + +def test_get_legacy_model(client): + result = client.get_legacy_model() + assert result["size"] == 20 + assert result["kind"] == "t-rex" diff --git a/packages/http-client-python/tests/mock_api/shared/test_typetest_model_usage.py b/packages/http-client-python/tests/mock_api/shared/test_typetest_model_usage.py index c9ef0e63e7f..1dc7bb7bbf5 100644 --- a/packages/http-client-python/tests/mock_api/shared/test_typetest_model_usage.py +++ b/packages/http-client-python/tests/mock_api/shared/test_typetest_model_usage.py @@ -5,6 +5,7 @@ # -------------------------------------------------------------------------- import pytest from typetest.model.usage import UsageClient, models +from typetest.model.usage.types import InputRecord, InputOutputRecord @pytest.fixture @@ -26,3 +27,16 @@ def test_output(client: UsageClient): def test_input_and_output(client: UsageClient): input_output = models.InputOutputRecord(required_prop="example-value") assert input_output == client.input_and_output(input_output) + + +def test_input_typeddict(client: UsageClient): + # Pass a TypedDict (plain dict with wire names) instead of a model + result = client.input({"requiredProp": "example-value"}) + assert result is None + + +def test_input_and_output_typeddict(client: UsageClient): + # Pass a TypedDict, get a model back + result = client.input_and_output({"requiredProp": "example-value"}) + assert isinstance(result, models.InputOutputRecord) + assert result.required_prop == "example-value" diff --git a/packages/http-client-python/tests/mock_api/shared/test_typetest_model_usage_typeddictonly.py b/packages/http-client-python/tests/mock_api/shared/test_typetest_model_usage_typeddictonly.py new file mode 100644 index 00000000000..586eb7be791 --- /dev/null +++ b/packages/http-client-python/tests/mock_api/shared/test_typetest_model_usage_typeddictonly.py @@ -0,0 +1,44 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- +import pytest +from typetest.model.usage.typeddictonly import UsageClient +from typetest.model.usage.typeddictonly.types import InputRecord, OutputRecord, InputOutputRecord + + +@pytest.fixture +def client(): + with UsageClient() as client: + yield client + + +def test_input(client: UsageClient): + # TypedDict-only: pass a plain dict matching the TypedDict schema + result = client.input({"requiredProp": "example-value"}) + assert result is None + + +def test_output(client: UsageClient): + # TypedDict-only: output should be a plain dict (no model deserialization) + output = client.output() + assert isinstance(output, dict) + assert output["requiredProp"] == "example-value" + + +def test_input_and_output(client: UsageClient): + # TypedDict-only: input a dict, get a dict back + result = client.input_and_output({"requiredProp": "example-value"}) + assert isinstance(result, dict) + assert result["requiredProp"] == "example-value" + + +def test_no_model_classes(): + """Verify that typed-dict-only models don't generate model classes.""" + from typetest.model.usage.typeddictonly import models + + # models.__all__ should be empty — no model classes exported + assert models.__all__ == [] + # The TypedDicts should only exist in the types module + assert hasattr(InputRecord, "__required_keys__") or hasattr(InputRecord, "__annotations__") diff --git a/packages/http-client-python/tests/mock_api/shared/unittests/test_readme.py b/packages/http-client-python/tests/mock_api/shared/unittests/test_readme.py index 4aa24a62964..c03c697af5e 100644 --- a/packages/http-client-python/tests/mock_api/shared/unittests/test_readme.py +++ b/packages/http-client-python/tests/mock_api/shared/unittests/test_readme.py @@ -16,6 +16,7 @@ SKIP_PACKAGES = { ("azure", "service-multiple-services"), ("azure", "azure-client-generator-core-client-initialization"), + ("azure", "azure-client-generator-core-response-as-bool"), # TypedDict test packages exist in the azure-sdk-for-python baseline but # are not yet in regenerate.ts on main, so their READMEs are deleted during # baseline reset and never recreated. diff --git a/packages/http-client-python/tests/unit/test_typeddict.py b/packages/http-client-python/tests/unit/test_typeddict.py new file mode 100644 index 00000000000..7f624847e5b --- /dev/null +++ b/packages/http-client-python/tests/unit/test_typeddict.py @@ -0,0 +1,253 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +"""Tests for TypedDict generation, unions generation, and models-mode interactions.""" + +from jinja2 import PackageLoader, Environment + +from pygen.codegen.models import CodeModel, JSONModelType, DPGModelType +from pygen.codegen.models.model_type import TypedDictModelType +from pygen.codegen.serializers.types_serializer import TypesSerializer +from pygen.codegen.serializers.unions_serializer import UnionsSerializer + + +def _make_code_model(models_mode="dpg"): + return CodeModel( + { + "clients": [ + { + "name": "client", + "namespace": "blah", + "moduleName": "blah", + "parameters": [], + "url": "", + "operationGroups": [], + } + ], + "namespace": "namespace", + }, + options={ + "show-send-request": True, + "builders-visibility": "public", + "show-operations": True, + "models-mode": models_mode, + "flavor": "unbranded", + "client-side-validation": False, + }, + ) + + +def _make_model(code_model, name, model_cls=None, properties=None): + """Create a model of the given class attached to code_model.""" + if model_cls is None: + if code_model.options["models-mode"] == "typeddict": + model_cls = TypedDictModelType + elif code_model.options["models-mode"] == "dpg": + model_cls = DPGModelType + else: + model_cls = JSONModelType + return model_cls( + yaml_data={ + "name": name, + "type": "model", + "snakeCaseName": name.lower(), + "usage": 2, + }, + code_model=code_model, + properties=properties or [], + ) + + +def _make_env(): + return Environment( + loader=PackageLoader("pygen.codegen", "templates"), + trim_blocks=True, + lstrip_blocks=True, + ) + + +# ---------- models-mode=none ---------- + + +def test_models_mode_none_produces_json_model_type(): + """When models-mode is none (False), all models should be JSONModelType.""" + code_model = _make_code_model(models_mode=False) + model = _make_model(code_model, "Foo", model_cls=JSONModelType) + assert model.base == "json" + + +def test_models_mode_none_no_typeddict_models(): + """TypesSerializer.typeddict_models should be empty when models-mode=none.""" + code_model = _make_code_model(models_mode=False) + m1 = _make_model(code_model, "Foo", model_cls=JSONModelType) + m2 = _make_model(code_model, "Bar", model_cls=JSONModelType) + + env = _make_env() + ts = TypesSerializer(code_model=code_model, env=env, models=[m1, m2]) + assert ts.typeddict_models == [] + + +def test_models_mode_none_types_file_has_no_typeddict_imports(): + """When models-mode=none, the types.py should not import TypedDict.""" + code_model = _make_code_model(models_mode=False) + m1 = _make_model(code_model, "Foo", model_cls=JSONModelType) + code_model.model_types = [m1] + + env = _make_env() + ts = TypesSerializer(code_model=code_model, env=env, models=[m1]) + output = ts.serialize() + assert "TypedDict" not in output + assert "Required" not in output + + +# ---------- models-mode=dpg ---------- + + +def test_models_mode_dpg_typeddict_models_included(): + """DPG models have base='dpg', not 'json', so they appear in typeddict_models.""" + code_model = _make_code_model(models_mode="dpg") + m1 = _make_model(code_model, "Foo", model_cls=DPGModelType) + + env = _make_env() + ts = TypesSerializer(code_model=code_model, env=env, models=[m1]) + # DPG models have base != "json" so they DO appear in typeddict_models + assert len(ts.typeddict_models) == 1 + + +# ---------- models-mode=typeddict ---------- + + +def test_models_mode_typeddict_models_included(): + """TypedDictModelType models should appear in typeddict_models.""" + code_model = _make_code_model(models_mode="typeddict") + m1 = _make_model(code_model, "Foo", model_cls=TypedDictModelType) + m2 = _make_model(code_model, "Bar", model_cls=TypedDictModelType) + + env = _make_env() + ts = TypesSerializer(code_model=code_model, env=env, models=[m1, m2]) + assert len(ts.typeddict_models) == 2 + + +def test_models_mode_typeddict_serialize_contains_class(): + """Serialized types.py output should contain TypedDict class definitions.""" + code_model = _make_code_model(models_mode="typeddict") + m1 = _make_model(code_model, "Foo", model_cls=TypedDictModelType) + code_model.model_types = [m1] + + env = _make_env() + ts = TypesSerializer(code_model=code_model, env=env, models=[m1]) + output = ts.serialize() + assert "class Foo(TypedDict, total=False):" in output + assert "TypedDict" in output + + +def test_types_file_has_no_named_unions(): + """Serialized types.py should not contain named union definitions.""" + code_model = _make_code_model(models_mode="dpg") + m1 = _make_model(code_model, "Foo", model_cls=DPGModelType) + code_model.model_types = [m1] + + env = _make_env() + ts = TypesSerializer(code_model=code_model, env=env, models=[m1]) + output = ts.serialize() + # Named unions should be in _unions.py, not types.py + assert "named_unions" not in output + + +# ---------- unions serializer ---------- + + +def test_unions_serializer_no_unions(): + """UnionsSerializer with no named unions should produce minimal output.""" + code_model = _make_code_model(models_mode="dpg") + + env = _make_env() + us = UnionsSerializer(code_model=code_model, env=env) + output = us.serialize() + assert "TypedDict" not in output + assert "Union" not in output + + +# ---------- typed-dict-only ---------- + + +def _make_typed_dict_only_model(code_model, name, **extra_yaml): + """Create a TypedDictModelType with typedDictOnly=True.""" + yaml_data = { + "name": name, + "type": "model", + "snakeCaseName": name.lower(), + "usage": 2, + "typedDictOnly": True, + **extra_yaml, + } + return TypedDictModelType( + yaml_data=yaml_data, + code_model=code_model, + properties=[], + ) + + +def test_typed_dict_only_property(): + """is_typed_dict_only should be True when yaml_data has typedDictOnly=True or models-mode is typeddict.""" + code_model = _make_code_model(models_mode="typeddict") + model = _make_typed_dict_only_model(code_model, "Foo") + assert model.is_typed_dict_only is True + + # In typeddict mode, ALL models are typed-dict-only + normal_model = _make_model(code_model, "Bar", model_cls=TypedDictModelType) + assert normal_model.is_typed_dict_only is True + + # In dpg mode, only models with typedDictOnly=True are typed-dict-only + dpg_code_model = _make_code_model(models_mode="dpg") + dpg_normal = _make_model(dpg_code_model, "Baz", model_cls=TypedDictModelType) + assert dpg_normal.is_typed_dict_only is False + + +def test_typed_dict_only_excluded_from_public_model_types(): + """Typed-dict-only models should not appear in public_model_types.""" + code_model = _make_code_model(models_mode="typeddict") + normal = _make_model(code_model, "Normal", model_cls=TypedDictModelType) + td_only = _make_typed_dict_only_model(code_model, "TdOnly") + code_model.model_types = [normal, td_only] + + public = code_model.public_model_types + # In typeddict mode, all models are typed-dict-only and excluded from public model types + assert normal not in public + assert td_only not in public + + +def test_typed_dict_only_still_in_types_file(): + """Typed-dict-only models should still appear in types.py as TypedDicts.""" + code_model = _make_code_model(models_mode="typeddict") + td_only = _make_typed_dict_only_model(code_model, "MyModel") + code_model.model_types = [td_only] + + env = _make_env() + ts = TypesSerializer(code_model=code_model, env=env, models=[td_only]) + output = ts.serialize() + assert "class MyModel(TypedDict, total=False):" in output + + +def test_typed_dict_only_type_annotation(): + """Typed-dict-only models should use types.Name, not _models.Name.""" + code_model = _make_code_model(models_mode="typeddict") + model = _make_typed_dict_only_model(code_model, "Foo") + + # In operation files, should be types.Name + annotation = model.type_annotation(is_operation_file=True) + assert annotation == "types.Foo" + assert "_models" not in annotation + + +def test_typed_dict_only_docstring_type(): + """Typed-dict-only models should reference types module, not models.""" + code_model = _make_code_model(models_mode="typeddict") + model = _make_typed_dict_only_model(code_model, "Foo") + + docstring = model.docstring_type() + assert "types.Foo" in docstring + assert "models.Foo" not in docstring