From edde57258aec4e4cf7c04259557465a1dca70025 Mon Sep 17 00:00:00 2001 From: Toby Ho Date: Wed, 12 Mar 2025 20:09:47 -0400 Subject: [PATCH 01/16] updates --- src/replit_river/codegen/client.py | 70 ++++++++++++++++++++++++++---- 1 file changed, 61 insertions(+), 9 deletions(-) diff --git a/src/replit_river/codegen/client.py b/src/replit_river/codegen/client.py index eaa28f55..62fd5418 100644 --- a/src/replit_river/codegen/client.py +++ b/src/replit_river/codegen/client.py @@ -65,6 +65,7 @@ from pydantic import TypeAdapter from replit_river.error_schema import RiverError +RiverErrorTypeAdapter = TypeAdapter(RiverError) import replit_river as river """ @@ -761,6 +762,7 @@ def generate_individual_service( schema_name: str, schema: RiverService, input_base_class: Literal["TypedDict"] | Literal["BaseModel"], + ) -> Tuple[ModuleName, ClassName, dict[RenderedPath, FileContents]]: serdes: list[Tuple[list[TypeName], list[ModuleName], list[FileContents]]] = [] class_name = ClassName(f"{schema_name.title()}Service") @@ -798,6 +800,8 @@ def __init__(self, client: river.Client[Any]): module_names, permit_unknown_members=False, ) + input_type_name = extract_inner_type(input_type) + input_type_type_adapter_name = TypeName(f"{input_type_name.value}TypeAdapter") serdes.append( ( [extract_inner_type(input_type), *encoder_names], @@ -805,6 +809,13 @@ def __init__(self, client: river.Client[Any]): input_chunks, ) ) + serdes.append( + ( + [input_type_type_adapter_name], + module_info, + [f"{input_type_type_adapter_name.value} = TypeAdapter({render_type_expr(input_type)}) # type: ignore"] + ) + ) output_type, module_info, output_chunks, encoder_names = encode_type( procedure.output, TypeName(f"{name.title()}Output"), @@ -812,13 +823,24 @@ def __init__(self, client: river.Client[Any]): module_names, permit_unknown_members=True, ) + output_type_name = extract_inner_type(output_type) serdes.append( ( - [extract_inner_type(output_type), *encoder_names], + [output_type_name, *encoder_names], module_info, output_chunks, ) ) + output_type_type_adapter_name = TypeName(f"{output_type_name.value}TypeAdapter") + # print('appending %r, %r' % (output_type_type_adapter_name, module_info)) + serdes.append( + ( + [output_type_type_adapter_name], + module_info, + [f"{output_type_type_adapter_name.value} = TypeAdapter({render_type_expr(output_type)}) # type: ignore"], + ) + ) + output_module_info = module_info if procedure.errors: error_type, module_info, errors_chunks, encoder_names = encode_type( procedure.errors, @@ -827,28 +849,48 @@ def __init__(self, client: river.Client[Any]): module_names, permit_unknown_members=True, ) + # print('error type module_info: %r' % module_info) if isinstance(error_type, NoneTypeExpr): - error_type = TypeName("RiverError") + error_type_name = TypeName("RiverError") + error_type = error_type_name else: + error_type_name = extract_inner_type(error_type) serdes.append( - ([extract_inner_type(error_type)], module_info, errors_chunks) + ([error_type_name], module_info, errors_chunks) ) + + else: - error_type = TypeName("RiverError") - output_or_error_type = UnionTypeExpr([output_type, error_type]) + error_type_name = TypeName("RiverError") + error_type = error_type_name + + error_type_type_adapter_name = TypeName(f"{error_type.value}TypeAdapter") + if error_type_type_adapter_name.value != "RiverErrorTypeAdapter": + # print('error type: %r, %r, %r' % (error_type_type_adapter_name, module_info, output_module_info)) + if len(module_info) == 0: + module_info = output_module_info + serdes.append( + ( + [error_type_type_adapter_name], + module_info, + [f"{error_type_type_adapter_name.value} = TypeAdapter({render_type_expr(error_type)}) # type: ignore"], + ) + ) + output_or_error_type = UnionTypeExpr([output_type, error_type_name]) + # NB: These strings must be indented to at least the same level of # the function strings in the branches below, otherwise `dedent` # will pick our indentation level for normalization, which will # break the "def" indentation presuppositions. parse_output_method = f"""\ - lambda x: TypeAdapter({render_type_expr(output_type)}) + lambda x: {output_type_type_adapter_name.value} .validate_python( x # type: ignore[arg-type] ) """ parse_error_method = f"""\ - lambda x: TypeAdapter({render_type_expr(error_type)}) + lambda x: {error_type_type_adapter_name.value} .validate_python( x # type: ignore[arg-type] ) @@ -871,8 +913,17 @@ def __init__(self, client: river.Client[Any]): else: render_init_method = f"encode_{render_literal_type(init_type)}" else: + init_type_name = extract_inner_type(init_type) + init_type_type_adapter_name = TypeName(f"{init_type_name.value}TypeAdapter") + serdes.append( + ( + [init_type_type_adapter_name], + module_info, + [f"{init_type_type_adapter_name.value} = TypeAdapter({render_type_expr(init_type)}) # type: ignore"] + ) + ) render_init_method = f"""\ - lambda x: TypeAdapter({render_type_expr(init_type)}) + lambda x: {init_type_type_adapter_name.name}) .validate_python """ @@ -898,8 +949,9 @@ def __init__(self, client: river.Client[Any]): else: render_input_method = f"encode_{render_literal_type(input_type)}" else: + render_input_method = f"""\ - lambda x: TypeAdapter({render_type_expr(input_type)}) + lambda x: {input_type_type_adapter_name.value} .dump_python( x, # type: ignore[arg-type] by_alias=True, From f5190e3324baf28948d0a0660c83d1be0ac56c38 Mon Sep 17 00:00:00 2001 From: Toby Ho Date: Wed, 12 Mar 2025 20:15:50 -0400 Subject: [PATCH 02/16] refactor --- src/replit_river/codegen/client.py | 43 +++++++++--------------------- 1 file changed, 13 insertions(+), 30 deletions(-) diff --git a/src/replit_river/codegen/client.py b/src/replit_river/codegen/client.py index 62fd5418..81f33a1d 100644 --- a/src/replit_river/codegen/client.py +++ b/src/replit_river/codegen/client.py @@ -764,6 +764,15 @@ def generate_individual_service( input_base_class: Literal["TypedDict"] | Literal["BaseModel"], ) -> Tuple[ModuleName, ClassName, dict[RenderedPath, FileContents]]: + def append_type_adapter_definition(type_adapter_name: TypeName, _type: TypeExpression, module_info: list[ModuleName]): + serdes.append( + ( + [type_adapter_name], + module_info, + [f"{type_adapter_name.value} = TypeAdapter({render_type_expr(_type)}) # type: ignore"] + ) + ) + serdes: list[Tuple[list[TypeName], list[ModuleName], list[FileContents]]] = [] class_name = ClassName(f"{schema_name.title()}Service") current_chunks: List[str] = [ @@ -809,13 +818,7 @@ def __init__(self, client: river.Client[Any]): input_chunks, ) ) - serdes.append( - ( - [input_type_type_adapter_name], - module_info, - [f"{input_type_type_adapter_name.value} = TypeAdapter({render_type_expr(input_type)}) # type: ignore"] - ) - ) + append_type_adapter_definition(input_type_type_adapter_name, input_type, module_info) output_type, module_info, output_chunks, encoder_names = encode_type( procedure.output, TypeName(f"{name.title()}Output"), @@ -832,14 +835,7 @@ def __init__(self, client: river.Client[Any]): ) ) output_type_type_adapter_name = TypeName(f"{output_type_name.value}TypeAdapter") - # print('appending %r, %r' % (output_type_type_adapter_name, module_info)) - serdes.append( - ( - [output_type_type_adapter_name], - module_info, - [f"{output_type_type_adapter_name.value} = TypeAdapter({render_type_expr(output_type)}) # type: ignore"], - ) - ) + append_type_adapter_definition(output_type_type_adapter_name, output_type, module_info) output_module_info = module_info if procedure.errors: error_type, module_info, errors_chunks, encoder_names = encode_type( @@ -866,16 +862,9 @@ def __init__(self, client: river.Client[Any]): error_type_type_adapter_name = TypeName(f"{error_type.value}TypeAdapter") if error_type_type_adapter_name.value != "RiverErrorTypeAdapter": - # print('error type: %r, %r, %r' % (error_type_type_adapter_name, module_info, output_module_info)) if len(module_info) == 0: module_info = output_module_info - serdes.append( - ( - [error_type_type_adapter_name], - module_info, - [f"{error_type_type_adapter_name.value} = TypeAdapter({render_type_expr(error_type)}) # type: ignore"], - ) - ) + append_type_adapter_definition(error_type_type_adapter_name, error_type, module_info) output_or_error_type = UnionTypeExpr([output_type, error_type_name]) @@ -915,13 +904,7 @@ def __init__(self, client: river.Client[Any]): else: init_type_name = extract_inner_type(init_type) init_type_type_adapter_name = TypeName(f"{init_type_name.value}TypeAdapter") - serdes.append( - ( - [init_type_type_adapter_name], - module_info, - [f"{init_type_type_adapter_name.value} = TypeAdapter({render_type_expr(init_type)}) # type: ignore"] - ) - ) + append_type_adapter_definition(init_type_type_adapter_name, init_type, module_info) render_init_method = f"""\ lambda x: {init_type_type_adapter_name.name}) .validate_python From 766304eff73260bc8b1533110fd7f85ae87800cc Mon Sep 17 00:00:00 2001 From: Toby Ho Date: Wed, 12 Mar 2025 20:21:49 -0400 Subject: [PATCH 03/16] updates to generated test output --- .../codegen/rpc/generated/test_service/__init__.py | 14 +++++++++++--- .../rpc/generated/test_service/rpc_method.py | 6 ++++++ .../stream/generated/test_service/__init__.py | 8 ++++++-- .../stream/generated/test_service/stream_method.py | 6 ++++++ 4 files changed, 29 insertions(+), 5 deletions(-) diff --git a/tests/codegen/rpc/generated/test_service/__init__.py b/tests/codegen/rpc/generated/test_service/__init__.py index fc994615..45ad8f39 100644 --- a/tests/codegen/rpc/generated/test_service/__init__.py +++ b/tests/codegen/rpc/generated/test_service/__init__.py @@ -6,10 +6,18 @@ from pydantic import TypeAdapter from replit_river.error_schema import RiverError + +RiverErrorTypeAdapter = TypeAdapter(RiverError) import replit_river as river -from .rpc_method import Rpc_MethodInput, Rpc_MethodOutput, encode_Rpc_MethodInput +from .rpc_method import ( + Rpc_MethodInput, + Rpc_MethodInputTypeAdapter, + Rpc_MethodOutput, + Rpc_MethodOutputTypeAdapter, + encode_Rpc_MethodInput, +) class Test_ServiceService: @@ -26,10 +34,10 @@ async def rpc_method( "rpc_method", input, encode_Rpc_MethodInput, - lambda x: TypeAdapter(Rpc_MethodOutput).validate_python( + lambda x: Rpc_MethodOutputTypeAdapter.validate_python( x # type: ignore[arg-type] ), - lambda x: TypeAdapter(RiverError).validate_python( + lambda x: RiverErrorTypeAdapter.validate_python( x # type: ignore[arg-type] ), timeout, diff --git a/tests/codegen/rpc/generated/test_service/rpc_method.py b/tests/codegen/rpc/generated/test_service/rpc_method.py index 263c36b0..4fb7b45c 100644 --- a/tests/codegen/rpc/generated/test_service/rpc_method.py +++ b/tests/codegen/rpc/generated/test_service/rpc_method.py @@ -39,5 +39,11 @@ class Rpc_MethodInput(TypedDict): data: str +Rpc_MethodInputTypeAdapter = TypeAdapter(Rpc_MethodInput) # type: ignore + + class Rpc_MethodOutput(BaseModel): data: str + + +Rpc_MethodOutputTypeAdapter = TypeAdapter(Rpc_MethodOutput) # type: ignore diff --git a/tests/codegen/stream/generated/test_service/__init__.py b/tests/codegen/stream/generated/test_service/__init__.py index f59c046e..df410156 100644 --- a/tests/codegen/stream/generated/test_service/__init__.py +++ b/tests/codegen/stream/generated/test_service/__init__.py @@ -6,12 +6,16 @@ from pydantic import TypeAdapter from replit_river.error_schema import RiverError + +RiverErrorTypeAdapter = TypeAdapter(RiverError) import replit_river as river from .stream_method import ( Stream_MethodInput, + Stream_MethodInputTypeAdapter, Stream_MethodOutput, + Stream_MethodOutputTypeAdapter, encode_Stream_MethodInput, ) @@ -31,10 +35,10 @@ async def stream_method( inputStream, None, encode_Stream_MethodInput, - lambda x: TypeAdapter(Stream_MethodOutput).validate_python( + lambda x: Stream_MethodOutputTypeAdapter.validate_python( x # type: ignore[arg-type] ), - lambda x: TypeAdapter(RiverError).validate_python( + lambda x: RiverErrorTypeAdapter.validate_python( x # type: ignore[arg-type] ), ) diff --git a/tests/codegen/stream/generated/test_service/stream_method.py b/tests/codegen/stream/generated/test_service/stream_method.py index 53267467..42d7f28c 100644 --- a/tests/codegen/stream/generated/test_service/stream_method.py +++ b/tests/codegen/stream/generated/test_service/stream_method.py @@ -39,5 +39,11 @@ class Stream_MethodInput(TypedDict): data: str +Stream_MethodInputTypeAdapter = TypeAdapter(Stream_MethodInput) # type: ignore + + class Stream_MethodOutput(BaseModel): data: str + + +Stream_MethodOutputTypeAdapter = TypeAdapter(Stream_MethodOutput) # type: ignore From 9b0f693015fd26ee9bf3bea90a8f702b625d19a1 Mon Sep 17 00:00:00 2001 From: Toby Ho Date: Wed, 12 Mar 2025 20:35:42 -0400 Subject: [PATCH 04/16] tests pass --- .python-version | 2 +- .replit | 4 ++-- flake.nix | 2 +- .../test_unknown_enum/enumService/__init__.py | 16 ++++++++++++---- .../test_unknown_enum/enumService/needsEnum.py | 3 +++ .../enumService/needsEnumObject.py | 7 +++++++ 6 files changed, 26 insertions(+), 8 deletions(-) diff --git a/.python-version b/.python-version index 2c073331..e4fba218 100644 --- a/.python-version +++ b/.python-version @@ -1 +1 @@ -3.11 +3.12 diff --git a/.replit b/.replit index 5b89454c..9a6274c0 100644 --- a/.replit +++ b/.replit @@ -1,6 +1,6 @@ run = "poetry run pytest tests" -modules = ["python-3.11"] +modules = ["python-3.12"] [nix] -channel = "stable-23_11" +channel = "stable-24_11" diff --git a/flake.nix b/flake.nix index 684a5e2c..5ad206f6 100644 --- a/flake.nix +++ b/flake.nix @@ -18,7 +18,7 @@ LD_LIBRARY_PATH = "${pkgs.stdenv.cc.cc.lib}/lib"; }; packages = replitNixDeps ++ [ - pkgs.python311 + pkgs.python312 pkgs.uv ]; }; diff --git a/tests/codegen/snapshot/snapshots/test_unknown_enum/enumService/__init__.py b/tests/codegen/snapshot/snapshots/test_unknown_enum/enumService/__init__.py index 7477adb8..a357bdcb 100644 --- a/tests/codegen/snapshot/snapshots/test_unknown_enum/enumService/__init__.py +++ b/tests/codegen/snapshot/snapshots/test_unknown_enum/enumService/__init__.py @@ -6,19 +6,27 @@ from pydantic import TypeAdapter from replit_river.error_schema import RiverError + +RiverErrorTypeAdapter = TypeAdapter(RiverError) import replit_river as river from .needsEnum import ( NeedsenumErrors, + NeedsenumErrorsTypeAdapter, NeedsenumInput, + NeedsenumInputTypeAdapter, NeedsenumOutput, + NeedsenumOutputTypeAdapter, encode_NeedsenumInput, ) from .needsEnumObject import ( NeedsenumobjectErrors, + NeedsenumobjectErrorsTypeAdapter, NeedsenumobjectInput, + NeedsenumobjectInputTypeAdapter, NeedsenumobjectOutput, + NeedsenumobjectOutputTypeAdapter, encode_NeedsenumobjectInput, ) @@ -37,10 +45,10 @@ async def needsEnum( "needsEnum", input, lambda x: x, - lambda x: TypeAdapter(NeedsenumOutput).validate_python( + lambda x: NeedsenumOutputTypeAdapter.validate_python( x # type: ignore[arg-type] ), - lambda x: TypeAdapter(NeedsenumErrors).validate_python( + lambda x: NeedsenumErrorsTypeAdapter.validate_python( x # type: ignore[arg-type] ), timeout, @@ -56,10 +64,10 @@ async def needsEnumObject( "needsEnumObject", input, encode_NeedsenumobjectInput, - lambda x: TypeAdapter(NeedsenumobjectOutput).validate_python( + lambda x: NeedsenumobjectOutputTypeAdapter.validate_python( x # type: ignore[arg-type] ), - lambda x: TypeAdapter(NeedsenumobjectErrors).validate_python( + lambda x: NeedsenumobjectErrorsTypeAdapter.validate_python( x # type: ignore[arg-type] ), timeout, diff --git a/tests/codegen/snapshot/snapshots/test_unknown_enum/enumService/needsEnum.py b/tests/codegen/snapshot/snapshots/test_unknown_enum/enumService/needsEnum.py index 426049b9..e77ec6d5 100644 --- a/tests/codegen/snapshot/snapshots/test_unknown_enum/enumService/needsEnum.py +++ b/tests/codegen/snapshot/snapshots/test_unknown_enum/enumService/needsEnum.py @@ -26,11 +26,14 @@ NeedsenumInput = Literal["in_first"] | Literal["in_second"] encode_NeedsenumInput: Callable[["NeedsenumInput"], Any] = lambda x: x +NeedsenumInputTypeAdapter = TypeAdapter(NeedsenumInput) # type: ignore NeedsenumOutput = Annotated[ Literal["out_first"] | Literal["out_second"] | RiverUnknownValue, WrapValidator(translate_unknown_value), ] +NeedsenumOutputTypeAdapter = TypeAdapter(NeedsenumOutput) # type: ignore NeedsenumErrors = Annotated[ Literal["err_first"] | Literal["err_second"] | RiverUnknownValue, WrapValidator(translate_unknown_value), ] +NeedsenumErrorsTypeAdapter = TypeAdapter(NeedsenumErrors) # type: ignore diff --git a/tests/codegen/snapshot/snapshots/test_unknown_enum/enumService/needsEnumObject.py b/tests/codegen/snapshot/snapshots/test_unknown_enum/enumService/needsEnumObject.py index 6766eb79..6faf4d76 100644 --- a/tests/codegen/snapshot/snapshots/test_unknown_enum/enumService/needsEnumObject.py +++ b/tests/codegen/snapshot/snapshots/test_unknown_enum/enumService/needsEnumObject.py @@ -71,6 +71,7 @@ class NeedsenumobjectInputOneOf_in_second(TypedDict): if x["kind"] == "in_first" else encode_NeedsenumobjectInputOneOf_in_second(x) ) +NeedsenumobjectInputTypeAdapter = TypeAdapter(NeedsenumobjectInput) # type: ignore class NeedsenumobjectOutputFooOneOf_out_first(BaseModel): @@ -103,6 +104,9 @@ class NeedsenumobjectOutput(BaseModel): foo: Optional[NeedsenumobjectOutputFoo] = None +NeedsenumobjectOutputTypeAdapter = TypeAdapter(NeedsenumobjectOutput) # type: ignore + + class NeedsenumobjectErrorsFooAnyOf_0(RiverError): beep: Optional[Literal["err_first"]] = None @@ -121,3 +125,6 @@ class NeedsenumobjectErrorsFooAnyOf_1(RiverError): class NeedsenumobjectErrors(RiverError): foo: Optional[NeedsenumobjectErrorsFoo] = None + + +NeedsenumobjectErrorsTypeAdapter = TypeAdapter(NeedsenumobjectErrors) # type: ignore From 29b935334db69b80ad9bf80f95a9c202e6b67f62 Mon Sep 17 00:00:00 2001 From: Toby Ho Date: Wed, 12 Mar 2025 20:47:54 -0400 Subject: [PATCH 05/16] lint --- src/replit_river/codegen/client.py | 50 +++++++++++++++++++----------- 1 file changed, 32 insertions(+), 18 deletions(-) diff --git a/src/replit_river/codegen/client.py b/src/replit_river/codegen/client.py index 81f33a1d..e71fbbf0 100644 --- a/src/replit_river/codegen/client.py +++ b/src/replit_river/codegen/client.py @@ -762,14 +762,23 @@ def generate_individual_service( schema_name: str, schema: RiverService, input_base_class: Literal["TypedDict"] | Literal["BaseModel"], - ) -> Tuple[ModuleName, ClassName, dict[RenderedPath, FileContents]]: - def append_type_adapter_definition(type_adapter_name: TypeName, _type: TypeExpression, module_info: list[ModuleName]): + def append_type_adapter_definition( + type_adapter_name: TypeName, + _type: TypeExpression, + module_info: list[ModuleName], + ) -> None: serdes.append( ( [type_adapter_name], module_info, - [f"{type_adapter_name.value} = TypeAdapter({render_type_expr(_type)}) # type: ignore"] + [ + FileContents( + f"{type_adapter_name.value} = " + f"TypeAdapter({render_type_expr(_type)})" + " # type: ignore" + ) + ], ) ) @@ -818,7 +827,9 @@ def __init__(self, client: river.Client[Any]): input_chunks, ) ) - append_type_adapter_definition(input_type_type_adapter_name, input_type, module_info) + append_type_adapter_definition( + input_type_type_adapter_name, input_type, module_info + ) output_type, module_info, output_chunks, encoder_names = encode_type( procedure.output, TypeName(f"{name.title()}Output"), @@ -835,7 +846,9 @@ def __init__(self, client: river.Client[Any]): ) ) output_type_type_adapter_name = TypeName(f"{output_type_name.value}TypeAdapter") - append_type_adapter_definition(output_type_type_adapter_name, output_type, module_info) + append_type_adapter_definition( + output_type_type_adapter_name, output_type, module_info + ) output_module_info = module_info if procedure.errors: error_type, module_info, errors_chunks, encoder_names = encode_type( @@ -851,22 +864,20 @@ def __init__(self, client: river.Client[Any]): error_type = error_type_name else: error_type_name = extract_inner_type(error_type) - serdes.append( - ([error_type_name], module_info, errors_chunks) - ) - + serdes.append(([error_type_name], module_info, errors_chunks)) else: error_type_name = TypeName("RiverError") error_type = error_type_name - error_type_type_adapter_name = TypeName(f"{error_type.value}TypeAdapter") + error_type_type_adapter_name = TypeName(f"{error_type_name.value}TypeAdapter") if error_type_type_adapter_name.value != "RiverErrorTypeAdapter": if len(module_info) == 0: module_info = output_module_info - append_type_adapter_definition(error_type_type_adapter_name, error_type, module_info) + append_type_adapter_definition( + error_type_type_adapter_name, error_type, module_info + ) output_or_error_type = UnionTypeExpr([output_type, error_type_name]) - # NB: These strings must be indented to at least the same level of # the function strings in the branches below, otherwise `dedent` @@ -903,10 +914,14 @@ def __init__(self, client: river.Client[Any]): render_init_method = f"encode_{render_literal_type(init_type)}" else: init_type_name = extract_inner_type(init_type) - init_type_type_adapter_name = TypeName(f"{init_type_name.value}TypeAdapter") - append_type_adapter_definition(init_type_type_adapter_name, init_type, module_info) + init_type_type_adapter_name = TypeName( + f"{init_type_name.value}TypeAdapter" + ) + append_type_adapter_definition( + init_type_type_adapter_name, init_type, module_info + ) render_init_method = f"""\ - lambda x: {init_type_type_adapter_name.name}) + lambda x: {init_type_type_adapter_name.value}) .validate_python """ @@ -923,16 +938,15 @@ def __init__(self, client: river.Client[Any]): procedure.input, RiverConcreteType ) and procedure.input.type in ["array"]: match input_type: - case ListTypeExpr(input_type_name): + case ListTypeExpr(list_type): render_input_method = f"""\ lambda xs: [ - encode_{render_literal_type(input_type_name)}(x) for x in xs + encode_{render_literal_type(list_type)}(x) for x in xs ] """ else: render_input_method = f"encode_{render_literal_type(input_type)}" else: - render_input_method = f"""\ lambda x: {input_type_type_adapter_name.value} .dump_python( From 2f545bf0bb970dd8ba1a4432ca31be1f4caf67e8 Mon Sep 17 00:00:00 2001 From: Toby Ho Date: Thu, 13 Mar 2025 09:03:19 -0400 Subject: [PATCH 06/16] cleanup --- src/replit_river/codegen/client.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/replit_river/codegen/client.py b/src/replit_river/codegen/client.py index e71fbbf0..02c260dc 100644 --- a/src/replit_river/codegen/client.py +++ b/src/replit_river/codegen/client.py @@ -858,7 +858,6 @@ def __init__(self, client: river.Client[Any]): module_names, permit_unknown_members=True, ) - # print('error type module_info: %r' % module_info) if isinstance(error_type, NoneTypeExpr): error_type_name = TypeName("RiverError") error_type = error_type_name From 6b57ec453f221ae8a29028c468365d81a5c7c876 Mon Sep 17 00:00:00 2001 From: Toby Ho Date: Thu, 13 Mar 2025 12:44:18 -0400 Subject: [PATCH 07/16] use Any --- src/replit_river/codegen/client.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/replit_river/codegen/client.py b/src/replit_river/codegen/client.py index 02c260dc..3fa007d6 100644 --- a/src/replit_river/codegen/client.py +++ b/src/replit_river/codegen/client.py @@ -768,15 +768,15 @@ def append_type_adapter_definition( _type: TypeExpression, module_info: list[ModuleName], ) -> None: + rendered_type_expr = render_type_expr(_type) serdes.append( ( [type_adapter_name], module_info, [ FileContents( - f"{type_adapter_name.value} = " - f"TypeAdapter({render_type_expr(_type)})" - " # type: ignore" + f"{type_adapter_name.value}: Any = " + f"TypeAdapter({rendered_type_expr})" ) ], ) From 7be8c26bc6d5d409e7c0df98cce81972250ad1d5 Mon Sep 17 00:00:00 2001 From: Toby Ho Date: Thu, 13 Mar 2025 12:59:19 -0400 Subject: [PATCH 08/16] allow Any for error type so the existing error types can match them; type the type adapters properly --- src/replit_river/client.py | 8 ++++---- src/replit_river/codegen/client.py | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/replit_river/client.py b/src/replit_river/client.py index 51e45a2d..dad15bb7 100644 --- a/src/replit_river/client.py +++ b/src/replit_river/client.py @@ -80,7 +80,7 @@ async def send_rpc( request: RequestType, request_serializer: Callable[[RequestType], Any], response_deserializer: Callable[[Any], ResponseType], - error_deserializer: Callable[[Any], ErrorType], + error_deserializer: Callable[[Any], Any], timeout: timedelta, ) -> ResponseType: with _trace_procedure("rpc", service_name, procedure_name) as span_handle: @@ -105,7 +105,7 @@ async def send_upload( init_serializer: Optional[Callable[[InitType], Any]], request_serializer: Callable[[RequestType], Any], response_deserializer: Callable[[Any], ResponseType], - error_deserializer: Callable[[Any], ErrorType], + error_deserializer: Callable[[Any], Any], ) -> ResponseType: with _trace_procedure("upload", service_name, procedure_name) as span_handle: session = await self._transport.get_or_create_session() @@ -128,7 +128,7 @@ async def send_subscription( request: RequestType, request_serializer: Callable[[RequestType], Any], response_deserializer: Callable[[Any], ResponseType], - error_deserializer: Callable[[Any], ErrorType], + error_deserializer: Callable[[Any], Any], ) -> AsyncGenerator[Union[ResponseType, ErrorType], None]: with _trace_procedure( "subscription", service_name, procedure_name @@ -156,7 +156,7 @@ async def send_stream( init_serializer: Optional[Callable[[InitType], Any]], request_serializer: Callable[[RequestType], Any], response_deserializer: Callable[[Any], ResponseType], - error_deserializer: Callable[[Any], ErrorType], + error_deserializer: Callable[[Any], Any], ) -> AsyncGenerator[Union[ResponseType, ErrorType], None]: with _trace_procedure("stream", service_name, procedure_name) as span_handle: session = await self._transport.get_or_create_session() diff --git a/src/replit_river/codegen/client.py b/src/replit_river/codegen/client.py index 3fa007d6..9a6518a8 100644 --- a/src/replit_river/codegen/client.py +++ b/src/replit_river/codegen/client.py @@ -775,7 +775,7 @@ def append_type_adapter_definition( module_info, [ FileContents( - f"{type_adapter_name.value}: Any = " + f"{type_adapter_name.value}: TypeAdapter[{rendered_type_expr}] = " f"TypeAdapter({rendered_type_expr})" ) ], From c77679ebf461ea61ff711fe7c68c3ed26eac3c24 Mon Sep 17 00:00:00 2001 From: Toby Ho Date: Thu, 13 Mar 2025 13:04:44 -0400 Subject: [PATCH 09/16] resnapshotted the tests --- .../codegen/rpc/generated/test_service/rpc_method.py | 6 ++++-- .../test_unknown_enum/enumService/needsEnum.py | 6 +++--- .../test_unknown_enum/enumService/needsEnumObject.py | 12 +++++++++--- .../stream/generated/test_service/stream_method.py | 8 ++++++-- 4 files changed, 22 insertions(+), 10 deletions(-) diff --git a/tests/codegen/rpc/generated/test_service/rpc_method.py b/tests/codegen/rpc/generated/test_service/rpc_method.py index 4fb7b45c..5bf32657 100644 --- a/tests/codegen/rpc/generated/test_service/rpc_method.py +++ b/tests/codegen/rpc/generated/test_service/rpc_method.py @@ -39,11 +39,13 @@ class Rpc_MethodInput(TypedDict): data: str -Rpc_MethodInputTypeAdapter = TypeAdapter(Rpc_MethodInput) # type: ignore +Rpc_MethodInputTypeAdapter: TypeAdapter[Rpc_MethodInput] = TypeAdapter(Rpc_MethodInput) class Rpc_MethodOutput(BaseModel): data: str -Rpc_MethodOutputTypeAdapter = TypeAdapter(Rpc_MethodOutput) # type: ignore +Rpc_MethodOutputTypeAdapter: TypeAdapter[Rpc_MethodOutput] = TypeAdapter( + Rpc_MethodOutput +) diff --git a/tests/codegen/snapshot/snapshots/test_unknown_enum/enumService/needsEnum.py b/tests/codegen/snapshot/snapshots/test_unknown_enum/enumService/needsEnum.py index e77ec6d5..3555eaac 100644 --- a/tests/codegen/snapshot/snapshots/test_unknown_enum/enumService/needsEnum.py +++ b/tests/codegen/snapshot/snapshots/test_unknown_enum/enumService/needsEnum.py @@ -26,14 +26,14 @@ NeedsenumInput = Literal["in_first"] | Literal["in_second"] encode_NeedsenumInput: Callable[["NeedsenumInput"], Any] = lambda x: x -NeedsenumInputTypeAdapter = TypeAdapter(NeedsenumInput) # type: ignore +NeedsenumInputTypeAdapter: TypeAdapter[NeedsenumInput] = TypeAdapter(NeedsenumInput) NeedsenumOutput = Annotated[ Literal["out_first"] | Literal["out_second"] | RiverUnknownValue, WrapValidator(translate_unknown_value), ] -NeedsenumOutputTypeAdapter = TypeAdapter(NeedsenumOutput) # type: ignore +NeedsenumOutputTypeAdapter: TypeAdapter[NeedsenumOutput] = TypeAdapter(NeedsenumOutput) NeedsenumErrors = Annotated[ Literal["err_first"] | Literal["err_second"] | RiverUnknownValue, WrapValidator(translate_unknown_value), ] -NeedsenumErrorsTypeAdapter = TypeAdapter(NeedsenumErrors) # type: ignore +NeedsenumErrorsTypeAdapter: TypeAdapter[NeedsenumErrors] = TypeAdapter(NeedsenumErrors) diff --git a/tests/codegen/snapshot/snapshots/test_unknown_enum/enumService/needsEnumObject.py b/tests/codegen/snapshot/snapshots/test_unknown_enum/enumService/needsEnumObject.py index 6faf4d76..e55e9e46 100644 --- a/tests/codegen/snapshot/snapshots/test_unknown_enum/enumService/needsEnumObject.py +++ b/tests/codegen/snapshot/snapshots/test_unknown_enum/enumService/needsEnumObject.py @@ -71,7 +71,9 @@ class NeedsenumobjectInputOneOf_in_second(TypedDict): if x["kind"] == "in_first" else encode_NeedsenumobjectInputOneOf_in_second(x) ) -NeedsenumobjectInputTypeAdapter = TypeAdapter(NeedsenumobjectInput) # type: ignore +NeedsenumobjectInputTypeAdapter: TypeAdapter[NeedsenumobjectInput] = TypeAdapter( + NeedsenumobjectInput +) class NeedsenumobjectOutputFooOneOf_out_first(BaseModel): @@ -104,7 +106,9 @@ class NeedsenumobjectOutput(BaseModel): foo: Optional[NeedsenumobjectOutputFoo] = None -NeedsenumobjectOutputTypeAdapter = TypeAdapter(NeedsenumobjectOutput) # type: ignore +NeedsenumobjectOutputTypeAdapter: TypeAdapter[NeedsenumobjectOutput] = TypeAdapter( + NeedsenumobjectOutput +) class NeedsenumobjectErrorsFooAnyOf_0(RiverError): @@ -127,4 +131,6 @@ class NeedsenumobjectErrors(RiverError): foo: Optional[NeedsenumobjectErrorsFoo] = None -NeedsenumobjectErrorsTypeAdapter = TypeAdapter(NeedsenumobjectErrors) # type: ignore +NeedsenumobjectErrorsTypeAdapter: TypeAdapter[NeedsenumobjectErrors] = TypeAdapter( + NeedsenumobjectErrors +) diff --git a/tests/codegen/stream/generated/test_service/stream_method.py b/tests/codegen/stream/generated/test_service/stream_method.py index 42d7f28c..13d45c2d 100644 --- a/tests/codegen/stream/generated/test_service/stream_method.py +++ b/tests/codegen/stream/generated/test_service/stream_method.py @@ -39,11 +39,15 @@ class Stream_MethodInput(TypedDict): data: str -Stream_MethodInputTypeAdapter = TypeAdapter(Stream_MethodInput) # type: ignore +Stream_MethodInputTypeAdapter: TypeAdapter[Stream_MethodInput] = TypeAdapter( + Stream_MethodInput +) class Stream_MethodOutput(BaseModel): data: str -Stream_MethodOutputTypeAdapter = TypeAdapter(Stream_MethodOutput) # type: ignore +Stream_MethodOutputTypeAdapter: TypeAdapter[Stream_MethodOutput] = TypeAdapter( + Stream_MethodOutput +) From 8dcd43937133633526146cefffaafe1067bc1661 Mon Sep 17 00:00:00 2001 From: Toby Ho Date: Thu, 13 Mar 2025 16:16:25 -0400 Subject: [PATCH 10/16] regened the code --- src/replit_river/client.py | 5 +-- src/replit_river/codegen/client.py | 40 +++++++++++-------- src/replit_river/error_schema.py | 5 ++- .../rpc/generated/test_service/__init__.py | 4 +- .../test_unknown_enum/enumService/__init__.py | 4 +- .../stream/generated/test_service/__init__.py | 4 +- 6 files changed, 32 insertions(+), 30 deletions(-) diff --git a/src/replit_river/client.py b/src/replit_river/client.py index dad15bb7..5ee79a96 100644 --- a/src/replit_river/client.py +++ b/src/replit_river/client.py @@ -21,7 +21,6 @@ ) from .rpc import ( - ErrorType, InitType, RequestType, ResponseType, @@ -129,7 +128,7 @@ async def send_subscription( request_serializer: Callable[[RequestType], Any], response_deserializer: Callable[[Any], ResponseType], error_deserializer: Callable[[Any], Any], - ) -> AsyncGenerator[Union[ResponseType, ErrorType], None]: + ) -> AsyncGenerator[Union[ResponseType, RiverError], None]: with _trace_procedure( "subscription", service_name, procedure_name ) as span_handle: @@ -157,7 +156,7 @@ async def send_stream( request_serializer: Callable[[RequestType], Any], response_deserializer: Callable[[Any], ResponseType], error_deserializer: Callable[[Any], Any], - ) -> AsyncGenerator[Union[ResponseType, ErrorType], None]: + ) -> AsyncGenerator[Union[ResponseType, RiverError], None]: with _trace_procedure("stream", service_name, procedure_name) as span_handle: session = await self._transport.get_or_create_session() async for msg in session.send_stream( diff --git a/src/replit_river/codegen/client.py b/src/replit_river/codegen/client.py index 9a6518a8..71bb4a3e 100644 --- a/src/replit_river/codegen/client.py +++ b/src/replit_river/codegen/client.py @@ -64,8 +64,7 @@ from pydantic import TypeAdapter -from replit_river.error_schema import RiverError -RiverErrorTypeAdapter = TypeAdapter(RiverError) +from replit_river.error_schema import RiverError, RiverErrorTypeAdapter import replit_river as river """ @@ -763,26 +762,25 @@ def generate_individual_service( schema: RiverService, input_base_class: Literal["TypedDict"] | Literal["BaseModel"], ) -> Tuple[ModuleName, ClassName, dict[RenderedPath, FileContents]]: + serdes: list[Tuple[list[TypeName], list[ModuleName], list[FileContents]]] = [] + def append_type_adapter_definition( type_adapter_name: TypeName, _type: TypeExpression, module_info: list[ModuleName], ) -> None: rendered_type_expr = render_type_expr(_type) + var_name = render_type_expr(type_adapter_name) + var_type = f"TypeAdapter[{rendered_type_expr}]" + var_value = f"TypeAdapter({rendered_type_expr})" serdes.append( ( [type_adapter_name], module_info, - [ - FileContents( - f"{type_adapter_name.value}: TypeAdapter[{rendered_type_expr}] = " - f"TypeAdapter({rendered_type_expr})" - ) - ], + [FileContents(f"{var_name}: {var_type} = {var_value}")], ) ) - serdes: list[Tuple[list[TypeName], list[ModuleName], list[FileContents]]] = [] class_name = ClassName(f"{schema_name.title()}Service") current_chunks: List[str] = [ dedent( @@ -819,7 +817,9 @@ def __init__(self, client: river.Client[Any]): permit_unknown_members=False, ) input_type_name = extract_inner_type(input_type) - input_type_type_adapter_name = TypeName(f"{input_type_name.value}TypeAdapter") + input_type_type_adapter_name = TypeName( + f"{render_literal_type(input_type_name)}TypeAdapter" + ) serdes.append( ( [extract_inner_type(input_type), *encoder_names], @@ -845,7 +845,9 @@ def __init__(self, client: river.Client[Any]): output_chunks, ) ) - output_type_type_adapter_name = TypeName(f"{output_type_name.value}TypeAdapter") + output_type_type_adapter_name = TypeName( + f"{render_literal_type(output_type_name)}TypeAdapter" + ) append_type_adapter_definition( output_type_type_adapter_name, output_type, module_info ) @@ -869,7 +871,9 @@ def __init__(self, client: river.Client[Any]): error_type_name = TypeName("RiverError") error_type = error_type_name - error_type_type_adapter_name = TypeName(f"{error_type_name.value}TypeAdapter") + error_type_type_adapter_name = TypeName( + f"{render_literal_type(error_type_name)}TypeAdapter" + ) if error_type_type_adapter_name.value != "RiverErrorTypeAdapter": if len(module_info) == 0: module_info = output_module_info @@ -882,14 +886,16 @@ def __init__(self, client: river.Client[Any]): # the function strings in the branches below, otherwise `dedent` # will pick our indentation level for normalization, which will # break the "def" indentation presuppositions. + ottd_name = render_literal_type(output_type_type_adapter_name) parse_output_method = f"""\ - lambda x: {output_type_type_adapter_name.value} + lambda x: {ottd_name} .validate_python( x # type: ignore[arg-type] ) """ + ettd_name = render_literal_type(error_type_type_adapter_name) parse_error_method = f"""\ - lambda x: {error_type_type_adapter_name.value} + lambda x: {ettd_name} .validate_python( x # type: ignore[arg-type] ) @@ -920,8 +926,8 @@ def __init__(self, client: river.Client[Any]): init_type_type_adapter_name, init_type, module_info ) render_init_method = f"""\ - lambda x: {init_type_type_adapter_name.value}) - .validate_python + lambda x: {render_type_expr(init_type_type_adapter_name)} + .validate_python """ assert init_type is None or render_init_method, ( @@ -947,7 +953,7 @@ def __init__(self, client: river.Client[Any]): render_input_method = f"encode_{render_literal_type(input_type)}" else: render_input_method = f"""\ - lambda x: {input_type_type_adapter_name.value} + lambda x: {render_type_expr(input_type_type_adapter_name)} .dump_python( x, # type: ignore[arg-type] by_alias=True, diff --git a/src/replit_river/error_schema.py b/src/replit_river/error_schema.py index dc74fa98..a97fbc9c 100644 --- a/src/replit_river/error_schema.py +++ b/src/replit_river/error_schema.py @@ -1,6 +1,6 @@ from typing import Any, List, Optional -from pydantic import BaseModel +from pydantic import BaseModel, TypeAdapter ERROR_CODE_STREAM_CLOSED = "stream_closed" ERROR_HANDSHAKE = "handshake_failed" @@ -25,6 +25,9 @@ class RiverError(BaseModel): message: str +RiverErrorTypeAdapter = TypeAdapter(RiverError) + + class RiverException(Exception): """Exception raised by the River server.""" diff --git a/tests/codegen/rpc/generated/test_service/__init__.py b/tests/codegen/rpc/generated/test_service/__init__.py index 45ad8f39..24545e00 100644 --- a/tests/codegen/rpc/generated/test_service/__init__.py +++ b/tests/codegen/rpc/generated/test_service/__init__.py @@ -5,9 +5,7 @@ from pydantic import TypeAdapter -from replit_river.error_schema import RiverError - -RiverErrorTypeAdapter = TypeAdapter(RiverError) +from replit_river.error_schema import RiverError, RiverErrorTypeAdapter import replit_river as river diff --git a/tests/codegen/snapshot/snapshots/test_unknown_enum/enumService/__init__.py b/tests/codegen/snapshot/snapshots/test_unknown_enum/enumService/__init__.py index a357bdcb..ab9eaa08 100644 --- a/tests/codegen/snapshot/snapshots/test_unknown_enum/enumService/__init__.py +++ b/tests/codegen/snapshot/snapshots/test_unknown_enum/enumService/__init__.py @@ -5,9 +5,7 @@ from pydantic import TypeAdapter -from replit_river.error_schema import RiverError - -RiverErrorTypeAdapter = TypeAdapter(RiverError) +from replit_river.error_schema import RiverError, RiverErrorTypeAdapter import replit_river as river diff --git a/tests/codegen/stream/generated/test_service/__init__.py b/tests/codegen/stream/generated/test_service/__init__.py index df410156..b0d628b2 100644 --- a/tests/codegen/stream/generated/test_service/__init__.py +++ b/tests/codegen/stream/generated/test_service/__init__.py @@ -5,9 +5,7 @@ from pydantic import TypeAdapter -from replit_river.error_schema import RiverError - -RiverErrorTypeAdapter = TypeAdapter(RiverError) +from replit_river.error_schema import RiverError, RiverErrorTypeAdapter import replit_river as river From cb3aac9e841901d019bd7a5bb3904d8b0f640269 Mon Sep 17 00:00:00 2001 From: Toby Ho Date: Fri, 14 Mar 2025 16:36:04 -0400 Subject: [PATCH 11/16] lint --- src/replit_river/codegen/client.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/src/replit_river/codegen/client.py b/src/replit_river/codegen/client.py index 71bb4a3e..8a63a4ef 100644 --- a/src/replit_river/codegen/client.py +++ b/src/replit_river/codegen/client.py @@ -770,14 +770,19 @@ def append_type_adapter_definition( module_info: list[ModuleName], ) -> None: rendered_type_expr = render_type_expr(_type) - var_name = render_type_expr(type_adapter_name) - var_type = f"TypeAdapter[{rendered_type_expr}]" - var_value = f"TypeAdapter({rendered_type_expr})" serdes.append( ( [type_adapter_name], module_info, - [FileContents(f"{var_name}: {var_type} = {var_value}")], + [ + FileContents( + dedent(f""" + {render_type_expr(type_adapter_name)}: TypeAdapter[Any] = ( + TypeAdapter({rendered_type_expr}) + ) + """) + ) + ], ) ) From cda628f88bc70973109591d27368fd049b8337a9 Mon Sep 17 00:00:00 2001 From: Toby Ho Date: Fri, 14 Mar 2025 16:37:43 -0400 Subject: [PATCH 12/16] non-abreviated names --- src/replit_river/codegen/client.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/replit_river/codegen/client.py b/src/replit_river/codegen/client.py index 8a63a4ef..fcf93ae1 100644 --- a/src/replit_river/codegen/client.py +++ b/src/replit_river/codegen/client.py @@ -891,16 +891,16 @@ def __init__(self, client: river.Client[Any]): # the function strings in the branches below, otherwise `dedent` # will pick our indentation level for normalization, which will # break the "def" indentation presuppositions. - ottd_name = render_literal_type(output_type_type_adapter_name) + output_type_adapter = render_literal_type(output_type_type_adapter_name) parse_output_method = f"""\ - lambda x: {ottd_name} + lambda x: {output_type_adapter} .validate_python( x # type: ignore[arg-type] ) """ - ettd_name = render_literal_type(error_type_type_adapter_name) + error_type_adapter = render_literal_type(error_type_type_adapter_name) parse_error_method = f"""\ - lambda x: {ettd_name} + lambda x: {error_type_adapter} .validate_python( x # type: ignore[arg-type] ) From 70c5ab8172cba3a48a9b6dc4c1febc4ea26e4894 Mon Sep 17 00:00:00 2001 From: Toby Ho Date: Fri, 14 Mar 2025 16:40:54 -0400 Subject: [PATCH 13/16] test snapshot --- .../rpc/generated/test_service/rpc_method.py | 6 ++---- .../test_unknown_enum/enumService/needsEnum.py | 11 ++++++++--- .../enumService/needsEnumObject.py | 13 ++++--------- .../stream/generated/test_service/stream_method.py | 8 ++------ 4 files changed, 16 insertions(+), 22 deletions(-) diff --git a/tests/codegen/rpc/generated/test_service/rpc_method.py b/tests/codegen/rpc/generated/test_service/rpc_method.py index 5bf32657..d32b4645 100644 --- a/tests/codegen/rpc/generated/test_service/rpc_method.py +++ b/tests/codegen/rpc/generated/test_service/rpc_method.py @@ -39,13 +39,11 @@ class Rpc_MethodInput(TypedDict): data: str -Rpc_MethodInputTypeAdapter: TypeAdapter[Rpc_MethodInput] = TypeAdapter(Rpc_MethodInput) +Rpc_MethodInputTypeAdapter: TypeAdapter[Any] = TypeAdapter(Rpc_MethodInput) class Rpc_MethodOutput(BaseModel): data: str -Rpc_MethodOutputTypeAdapter: TypeAdapter[Rpc_MethodOutput] = TypeAdapter( - Rpc_MethodOutput -) +Rpc_MethodOutputTypeAdapter: TypeAdapter[Any] = TypeAdapter(Rpc_MethodOutput) diff --git a/tests/codegen/snapshot/snapshots/test_unknown_enum/enumService/needsEnum.py b/tests/codegen/snapshot/snapshots/test_unknown_enum/enumService/needsEnum.py index 3555eaac..95e6bd2c 100644 --- a/tests/codegen/snapshot/snapshots/test_unknown_enum/enumService/needsEnum.py +++ b/tests/codegen/snapshot/snapshots/test_unknown_enum/enumService/needsEnum.py @@ -26,14 +26,19 @@ NeedsenumInput = Literal["in_first"] | Literal["in_second"] encode_NeedsenumInput: Callable[["NeedsenumInput"], Any] = lambda x: x -NeedsenumInputTypeAdapter: TypeAdapter[NeedsenumInput] = TypeAdapter(NeedsenumInput) + +NeedsenumInputTypeAdapter: TypeAdapter[Any] = TypeAdapter(NeedsenumInput) + NeedsenumOutput = Annotated[ Literal["out_first"] | Literal["out_second"] | RiverUnknownValue, WrapValidator(translate_unknown_value), ] -NeedsenumOutputTypeAdapter: TypeAdapter[NeedsenumOutput] = TypeAdapter(NeedsenumOutput) + +NeedsenumOutputTypeAdapter: TypeAdapter[Any] = TypeAdapter(NeedsenumOutput) + NeedsenumErrors = Annotated[ Literal["err_first"] | Literal["err_second"] | RiverUnknownValue, WrapValidator(translate_unknown_value), ] -NeedsenumErrorsTypeAdapter: TypeAdapter[NeedsenumErrors] = TypeAdapter(NeedsenumErrors) + +NeedsenumErrorsTypeAdapter: TypeAdapter[Any] = TypeAdapter(NeedsenumErrors) diff --git a/tests/codegen/snapshot/snapshots/test_unknown_enum/enumService/needsEnumObject.py b/tests/codegen/snapshot/snapshots/test_unknown_enum/enumService/needsEnumObject.py index e55e9e46..4e370433 100644 --- a/tests/codegen/snapshot/snapshots/test_unknown_enum/enumService/needsEnumObject.py +++ b/tests/codegen/snapshot/snapshots/test_unknown_enum/enumService/needsEnumObject.py @@ -71,9 +71,8 @@ class NeedsenumobjectInputOneOf_in_second(TypedDict): if x["kind"] == "in_first" else encode_NeedsenumobjectInputOneOf_in_second(x) ) -NeedsenumobjectInputTypeAdapter: TypeAdapter[NeedsenumobjectInput] = TypeAdapter( - NeedsenumobjectInput -) + +NeedsenumobjectInputTypeAdapter: TypeAdapter[Any] = TypeAdapter(NeedsenumobjectInput) class NeedsenumobjectOutputFooOneOf_out_first(BaseModel): @@ -106,9 +105,7 @@ class NeedsenumobjectOutput(BaseModel): foo: Optional[NeedsenumobjectOutputFoo] = None -NeedsenumobjectOutputTypeAdapter: TypeAdapter[NeedsenumobjectOutput] = TypeAdapter( - NeedsenumobjectOutput -) +NeedsenumobjectOutputTypeAdapter: TypeAdapter[Any] = TypeAdapter(NeedsenumobjectOutput) class NeedsenumobjectErrorsFooAnyOf_0(RiverError): @@ -131,6 +128,4 @@ class NeedsenumobjectErrors(RiverError): foo: Optional[NeedsenumobjectErrorsFoo] = None -NeedsenumobjectErrorsTypeAdapter: TypeAdapter[NeedsenumobjectErrors] = TypeAdapter( - NeedsenumobjectErrors -) +NeedsenumobjectErrorsTypeAdapter: TypeAdapter[Any] = TypeAdapter(NeedsenumobjectErrors) diff --git a/tests/codegen/stream/generated/test_service/stream_method.py b/tests/codegen/stream/generated/test_service/stream_method.py index 13d45c2d..1294a67a 100644 --- a/tests/codegen/stream/generated/test_service/stream_method.py +++ b/tests/codegen/stream/generated/test_service/stream_method.py @@ -39,15 +39,11 @@ class Stream_MethodInput(TypedDict): data: str -Stream_MethodInputTypeAdapter: TypeAdapter[Stream_MethodInput] = TypeAdapter( - Stream_MethodInput -) +Stream_MethodInputTypeAdapter: TypeAdapter[Any] = TypeAdapter(Stream_MethodInput) class Stream_MethodOutput(BaseModel): data: str -Stream_MethodOutputTypeAdapter: TypeAdapter[Stream_MethodOutput] = TypeAdapter( - Stream_MethodOutput -) +Stream_MethodOutputTypeAdapter: TypeAdapter[Any] = TypeAdapter(Stream_MethodOutput) From 217582bf550f29a5bc0a0c54794295afebeba523 Mon Sep 17 00:00:00 2001 From: Toby Ho Date: Fri, 14 Mar 2025 17:12:49 -0400 Subject: [PATCH 14/16] reverts --- src/replit_river/client.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/replit_river/client.py b/src/replit_river/client.py index 5ee79a96..90a695b1 100644 --- a/src/replit_river/client.py +++ b/src/replit_river/client.py @@ -21,6 +21,7 @@ ) from .rpc import ( + ErrorType, InitType, RequestType, ResponseType, @@ -79,7 +80,7 @@ async def send_rpc( request: RequestType, request_serializer: Callable[[RequestType], Any], response_deserializer: Callable[[Any], ResponseType], - error_deserializer: Callable[[Any], Any], + error_deserializer: Callable[[Any], ErrorType], timeout: timedelta, ) -> ResponseType: with _trace_procedure("rpc", service_name, procedure_name) as span_handle: @@ -104,7 +105,7 @@ async def send_upload( init_serializer: Optional[Callable[[InitType], Any]], request_serializer: Callable[[RequestType], Any], response_deserializer: Callable[[Any], ResponseType], - error_deserializer: Callable[[Any], Any], + error_deserializer: Callable[[Any], ErrorType], ) -> ResponseType: with _trace_procedure("upload", service_name, procedure_name) as span_handle: session = await self._transport.get_or_create_session() @@ -127,7 +128,7 @@ async def send_subscription( request: RequestType, request_serializer: Callable[[RequestType], Any], response_deserializer: Callable[[Any], ResponseType], - error_deserializer: Callable[[Any], Any], + error_deserializer: Callable[[Any], ErrorType], ) -> AsyncGenerator[Union[ResponseType, RiverError], None]: with _trace_procedure( "subscription", service_name, procedure_name @@ -155,7 +156,7 @@ async def send_stream( init_serializer: Optional[Callable[[InitType], Any]], request_serializer: Callable[[RequestType], Any], response_deserializer: Callable[[Any], ResponseType], - error_deserializer: Callable[[Any], Any], + error_deserializer: Callable[[Any], ErrorType], ) -> AsyncGenerator[Union[ResponseType, RiverError], None]: with _trace_procedure("stream", service_name, procedure_name) as span_handle: session = await self._transport.get_or_create_session() From f25949efa8944e5238600ad1ce065c05f3463413 Mon Sep 17 00:00:00 2001 From: Toby Ho Date: Fri, 14 Mar 2025 17:14:16 -0400 Subject: [PATCH 15/16] Update src/replit_river/codegen/client.py Co-authored-by: Devon Stewart --- src/replit_river/codegen/client.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/replit_river/codegen/client.py b/src/replit_river/codegen/client.py index fcf93ae1..8db12a6d 100644 --- a/src/replit_river/codegen/client.py +++ b/src/replit_river/codegen/client.py @@ -768,9 +768,9 @@ def append_type_adapter_definition( type_adapter_name: TypeName, _type: TypeExpression, module_info: list[ModuleName], - ) -> None: + ) -> tuple[list[TypeName], list[ModuleName], list[FileContents]]: rendered_type_expr = render_type_expr(_type) - serdes.append( + return ( ( [type_adapter_name], module_info, From 9bd9a47a0cfc1db150850fe189fb827133198be0 Mon Sep 17 00:00:00 2001 From: Toby Ho Date: Fri, 14 Mar 2025 17:17:13 -0400 Subject: [PATCH 16/16] helper to not append --- src/replit_river/codegen/client.py | 42 +++++++++++++++++------------- 1 file changed, 24 insertions(+), 18 deletions(-) diff --git a/src/replit_river/codegen/client.py b/src/replit_river/codegen/client.py index 8db12a6d..7126163d 100644 --- a/src/replit_river/codegen/client.py +++ b/src/replit_river/codegen/client.py @@ -764,26 +764,24 @@ def generate_individual_service( ) -> Tuple[ModuleName, ClassName, dict[RenderedPath, FileContents]]: serdes: list[Tuple[list[TypeName], list[ModuleName], list[FileContents]]] = [] - def append_type_adapter_definition( + def _type_adapter_definition( type_adapter_name: TypeName, _type: TypeExpression, module_info: list[ModuleName], ) -> tuple[list[TypeName], list[ModuleName], list[FileContents]]: rendered_type_expr = render_type_expr(_type) return ( - ( - [type_adapter_name], - module_info, - [ - FileContents( - dedent(f""" + [type_adapter_name], + module_info, + [ + FileContents( + dedent(f""" {render_type_expr(type_adapter_name)}: TypeAdapter[Any] = ( TypeAdapter({rendered_type_expr}) ) """) - ) - ], - ) + ) + ], ) class_name = ClassName(f"{schema_name.title()}Service") @@ -832,8 +830,10 @@ def __init__(self, client: river.Client[Any]): input_chunks, ) ) - append_type_adapter_definition( - input_type_type_adapter_name, input_type, module_info + serdes.append( + _type_adapter_definition( + input_type_type_adapter_name, input_type, module_info + ) ) output_type, module_info, output_chunks, encoder_names = encode_type( procedure.output, @@ -853,8 +853,10 @@ def __init__(self, client: river.Client[Any]): output_type_type_adapter_name = TypeName( f"{render_literal_type(output_type_name)}TypeAdapter" ) - append_type_adapter_definition( - output_type_type_adapter_name, output_type, module_info + serdes.append( + _type_adapter_definition( + output_type_type_adapter_name, output_type, module_info + ) ) output_module_info = module_info if procedure.errors: @@ -882,8 +884,10 @@ def __init__(self, client: river.Client[Any]): if error_type_type_adapter_name.value != "RiverErrorTypeAdapter": if len(module_info) == 0: module_info = output_module_info - append_type_adapter_definition( - error_type_type_adapter_name, error_type, module_info + serdes.append( + _type_adapter_definition( + error_type_type_adapter_name, error_type, module_info + ) ) output_or_error_type = UnionTypeExpr([output_type, error_type_name]) @@ -927,8 +931,10 @@ def __init__(self, client: river.Client[Any]): init_type_type_adapter_name = TypeName( f"{init_type_name.value}TypeAdapter" ) - append_type_adapter_definition( - init_type_type_adapter_name, init_type, module_info + serdes.append( + _type_adapter_definition( + init_type_type_adapter_name, init_type, module_info + ) ) render_init_method = f"""\ lambda x: {render_type_expr(init_type_type_adapter_name)}