diff --git a/.python-version b/.python-version index 2c073331..e4fba218 100644 --- a/.python-version +++ b/.python-version @@ -1 +1 @@ -3.11 +3.12 diff --git a/.replit b/.replit index 5b89454c..9a6274c0 100644 --- a/.replit +++ b/.replit @@ -1,6 +1,6 @@ run = "poetry run pytest tests" -modules = ["python-3.11"] +modules = ["python-3.12"] [nix] -channel = "stable-23_11" +channel = "stable-24_11" diff --git a/flake.nix b/flake.nix index 684a5e2c..5ad206f6 100644 --- a/flake.nix +++ b/flake.nix @@ -18,7 +18,7 @@ LD_LIBRARY_PATH = "${pkgs.stdenv.cc.cc.lib}/lib"; }; packages = replitNixDeps ++ [ - pkgs.python311 + pkgs.python312 pkgs.uv ]; }; diff --git a/src/replit_river/client.py b/src/replit_river/client.py index 51e45a2d..90a695b1 100644 --- a/src/replit_river/client.py +++ b/src/replit_river/client.py @@ -129,7 +129,7 @@ async def send_subscription( request_serializer: Callable[[RequestType], Any], response_deserializer: Callable[[Any], ResponseType], error_deserializer: Callable[[Any], ErrorType], - ) -> AsyncGenerator[Union[ResponseType, ErrorType], None]: + ) -> AsyncGenerator[Union[ResponseType, RiverError], None]: with _trace_procedure( "subscription", service_name, procedure_name ) as span_handle: @@ -157,7 +157,7 @@ async def send_stream( request_serializer: Callable[[RequestType], Any], response_deserializer: Callable[[Any], ResponseType], error_deserializer: Callable[[Any], ErrorType], - ) -> AsyncGenerator[Union[ResponseType, ErrorType], None]: + ) -> AsyncGenerator[Union[ResponseType, RiverError], None]: with _trace_procedure("stream", service_name, procedure_name) as span_handle: session = await self._transport.get_or_create_session() async for msg in session.send_stream( diff --git a/src/replit_river/codegen/client.py b/src/replit_river/codegen/client.py index eaa28f55..7126163d 100644 --- a/src/replit_river/codegen/client.py +++ b/src/replit_river/codegen/client.py @@ -64,7 +64,7 @@ from pydantic import TypeAdapter -from replit_river.error_schema import RiverError +from replit_river.error_schema import RiverError, RiverErrorTypeAdapter import replit_river as river """ @@ -763,6 +763,27 @@ def generate_individual_service( input_base_class: Literal["TypedDict"] | Literal["BaseModel"], ) -> Tuple[ModuleName, ClassName, dict[RenderedPath, FileContents]]: serdes: list[Tuple[list[TypeName], list[ModuleName], list[FileContents]]] = [] + + def _type_adapter_definition( + type_adapter_name: TypeName, + _type: TypeExpression, + module_info: list[ModuleName], + ) -> tuple[list[TypeName], list[ModuleName], list[FileContents]]: + rendered_type_expr = render_type_expr(_type) + return ( + [type_adapter_name], + module_info, + [ + FileContents( + dedent(f""" + {render_type_expr(type_adapter_name)}: TypeAdapter[Any] = ( + TypeAdapter({rendered_type_expr}) + ) + """) + ) + ], + ) + class_name = ClassName(f"{schema_name.title()}Service") current_chunks: List[str] = [ dedent( @@ -798,6 +819,10 @@ def __init__(self, client: river.Client[Any]): module_names, permit_unknown_members=False, ) + input_type_name = extract_inner_type(input_type) + input_type_type_adapter_name = TypeName( + f"{render_literal_type(input_type_name)}TypeAdapter" + ) serdes.append( ( [extract_inner_type(input_type), *encoder_names], @@ -805,6 +830,11 @@ def __init__(self, client: river.Client[Any]): input_chunks, ) ) + serdes.append( + _type_adapter_definition( + input_type_type_adapter_name, input_type, module_info + ) + ) output_type, module_info, output_chunks, encoder_names = encode_type( procedure.output, TypeName(f"{name.title()}Output"), @@ -812,13 +842,23 @@ def __init__(self, client: river.Client[Any]): module_names, permit_unknown_members=True, ) + output_type_name = extract_inner_type(output_type) serdes.append( ( - [extract_inner_type(output_type), *encoder_names], + [output_type_name, *encoder_names], module_info, output_chunks, ) ) + output_type_type_adapter_name = TypeName( + f"{render_literal_type(output_type_name)}TypeAdapter" + ) + serdes.append( + _type_adapter_definition( + output_type_type_adapter_name, output_type, module_info + ) + ) + output_module_info = module_info if procedure.errors: error_type, module_info, errors_chunks, encoder_names = encode_type( procedure.errors, @@ -828,27 +868,43 @@ def __init__(self, client: river.Client[Any]): permit_unknown_members=True, ) if isinstance(error_type, NoneTypeExpr): - error_type = TypeName("RiverError") + error_type_name = TypeName("RiverError") + error_type = error_type_name else: - serdes.append( - ([extract_inner_type(error_type)], module_info, errors_chunks) - ) + error_type_name = extract_inner_type(error_type) + serdes.append(([error_type_name], module_info, errors_chunks)) + else: - error_type = TypeName("RiverError") - output_or_error_type = UnionTypeExpr([output_type, error_type]) + error_type_name = TypeName("RiverError") + error_type = error_type_name + + error_type_type_adapter_name = TypeName( + f"{render_literal_type(error_type_name)}TypeAdapter" + ) + if error_type_type_adapter_name.value != "RiverErrorTypeAdapter": + if len(module_info) == 0: + module_info = output_module_info + serdes.append( + _type_adapter_definition( + error_type_type_adapter_name, error_type, module_info + ) + ) + output_or_error_type = UnionTypeExpr([output_type, error_type_name]) # NB: These strings must be indented to at least the same level of # the function strings in the branches below, otherwise `dedent` # will pick our indentation level for normalization, which will # break the "def" indentation presuppositions. + output_type_adapter = render_literal_type(output_type_type_adapter_name) parse_output_method = f"""\ - lambda x: TypeAdapter({render_type_expr(output_type)}) + lambda x: {output_type_adapter} .validate_python( x # type: ignore[arg-type] ) """ + error_type_adapter = render_literal_type(error_type_type_adapter_name) parse_error_method = f"""\ - lambda x: TypeAdapter({render_type_expr(error_type)}) + lambda x: {error_type_adapter} .validate_python( x # type: ignore[arg-type] ) @@ -871,9 +927,18 @@ def __init__(self, client: river.Client[Any]): else: render_init_method = f"encode_{render_literal_type(init_type)}" else: + init_type_name = extract_inner_type(init_type) + init_type_type_adapter_name = TypeName( + f"{init_type_name.value}TypeAdapter" + ) + serdes.append( + _type_adapter_definition( + init_type_type_adapter_name, init_type, module_info + ) + ) render_init_method = f"""\ - lambda x: TypeAdapter({render_type_expr(init_type)}) - .validate_python + lambda x: {render_type_expr(init_type_type_adapter_name)} + .validate_python """ assert init_type is None or render_init_method, ( @@ -889,17 +954,17 @@ def __init__(self, client: river.Client[Any]): procedure.input, RiverConcreteType ) and procedure.input.type in ["array"]: match input_type: - case ListTypeExpr(input_type_name): + case ListTypeExpr(list_type): render_input_method = f"""\ lambda xs: [ - encode_{render_literal_type(input_type_name)}(x) for x in xs + encode_{render_literal_type(list_type)}(x) for x in xs ] """ else: render_input_method = f"encode_{render_literal_type(input_type)}" else: render_input_method = f"""\ - lambda x: TypeAdapter({render_type_expr(input_type)}) + lambda x: {render_type_expr(input_type_type_adapter_name)} .dump_python( x, # type: ignore[arg-type] by_alias=True, diff --git a/src/replit_river/error_schema.py b/src/replit_river/error_schema.py index dc74fa98..a97fbc9c 100644 --- a/src/replit_river/error_schema.py +++ b/src/replit_river/error_schema.py @@ -1,6 +1,6 @@ from typing import Any, List, Optional -from pydantic import BaseModel +from pydantic import BaseModel, TypeAdapter ERROR_CODE_STREAM_CLOSED = "stream_closed" ERROR_HANDSHAKE = "handshake_failed" @@ -25,6 +25,9 @@ class RiverError(BaseModel): message: str +RiverErrorTypeAdapter = TypeAdapter(RiverError) + + class RiverException(Exception): """Exception raised by the River server.""" diff --git a/tests/codegen/rpc/generated/test_service/__init__.py b/tests/codegen/rpc/generated/test_service/__init__.py index fc994615..24545e00 100644 --- a/tests/codegen/rpc/generated/test_service/__init__.py +++ b/tests/codegen/rpc/generated/test_service/__init__.py @@ -5,11 +5,17 @@ from pydantic import TypeAdapter -from replit_river.error_schema import RiverError +from replit_river.error_schema import RiverError, RiverErrorTypeAdapter import replit_river as river -from .rpc_method import Rpc_MethodInput, Rpc_MethodOutput, encode_Rpc_MethodInput +from .rpc_method import ( + Rpc_MethodInput, + Rpc_MethodInputTypeAdapter, + Rpc_MethodOutput, + Rpc_MethodOutputTypeAdapter, + encode_Rpc_MethodInput, +) class Test_ServiceService: @@ -26,10 +32,10 @@ async def rpc_method( "rpc_method", input, encode_Rpc_MethodInput, - lambda x: TypeAdapter(Rpc_MethodOutput).validate_python( + lambda x: Rpc_MethodOutputTypeAdapter.validate_python( x # type: ignore[arg-type] ), - lambda x: TypeAdapter(RiverError).validate_python( + lambda x: RiverErrorTypeAdapter.validate_python( x # type: ignore[arg-type] ), timeout, diff --git a/tests/codegen/rpc/generated/test_service/rpc_method.py b/tests/codegen/rpc/generated/test_service/rpc_method.py index 263c36b0..d32b4645 100644 --- a/tests/codegen/rpc/generated/test_service/rpc_method.py +++ b/tests/codegen/rpc/generated/test_service/rpc_method.py @@ -39,5 +39,11 @@ class Rpc_MethodInput(TypedDict): data: str +Rpc_MethodInputTypeAdapter: TypeAdapter[Any] = TypeAdapter(Rpc_MethodInput) + + class Rpc_MethodOutput(BaseModel): data: str + + +Rpc_MethodOutputTypeAdapter: TypeAdapter[Any] = TypeAdapter(Rpc_MethodOutput) diff --git a/tests/codegen/snapshot/snapshots/test_unknown_enum/enumService/__init__.py b/tests/codegen/snapshot/snapshots/test_unknown_enum/enumService/__init__.py index 7477adb8..ab9eaa08 100644 --- a/tests/codegen/snapshot/snapshots/test_unknown_enum/enumService/__init__.py +++ b/tests/codegen/snapshot/snapshots/test_unknown_enum/enumService/__init__.py @@ -5,20 +5,26 @@ from pydantic import TypeAdapter -from replit_river.error_schema import RiverError +from replit_river.error_schema import RiverError, RiverErrorTypeAdapter import replit_river as river from .needsEnum import ( NeedsenumErrors, + NeedsenumErrorsTypeAdapter, NeedsenumInput, + NeedsenumInputTypeAdapter, NeedsenumOutput, + NeedsenumOutputTypeAdapter, encode_NeedsenumInput, ) from .needsEnumObject import ( NeedsenumobjectErrors, + NeedsenumobjectErrorsTypeAdapter, NeedsenumobjectInput, + NeedsenumobjectInputTypeAdapter, NeedsenumobjectOutput, + NeedsenumobjectOutputTypeAdapter, encode_NeedsenumobjectInput, ) @@ -37,10 +43,10 @@ async def needsEnum( "needsEnum", input, lambda x: x, - lambda x: TypeAdapter(NeedsenumOutput).validate_python( + lambda x: NeedsenumOutputTypeAdapter.validate_python( x # type: ignore[arg-type] ), - lambda x: TypeAdapter(NeedsenumErrors).validate_python( + lambda x: NeedsenumErrorsTypeAdapter.validate_python( x # type: ignore[arg-type] ), timeout, @@ -56,10 +62,10 @@ async def needsEnumObject( "needsEnumObject", input, encode_NeedsenumobjectInput, - lambda x: TypeAdapter(NeedsenumobjectOutput).validate_python( + lambda x: NeedsenumobjectOutputTypeAdapter.validate_python( x # type: ignore[arg-type] ), - lambda x: TypeAdapter(NeedsenumobjectErrors).validate_python( + lambda x: NeedsenumobjectErrorsTypeAdapter.validate_python( x # type: ignore[arg-type] ), timeout, diff --git a/tests/codegen/snapshot/snapshots/test_unknown_enum/enumService/needsEnum.py b/tests/codegen/snapshot/snapshots/test_unknown_enum/enumService/needsEnum.py index 426049b9..95e6bd2c 100644 --- a/tests/codegen/snapshot/snapshots/test_unknown_enum/enumService/needsEnum.py +++ b/tests/codegen/snapshot/snapshots/test_unknown_enum/enumService/needsEnum.py @@ -26,11 +26,19 @@ NeedsenumInput = Literal["in_first"] | Literal["in_second"] encode_NeedsenumInput: Callable[["NeedsenumInput"], Any] = lambda x: x + +NeedsenumInputTypeAdapter: TypeAdapter[Any] = TypeAdapter(NeedsenumInput) + NeedsenumOutput = Annotated[ Literal["out_first"] | Literal["out_second"] | RiverUnknownValue, WrapValidator(translate_unknown_value), ] + +NeedsenumOutputTypeAdapter: TypeAdapter[Any] = TypeAdapter(NeedsenumOutput) + NeedsenumErrors = Annotated[ Literal["err_first"] | Literal["err_second"] | RiverUnknownValue, WrapValidator(translate_unknown_value), ] + +NeedsenumErrorsTypeAdapter: TypeAdapter[Any] = TypeAdapter(NeedsenumErrors) diff --git a/tests/codegen/snapshot/snapshots/test_unknown_enum/enumService/needsEnumObject.py b/tests/codegen/snapshot/snapshots/test_unknown_enum/enumService/needsEnumObject.py index 6766eb79..4e370433 100644 --- a/tests/codegen/snapshot/snapshots/test_unknown_enum/enumService/needsEnumObject.py +++ b/tests/codegen/snapshot/snapshots/test_unknown_enum/enumService/needsEnumObject.py @@ -72,6 +72,8 @@ class NeedsenumobjectInputOneOf_in_second(TypedDict): else encode_NeedsenumobjectInputOneOf_in_second(x) ) +NeedsenumobjectInputTypeAdapter: TypeAdapter[Any] = TypeAdapter(NeedsenumobjectInput) + class NeedsenumobjectOutputFooOneOf_out_first(BaseModel): kind: Literal["out_first"] = Field( @@ -103,6 +105,9 @@ class NeedsenumobjectOutput(BaseModel): foo: Optional[NeedsenumobjectOutputFoo] = None +NeedsenumobjectOutputTypeAdapter: TypeAdapter[Any] = TypeAdapter(NeedsenumobjectOutput) + + class NeedsenumobjectErrorsFooAnyOf_0(RiverError): beep: Optional[Literal["err_first"]] = None @@ -121,3 +126,6 @@ class NeedsenumobjectErrorsFooAnyOf_1(RiverError): class NeedsenumobjectErrors(RiverError): foo: Optional[NeedsenumobjectErrorsFoo] = None + + +NeedsenumobjectErrorsTypeAdapter: TypeAdapter[Any] = TypeAdapter(NeedsenumobjectErrors) diff --git a/tests/codegen/stream/generated/test_service/__init__.py b/tests/codegen/stream/generated/test_service/__init__.py index f59c046e..b0d628b2 100644 --- a/tests/codegen/stream/generated/test_service/__init__.py +++ b/tests/codegen/stream/generated/test_service/__init__.py @@ -5,13 +5,15 @@ from pydantic import TypeAdapter -from replit_river.error_schema import RiverError +from replit_river.error_schema import RiverError, RiverErrorTypeAdapter import replit_river as river from .stream_method import ( Stream_MethodInput, + Stream_MethodInputTypeAdapter, Stream_MethodOutput, + Stream_MethodOutputTypeAdapter, encode_Stream_MethodInput, ) @@ -31,10 +33,10 @@ async def stream_method( inputStream, None, encode_Stream_MethodInput, - lambda x: TypeAdapter(Stream_MethodOutput).validate_python( + lambda x: Stream_MethodOutputTypeAdapter.validate_python( x # type: ignore[arg-type] ), - lambda x: TypeAdapter(RiverError).validate_python( + lambda x: RiverErrorTypeAdapter.validate_python( x # type: ignore[arg-type] ), ) diff --git a/tests/codegen/stream/generated/test_service/stream_method.py b/tests/codegen/stream/generated/test_service/stream_method.py index 53267467..1294a67a 100644 --- a/tests/codegen/stream/generated/test_service/stream_method.py +++ b/tests/codegen/stream/generated/test_service/stream_method.py @@ -39,5 +39,11 @@ class Stream_MethodInput(TypedDict): data: str +Stream_MethodInputTypeAdapter: TypeAdapter[Any] = TypeAdapter(Stream_MethodInput) + + class Stream_MethodOutput(BaseModel): data: str + + +Stream_MethodOutputTypeAdapter: TypeAdapter[Any] = TypeAdapter(Stream_MethodOutput)