|
3 | 3 | from typing import Literal, Optional |
4 | 4 |
|
5 | 5 | import pytest |
6 | | -from pydantic import BaseModel, Field |
| 6 | +from pydantic import BaseModel, Field, computed_field |
7 | 7 | from typing_extensions import Annotated |
8 | 8 |
|
9 | 9 | from aws_lambda_powertools.event_handler import APIGatewayRestResolver |
@@ -110,3 +110,79 @@ def create_todo(todo: TodoEnvelope): ... |
110 | 110 |
|
111 | 111 | # THEN the schema should be valid |
112 | 112 | assert openapi31_schema(schema) |
| 113 | + |
| 114 | + |
| 115 | +@pytest.mark.usefixtures("pydanticv2_only") |
| 116 | +def test_openapi_schema_includes_computed_field(): |
| 117 | + # GIVEN a model with a computed_field |
| 118 | + class User(BaseModel): |
| 119 | + first_name: str |
| 120 | + last_name: str |
| 121 | + |
| 122 | + @computed_field |
| 123 | + @property |
| 124 | + def full_name(self) -> str: |
| 125 | + return f"{self.first_name} {self.last_name}" |
| 126 | + |
| 127 | + # GIVEN APIGatewayRestResolver with a handler returning that model |
| 128 | + app = APIGatewayRestResolver(enable_validation=True) |
| 129 | + |
| 130 | + @app.get("/user") |
| 131 | + def get_user() -> User: |
| 132 | + return User(first_name="John", last_name="Doe") |
| 133 | + |
| 134 | + # WHEN we get the schema |
| 135 | + schema = json.loads(app.get_openapi_json_schema()) |
| 136 | + |
| 137 | + # THEN the computed_field should appear in the response schema |
| 138 | + user_schema = schema["components"]["schemas"]["User"] |
| 139 | + assert "full_name" in user_schema["properties"] |
| 140 | + assert user_schema["properties"]["full_name"]["type"] == "string" |
| 141 | + assert user_schema["properties"]["full_name"].get("readOnly") is True |
| 142 | + |
| 143 | + |
| 144 | +@pytest.mark.usefixtures("pydanticv2_only") |
| 145 | +def test_openapi_schema_computed_field_not_in_request_body(): |
| 146 | + # GIVEN a model with a computed_field used as both request and response |
| 147 | + class Item(BaseModel): |
| 148 | + price: float |
| 149 | + quantity: int |
| 150 | + |
| 151 | + @computed_field |
| 152 | + @property |
| 153 | + def total(self) -> float: |
| 154 | + return self.price * self.quantity |
| 155 | + |
| 156 | + # GIVEN APIGatewayRestResolver with handlers using the model |
| 157 | + app = APIGatewayRestResolver(enable_validation=True) |
| 158 | + |
| 159 | + @app.post("/items") |
| 160 | + def create_item(item: Item) -> Item: |
| 161 | + return item |
| 162 | + |
| 163 | + # WHEN we get the schema |
| 164 | + schema = json.loads(app.get_openapi_json_schema()) |
| 165 | + |
| 166 | + # THEN the request body schema should NOT include computed_field |
| 167 | + request_body = schema["paths"]["/items"]["post"]["requestBody"] |
| 168 | + request_ref = request_body["content"]["application/json"]["schema"]["$ref"] |
| 169 | + request_schema_name = request_ref.split("/")[-1] |
| 170 | + |
| 171 | + # THEN the response schema SHOULD include computed_field |
| 172 | + response_ref = schema["paths"]["/items"]["post"]["responses"]["200"]["content"]["application/json"]["schema"][ |
| 173 | + "$ref" |
| 174 | + ] |
| 175 | + response_schema_name = response_ref.split("/")[-1] |
| 176 | + |
| 177 | + # When input/output schemas are separate, we expect different schema names |
| 178 | + # When they share a schema, computed_field should be present |
| 179 | + if request_schema_name == response_schema_name: |
| 180 | + # Shared schema - computed_field should be present (serialization mode wins) |
| 181 | + item_schema = schema["components"]["schemas"][response_schema_name] |
| 182 | + assert "total" in item_schema["properties"] |
| 183 | + else: |
| 184 | + # Separate schemas |
| 185 | + input_schema = schema["components"]["schemas"][request_schema_name] |
| 186 | + output_schema = schema["components"]["schemas"][response_schema_name] |
| 187 | + assert "total" not in input_schema["properties"] |
| 188 | + assert "total" in output_schema["properties"] |
0 commit comments