From edde57258aec4e4cf7c04259557465a1dca70025 Mon Sep 17 00:00:00 2001 From: Toby Ho Date: Wed, 12 Mar 2025 20:09:47 -0400 Subject: [PATCH 01/29] updates --- src/replit_river/codegen/client.py | 70 ++++++++++++++++++++++++++---- 1 file changed, 61 insertions(+), 9 deletions(-) diff --git a/src/replit_river/codegen/client.py b/src/replit_river/codegen/client.py index eaa28f55..62fd5418 100644 --- a/src/replit_river/codegen/client.py +++ b/src/replit_river/codegen/client.py @@ -65,6 +65,7 @@ from pydantic import TypeAdapter from replit_river.error_schema import RiverError +RiverErrorTypeAdapter = TypeAdapter(RiverError) import replit_river as river """ @@ -761,6 +762,7 @@ def generate_individual_service( schema_name: str, schema: RiverService, input_base_class: Literal["TypedDict"] | Literal["BaseModel"], + ) -> Tuple[ModuleName, ClassName, dict[RenderedPath, FileContents]]: serdes: list[Tuple[list[TypeName], list[ModuleName], list[FileContents]]] = [] class_name = ClassName(f"{schema_name.title()}Service") @@ -798,6 +800,8 @@ def __init__(self, client: river.Client[Any]): module_names, permit_unknown_members=False, ) + input_type_name = extract_inner_type(input_type) + input_type_type_adapter_name = TypeName(f"{input_type_name.value}TypeAdapter") serdes.append( ( [extract_inner_type(input_type), *encoder_names], @@ -805,6 +809,13 @@ def __init__(self, client: river.Client[Any]): input_chunks, ) ) + serdes.append( + ( + [input_type_type_adapter_name], + module_info, + [f"{input_type_type_adapter_name.value} = TypeAdapter({render_type_expr(input_type)}) # type: ignore"] + ) + ) output_type, module_info, output_chunks, encoder_names = encode_type( procedure.output, TypeName(f"{name.title()}Output"), @@ -812,13 +823,24 @@ def __init__(self, client: river.Client[Any]): module_names, permit_unknown_members=True, ) + output_type_name = extract_inner_type(output_type) serdes.append( ( - [extract_inner_type(output_type), *encoder_names], + [output_type_name, *encoder_names], module_info, output_chunks, ) ) + output_type_type_adapter_name = TypeName(f"{output_type_name.value}TypeAdapter") + # print('appending %r, %r' % (output_type_type_adapter_name, module_info)) + serdes.append( + ( + [output_type_type_adapter_name], + module_info, + [f"{output_type_type_adapter_name.value} = TypeAdapter({render_type_expr(output_type)}) # type: ignore"], + ) + ) + output_module_info = module_info if procedure.errors: error_type, module_info, errors_chunks, encoder_names = encode_type( procedure.errors, @@ -827,28 +849,48 @@ def __init__(self, client: river.Client[Any]): module_names, permit_unknown_members=True, ) + # print('error type module_info: %r' % module_info) if isinstance(error_type, NoneTypeExpr): - error_type = TypeName("RiverError") + error_type_name = TypeName("RiverError") + error_type = error_type_name else: + error_type_name = extract_inner_type(error_type) serdes.append( - ([extract_inner_type(error_type)], module_info, errors_chunks) + ([error_type_name], module_info, errors_chunks) ) + + else: - error_type = TypeName("RiverError") - output_or_error_type = UnionTypeExpr([output_type, error_type]) + error_type_name = TypeName("RiverError") + error_type = error_type_name + + error_type_type_adapter_name = TypeName(f"{error_type.value}TypeAdapter") + if error_type_type_adapter_name.value != "RiverErrorTypeAdapter": + # print('error type: %r, %r, %r' % (error_type_type_adapter_name, module_info, output_module_info)) + if len(module_info) == 0: + module_info = output_module_info + serdes.append( + ( + [error_type_type_adapter_name], + module_info, + [f"{error_type_type_adapter_name.value} = TypeAdapter({render_type_expr(error_type)}) # type: ignore"], + ) + ) + output_or_error_type = UnionTypeExpr([output_type, error_type_name]) + # NB: These strings must be indented to at least the same level of # the function strings in the branches below, otherwise `dedent` # will pick our indentation level for normalization, which will # break the "def" indentation presuppositions. parse_output_method = f"""\ - lambda x: TypeAdapter({render_type_expr(output_type)}) + lambda x: {output_type_type_adapter_name.value} .validate_python( x # type: ignore[arg-type] ) """ parse_error_method = f"""\ - lambda x: TypeAdapter({render_type_expr(error_type)}) + lambda x: {error_type_type_adapter_name.value} .validate_python( x # type: ignore[arg-type] ) @@ -871,8 +913,17 @@ def __init__(self, client: river.Client[Any]): else: render_init_method = f"encode_{render_literal_type(init_type)}" else: + init_type_name = extract_inner_type(init_type) + init_type_type_adapter_name = TypeName(f"{init_type_name.value}TypeAdapter") + serdes.append( + ( + [init_type_type_adapter_name], + module_info, + [f"{init_type_type_adapter_name.value} = TypeAdapter({render_type_expr(init_type)}) # type: ignore"] + ) + ) render_init_method = f"""\ - lambda x: TypeAdapter({render_type_expr(init_type)}) + lambda x: {init_type_type_adapter_name.name}) .validate_python """ @@ -898,8 +949,9 @@ def __init__(self, client: river.Client[Any]): else: render_input_method = f"encode_{render_literal_type(input_type)}" else: + render_input_method = f"""\ - lambda x: TypeAdapter({render_type_expr(input_type)}) + lambda x: {input_type_type_adapter_name.value} .dump_python( x, # type: ignore[arg-type] by_alias=True, From f5190e3324baf28948d0a0660c83d1be0ac56c38 Mon Sep 17 00:00:00 2001 From: Toby Ho Date: Wed, 12 Mar 2025 20:15:50 -0400 Subject: [PATCH 02/29] refactor --- src/replit_river/codegen/client.py | 43 +++++++++--------------------- 1 file changed, 13 insertions(+), 30 deletions(-) diff --git a/src/replit_river/codegen/client.py b/src/replit_river/codegen/client.py index 62fd5418..81f33a1d 100644 --- a/src/replit_river/codegen/client.py +++ b/src/replit_river/codegen/client.py @@ -764,6 +764,15 @@ def generate_individual_service( input_base_class: Literal["TypedDict"] | Literal["BaseModel"], ) -> Tuple[ModuleName, ClassName, dict[RenderedPath, FileContents]]: + def append_type_adapter_definition(type_adapter_name: TypeName, _type: TypeExpression, module_info: list[ModuleName]): + serdes.append( + ( + [type_adapter_name], + module_info, + [f"{type_adapter_name.value} = TypeAdapter({render_type_expr(_type)}) # type: ignore"] + ) + ) + serdes: list[Tuple[list[TypeName], list[ModuleName], list[FileContents]]] = [] class_name = ClassName(f"{schema_name.title()}Service") current_chunks: List[str] = [ @@ -809,13 +818,7 @@ def __init__(self, client: river.Client[Any]): input_chunks, ) ) - serdes.append( - ( - [input_type_type_adapter_name], - module_info, - [f"{input_type_type_adapter_name.value} = TypeAdapter({render_type_expr(input_type)}) # type: ignore"] - ) - ) + append_type_adapter_definition(input_type_type_adapter_name, input_type, module_info) output_type, module_info, output_chunks, encoder_names = encode_type( procedure.output, TypeName(f"{name.title()}Output"), @@ -832,14 +835,7 @@ def __init__(self, client: river.Client[Any]): ) ) output_type_type_adapter_name = TypeName(f"{output_type_name.value}TypeAdapter") - # print('appending %r, %r' % (output_type_type_adapter_name, module_info)) - serdes.append( - ( - [output_type_type_adapter_name], - module_info, - [f"{output_type_type_adapter_name.value} = TypeAdapter({render_type_expr(output_type)}) # type: ignore"], - ) - ) + append_type_adapter_definition(output_type_type_adapter_name, output_type, module_info) output_module_info = module_info if procedure.errors: error_type, module_info, errors_chunks, encoder_names = encode_type( @@ -866,16 +862,9 @@ def __init__(self, client: river.Client[Any]): error_type_type_adapter_name = TypeName(f"{error_type.value}TypeAdapter") if error_type_type_adapter_name.value != "RiverErrorTypeAdapter": - # print('error type: %r, %r, %r' % (error_type_type_adapter_name, module_info, output_module_info)) if len(module_info) == 0: module_info = output_module_info - serdes.append( - ( - [error_type_type_adapter_name], - module_info, - [f"{error_type_type_adapter_name.value} = TypeAdapter({render_type_expr(error_type)}) # type: ignore"], - ) - ) + append_type_adapter_definition(error_type_type_adapter_name, error_type, module_info) output_or_error_type = UnionTypeExpr([output_type, error_type_name]) @@ -915,13 +904,7 @@ def __init__(self, client: river.Client[Any]): else: init_type_name = extract_inner_type(init_type) init_type_type_adapter_name = TypeName(f"{init_type_name.value}TypeAdapter") - serdes.append( - ( - [init_type_type_adapter_name], - module_info, - [f"{init_type_type_adapter_name.value} = TypeAdapter({render_type_expr(init_type)}) # type: ignore"] - ) - ) + append_type_adapter_definition(init_type_type_adapter_name, init_type, module_info) render_init_method = f"""\ lambda x: {init_type_type_adapter_name.name}) .validate_python From 766304eff73260bc8b1533110fd7f85ae87800cc Mon Sep 17 00:00:00 2001 From: Toby Ho Date: Wed, 12 Mar 2025 20:21:49 -0400 Subject: [PATCH 03/29] updates to generated test output --- .../codegen/rpc/generated/test_service/__init__.py | 14 +++++++++++--- .../rpc/generated/test_service/rpc_method.py | 6 ++++++ .../stream/generated/test_service/__init__.py | 8 ++++++-- .../stream/generated/test_service/stream_method.py | 6 ++++++ 4 files changed, 29 insertions(+), 5 deletions(-) diff --git a/tests/codegen/rpc/generated/test_service/__init__.py b/tests/codegen/rpc/generated/test_service/__init__.py index fc994615..45ad8f39 100644 --- a/tests/codegen/rpc/generated/test_service/__init__.py +++ b/tests/codegen/rpc/generated/test_service/__init__.py @@ -6,10 +6,18 @@ from pydantic import TypeAdapter from replit_river.error_schema import RiverError + +RiverErrorTypeAdapter = TypeAdapter(RiverError) import replit_river as river -from .rpc_method import Rpc_MethodInput, Rpc_MethodOutput, encode_Rpc_MethodInput +from .rpc_method import ( + Rpc_MethodInput, + Rpc_MethodInputTypeAdapter, + Rpc_MethodOutput, + Rpc_MethodOutputTypeAdapter, + encode_Rpc_MethodInput, +) class Test_ServiceService: @@ -26,10 +34,10 @@ async def rpc_method( "rpc_method", input, encode_Rpc_MethodInput, - lambda x: TypeAdapter(Rpc_MethodOutput).validate_python( + lambda x: Rpc_MethodOutputTypeAdapter.validate_python( x # type: ignore[arg-type] ), - lambda x: TypeAdapter(RiverError).validate_python( + lambda x: RiverErrorTypeAdapter.validate_python( x # type: ignore[arg-type] ), timeout, diff --git a/tests/codegen/rpc/generated/test_service/rpc_method.py b/tests/codegen/rpc/generated/test_service/rpc_method.py index 263c36b0..4fb7b45c 100644 --- a/tests/codegen/rpc/generated/test_service/rpc_method.py +++ b/tests/codegen/rpc/generated/test_service/rpc_method.py @@ -39,5 +39,11 @@ class Rpc_MethodInput(TypedDict): data: str +Rpc_MethodInputTypeAdapter = TypeAdapter(Rpc_MethodInput) # type: ignore + + class Rpc_MethodOutput(BaseModel): data: str + + +Rpc_MethodOutputTypeAdapter = TypeAdapter(Rpc_MethodOutput) # type: ignore diff --git a/tests/codegen/stream/generated/test_service/__init__.py b/tests/codegen/stream/generated/test_service/__init__.py index f59c046e..df410156 100644 --- a/tests/codegen/stream/generated/test_service/__init__.py +++ b/tests/codegen/stream/generated/test_service/__init__.py @@ -6,12 +6,16 @@ from pydantic import TypeAdapter from replit_river.error_schema import RiverError + +RiverErrorTypeAdapter = TypeAdapter(RiverError) import replit_river as river from .stream_method import ( Stream_MethodInput, + Stream_MethodInputTypeAdapter, Stream_MethodOutput, + Stream_MethodOutputTypeAdapter, encode_Stream_MethodInput, ) @@ -31,10 +35,10 @@ async def stream_method( inputStream, None, encode_Stream_MethodInput, - lambda x: TypeAdapter(Stream_MethodOutput).validate_python( + lambda x: Stream_MethodOutputTypeAdapter.validate_python( x # type: ignore[arg-type] ), - lambda x: TypeAdapter(RiverError).validate_python( + lambda x: RiverErrorTypeAdapter.validate_python( x # type: ignore[arg-type] ), ) diff --git a/tests/codegen/stream/generated/test_service/stream_method.py b/tests/codegen/stream/generated/test_service/stream_method.py index 53267467..42d7f28c 100644 --- a/tests/codegen/stream/generated/test_service/stream_method.py +++ b/tests/codegen/stream/generated/test_service/stream_method.py @@ -39,5 +39,11 @@ class Stream_MethodInput(TypedDict): data: str +Stream_MethodInputTypeAdapter = TypeAdapter(Stream_MethodInput) # type: ignore + + class Stream_MethodOutput(BaseModel): data: str + + +Stream_MethodOutputTypeAdapter = TypeAdapter(Stream_MethodOutput) # type: ignore From 9b0f693015fd26ee9bf3bea90a8f702b625d19a1 Mon Sep 17 00:00:00 2001 From: Toby Ho Date: Wed, 12 Mar 2025 20:35:42 -0400 Subject: [PATCH 04/29] tests pass --- .python-version | 2 +- .replit | 4 ++-- flake.nix | 2 +- .../test_unknown_enum/enumService/__init__.py | 16 ++++++++++++---- .../test_unknown_enum/enumService/needsEnum.py | 3 +++ .../enumService/needsEnumObject.py | 7 +++++++ 6 files changed, 26 insertions(+), 8 deletions(-) diff --git a/.python-version b/.python-version index 2c073331..e4fba218 100644 --- a/.python-version +++ b/.python-version @@ -1 +1 @@ -3.11 +3.12 diff --git a/.replit b/.replit index 5b89454c..9a6274c0 100644 --- a/.replit +++ b/.replit @@ -1,6 +1,6 @@ run = "poetry run pytest tests" -modules = ["python-3.11"] +modules = ["python-3.12"] [nix] -channel = "stable-23_11" +channel = "stable-24_11" diff --git a/flake.nix b/flake.nix index 684a5e2c..5ad206f6 100644 --- a/flake.nix +++ b/flake.nix @@ -18,7 +18,7 @@ LD_LIBRARY_PATH = "${pkgs.stdenv.cc.cc.lib}/lib"; }; packages = replitNixDeps ++ [ - pkgs.python311 + pkgs.python312 pkgs.uv ]; }; diff --git a/tests/codegen/snapshot/snapshots/test_unknown_enum/enumService/__init__.py b/tests/codegen/snapshot/snapshots/test_unknown_enum/enumService/__init__.py index 7477adb8..a357bdcb 100644 --- a/tests/codegen/snapshot/snapshots/test_unknown_enum/enumService/__init__.py +++ b/tests/codegen/snapshot/snapshots/test_unknown_enum/enumService/__init__.py @@ -6,19 +6,27 @@ from pydantic import TypeAdapter from replit_river.error_schema import RiverError + +RiverErrorTypeAdapter = TypeAdapter(RiverError) import replit_river as river from .needsEnum import ( NeedsenumErrors, + NeedsenumErrorsTypeAdapter, NeedsenumInput, + NeedsenumInputTypeAdapter, NeedsenumOutput, + NeedsenumOutputTypeAdapter, encode_NeedsenumInput, ) from .needsEnumObject import ( NeedsenumobjectErrors, + NeedsenumobjectErrorsTypeAdapter, NeedsenumobjectInput, + NeedsenumobjectInputTypeAdapter, NeedsenumobjectOutput, + NeedsenumobjectOutputTypeAdapter, encode_NeedsenumobjectInput, ) @@ -37,10 +45,10 @@ async def needsEnum( "needsEnum", input, lambda x: x, - lambda x: TypeAdapter(NeedsenumOutput).validate_python( + lambda x: NeedsenumOutputTypeAdapter.validate_python( x # type: ignore[arg-type] ), - lambda x: TypeAdapter(NeedsenumErrors).validate_python( + lambda x: NeedsenumErrorsTypeAdapter.validate_python( x # type: ignore[arg-type] ), timeout, @@ -56,10 +64,10 @@ async def needsEnumObject( "needsEnumObject", input, encode_NeedsenumobjectInput, - lambda x: TypeAdapter(NeedsenumobjectOutput).validate_python( + lambda x: NeedsenumobjectOutputTypeAdapter.validate_python( x # type: ignore[arg-type] ), - lambda x: TypeAdapter(NeedsenumobjectErrors).validate_python( + lambda x: NeedsenumobjectErrorsTypeAdapter.validate_python( x # type: ignore[arg-type] ), timeout, diff --git a/tests/codegen/snapshot/snapshots/test_unknown_enum/enumService/needsEnum.py b/tests/codegen/snapshot/snapshots/test_unknown_enum/enumService/needsEnum.py index 426049b9..e77ec6d5 100644 --- a/tests/codegen/snapshot/snapshots/test_unknown_enum/enumService/needsEnum.py +++ b/tests/codegen/snapshot/snapshots/test_unknown_enum/enumService/needsEnum.py @@ -26,11 +26,14 @@ NeedsenumInput = Literal["in_first"] | Literal["in_second"] encode_NeedsenumInput: Callable[["NeedsenumInput"], Any] = lambda x: x +NeedsenumInputTypeAdapter = TypeAdapter(NeedsenumInput) # type: ignore NeedsenumOutput = Annotated[ Literal["out_first"] | Literal["out_second"] | RiverUnknownValue, WrapValidator(translate_unknown_value), ] +NeedsenumOutputTypeAdapter = TypeAdapter(NeedsenumOutput) # type: ignore NeedsenumErrors = Annotated[ Literal["err_first"] | Literal["err_second"] | RiverUnknownValue, WrapValidator(translate_unknown_value), ] +NeedsenumErrorsTypeAdapter = TypeAdapter(NeedsenumErrors) # type: ignore diff --git a/tests/codegen/snapshot/snapshots/test_unknown_enum/enumService/needsEnumObject.py b/tests/codegen/snapshot/snapshots/test_unknown_enum/enumService/needsEnumObject.py index 6766eb79..6faf4d76 100644 --- a/tests/codegen/snapshot/snapshots/test_unknown_enum/enumService/needsEnumObject.py +++ b/tests/codegen/snapshot/snapshots/test_unknown_enum/enumService/needsEnumObject.py @@ -71,6 +71,7 @@ class NeedsenumobjectInputOneOf_in_second(TypedDict): if x["kind"] == "in_first" else encode_NeedsenumobjectInputOneOf_in_second(x) ) +NeedsenumobjectInputTypeAdapter = TypeAdapter(NeedsenumobjectInput) # type: ignore class NeedsenumobjectOutputFooOneOf_out_first(BaseModel): @@ -103,6 +104,9 @@ class NeedsenumobjectOutput(BaseModel): foo: Optional[NeedsenumobjectOutputFoo] = None +NeedsenumobjectOutputTypeAdapter = TypeAdapter(NeedsenumobjectOutput) # type: ignore + + class NeedsenumobjectErrorsFooAnyOf_0(RiverError): beep: Optional[Literal["err_first"]] = None @@ -121,3 +125,6 @@ class NeedsenumobjectErrorsFooAnyOf_1(RiverError): class NeedsenumobjectErrors(RiverError): foo: Optional[NeedsenumobjectErrorsFoo] = None + + +NeedsenumobjectErrorsTypeAdapter = TypeAdapter(NeedsenumobjectErrors) # type: ignore From 29b935334db69b80ad9bf80f95a9c202e6b67f62 Mon Sep 17 00:00:00 2001 From: Toby Ho Date: Wed, 12 Mar 2025 20:47:54 -0400 Subject: [PATCH 05/29] lint --- src/replit_river/codegen/client.py | 50 +++++++++++++++++++----------- 1 file changed, 32 insertions(+), 18 deletions(-) diff --git a/src/replit_river/codegen/client.py b/src/replit_river/codegen/client.py index 81f33a1d..e71fbbf0 100644 --- a/src/replit_river/codegen/client.py +++ b/src/replit_river/codegen/client.py @@ -762,14 +762,23 @@ def generate_individual_service( schema_name: str, schema: RiverService, input_base_class: Literal["TypedDict"] | Literal["BaseModel"], - ) -> Tuple[ModuleName, ClassName, dict[RenderedPath, FileContents]]: - def append_type_adapter_definition(type_adapter_name: TypeName, _type: TypeExpression, module_info: list[ModuleName]): + def append_type_adapter_definition( + type_adapter_name: TypeName, + _type: TypeExpression, + module_info: list[ModuleName], + ) -> None: serdes.append( ( [type_adapter_name], module_info, - [f"{type_adapter_name.value} = TypeAdapter({render_type_expr(_type)}) # type: ignore"] + [ + FileContents( + f"{type_adapter_name.value} = " + f"TypeAdapter({render_type_expr(_type)})" + " # type: ignore" + ) + ], ) ) @@ -818,7 +827,9 @@ def __init__(self, client: river.Client[Any]): input_chunks, ) ) - append_type_adapter_definition(input_type_type_adapter_name, input_type, module_info) + append_type_adapter_definition( + input_type_type_adapter_name, input_type, module_info + ) output_type, module_info, output_chunks, encoder_names = encode_type( procedure.output, TypeName(f"{name.title()}Output"), @@ -835,7 +846,9 @@ def __init__(self, client: river.Client[Any]): ) ) output_type_type_adapter_name = TypeName(f"{output_type_name.value}TypeAdapter") - append_type_adapter_definition(output_type_type_adapter_name, output_type, module_info) + append_type_adapter_definition( + output_type_type_adapter_name, output_type, module_info + ) output_module_info = module_info if procedure.errors: error_type, module_info, errors_chunks, encoder_names = encode_type( @@ -851,22 +864,20 @@ def __init__(self, client: river.Client[Any]): error_type = error_type_name else: error_type_name = extract_inner_type(error_type) - serdes.append( - ([error_type_name], module_info, errors_chunks) - ) - + serdes.append(([error_type_name], module_info, errors_chunks)) else: error_type_name = TypeName("RiverError") error_type = error_type_name - error_type_type_adapter_name = TypeName(f"{error_type.value}TypeAdapter") + error_type_type_adapter_name = TypeName(f"{error_type_name.value}TypeAdapter") if error_type_type_adapter_name.value != "RiverErrorTypeAdapter": if len(module_info) == 0: module_info = output_module_info - append_type_adapter_definition(error_type_type_adapter_name, error_type, module_info) + append_type_adapter_definition( + error_type_type_adapter_name, error_type, module_info + ) output_or_error_type = UnionTypeExpr([output_type, error_type_name]) - # NB: These strings must be indented to at least the same level of # the function strings in the branches below, otherwise `dedent` @@ -903,10 +914,14 @@ def __init__(self, client: river.Client[Any]): render_init_method = f"encode_{render_literal_type(init_type)}" else: init_type_name = extract_inner_type(init_type) - init_type_type_adapter_name = TypeName(f"{init_type_name.value}TypeAdapter") - append_type_adapter_definition(init_type_type_adapter_name, init_type, module_info) + init_type_type_adapter_name = TypeName( + f"{init_type_name.value}TypeAdapter" + ) + append_type_adapter_definition( + init_type_type_adapter_name, init_type, module_info + ) render_init_method = f"""\ - lambda x: {init_type_type_adapter_name.name}) + lambda x: {init_type_type_adapter_name.value}) .validate_python """ @@ -923,16 +938,15 @@ def __init__(self, client: river.Client[Any]): procedure.input, RiverConcreteType ) and procedure.input.type in ["array"]: match input_type: - case ListTypeExpr(input_type_name): + case ListTypeExpr(list_type): render_input_method = f"""\ lambda xs: [ - encode_{render_literal_type(input_type_name)}(x) for x in xs + encode_{render_literal_type(list_type)}(x) for x in xs ] """ else: render_input_method = f"encode_{render_literal_type(input_type)}" else: - render_input_method = f"""\ lambda x: {input_type_type_adapter_name.value} .dump_python( From 2f545bf0bb970dd8ba1a4432ca31be1f4caf67e8 Mon Sep 17 00:00:00 2001 From: Toby Ho Date: Thu, 13 Mar 2025 09:03:19 -0400 Subject: [PATCH 06/29] cleanup --- src/replit_river/codegen/client.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/replit_river/codegen/client.py b/src/replit_river/codegen/client.py index e71fbbf0..02c260dc 100644 --- a/src/replit_river/codegen/client.py +++ b/src/replit_river/codegen/client.py @@ -858,7 +858,6 @@ def __init__(self, client: river.Client[Any]): module_names, permit_unknown_members=True, ) - # print('error type module_info: %r' % module_info) if isinstance(error_type, NoneTypeExpr): error_type_name = TypeName("RiverError") error_type = error_type_name From 6b57ec453f221ae8a29028c468365d81a5c7c876 Mon Sep 17 00:00:00 2001 From: Toby Ho Date: Thu, 13 Mar 2025 12:44:18 -0400 Subject: [PATCH 07/29] use Any --- src/replit_river/codegen/client.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/replit_river/codegen/client.py b/src/replit_river/codegen/client.py index 02c260dc..3fa007d6 100644 --- a/src/replit_river/codegen/client.py +++ b/src/replit_river/codegen/client.py @@ -768,15 +768,15 @@ def append_type_adapter_definition( _type: TypeExpression, module_info: list[ModuleName], ) -> None: + rendered_type_expr = render_type_expr(_type) serdes.append( ( [type_adapter_name], module_info, [ FileContents( - f"{type_adapter_name.value} = " - f"TypeAdapter({render_type_expr(_type)})" - " # type: ignore" + f"{type_adapter_name.value}: Any = " + f"TypeAdapter({rendered_type_expr})" ) ], ) From 7be8c26bc6d5d409e7c0df98cce81972250ad1d5 Mon Sep 17 00:00:00 2001 From: Toby Ho Date: Thu, 13 Mar 2025 12:59:19 -0400 Subject: [PATCH 08/29] allow Any for error type so the existing error types can match them; type the type adapters properly --- src/replit_river/client.py | 8 ++++---- src/replit_river/codegen/client.py | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/replit_river/client.py b/src/replit_river/client.py index 51e45a2d..dad15bb7 100644 --- a/src/replit_river/client.py +++ b/src/replit_river/client.py @@ -80,7 +80,7 @@ async def send_rpc( request: RequestType, request_serializer: Callable[[RequestType], Any], response_deserializer: Callable[[Any], ResponseType], - error_deserializer: Callable[[Any], ErrorType], + error_deserializer: Callable[[Any], Any], timeout: timedelta, ) -> ResponseType: with _trace_procedure("rpc", service_name, procedure_name) as span_handle: @@ -105,7 +105,7 @@ async def send_upload( init_serializer: Optional[Callable[[InitType], Any]], request_serializer: Callable[[RequestType], Any], response_deserializer: Callable[[Any], ResponseType], - error_deserializer: Callable[[Any], ErrorType], + error_deserializer: Callable[[Any], Any], ) -> ResponseType: with _trace_procedure("upload", service_name, procedure_name) as span_handle: session = await self._transport.get_or_create_session() @@ -128,7 +128,7 @@ async def send_subscription( request: RequestType, request_serializer: Callable[[RequestType], Any], response_deserializer: Callable[[Any], ResponseType], - error_deserializer: Callable[[Any], ErrorType], + error_deserializer: Callable[[Any], Any], ) -> AsyncGenerator[Union[ResponseType, ErrorType], None]: with _trace_procedure( "subscription", service_name, procedure_name @@ -156,7 +156,7 @@ async def send_stream( init_serializer: Optional[Callable[[InitType], Any]], request_serializer: Callable[[RequestType], Any], response_deserializer: Callable[[Any], ResponseType], - error_deserializer: Callable[[Any], ErrorType], + error_deserializer: Callable[[Any], Any], ) -> AsyncGenerator[Union[ResponseType, ErrorType], None]: with _trace_procedure("stream", service_name, procedure_name) as span_handle: session = await self._transport.get_or_create_session() diff --git a/src/replit_river/codegen/client.py b/src/replit_river/codegen/client.py index 3fa007d6..9a6518a8 100644 --- a/src/replit_river/codegen/client.py +++ b/src/replit_river/codegen/client.py @@ -775,7 +775,7 @@ def append_type_adapter_definition( module_info, [ FileContents( - f"{type_adapter_name.value}: Any = " + f"{type_adapter_name.value}: TypeAdapter[{rendered_type_expr}] = " f"TypeAdapter({rendered_type_expr})" ) ], From c77679ebf461ea61ff711fe7c68c3ed26eac3c24 Mon Sep 17 00:00:00 2001 From: Toby Ho Date: Thu, 13 Mar 2025 13:04:44 -0400 Subject: [PATCH 09/29] resnapshotted the tests --- .../codegen/rpc/generated/test_service/rpc_method.py | 6 ++++-- .../test_unknown_enum/enumService/needsEnum.py | 6 +++--- .../test_unknown_enum/enumService/needsEnumObject.py | 12 +++++++++--- .../stream/generated/test_service/stream_method.py | 8 ++++++-- 4 files changed, 22 insertions(+), 10 deletions(-) diff --git a/tests/codegen/rpc/generated/test_service/rpc_method.py b/tests/codegen/rpc/generated/test_service/rpc_method.py index 4fb7b45c..5bf32657 100644 --- a/tests/codegen/rpc/generated/test_service/rpc_method.py +++ b/tests/codegen/rpc/generated/test_service/rpc_method.py @@ -39,11 +39,13 @@ class Rpc_MethodInput(TypedDict): data: str -Rpc_MethodInputTypeAdapter = TypeAdapter(Rpc_MethodInput) # type: ignore +Rpc_MethodInputTypeAdapter: TypeAdapter[Rpc_MethodInput] = TypeAdapter(Rpc_MethodInput) class Rpc_MethodOutput(BaseModel): data: str -Rpc_MethodOutputTypeAdapter = TypeAdapter(Rpc_MethodOutput) # type: ignore +Rpc_MethodOutputTypeAdapter: TypeAdapter[Rpc_MethodOutput] = TypeAdapter( + Rpc_MethodOutput +) diff --git a/tests/codegen/snapshot/snapshots/test_unknown_enum/enumService/needsEnum.py b/tests/codegen/snapshot/snapshots/test_unknown_enum/enumService/needsEnum.py index e77ec6d5..3555eaac 100644 --- a/tests/codegen/snapshot/snapshots/test_unknown_enum/enumService/needsEnum.py +++ b/tests/codegen/snapshot/snapshots/test_unknown_enum/enumService/needsEnum.py @@ -26,14 +26,14 @@ NeedsenumInput = Literal["in_first"] | Literal["in_second"] encode_NeedsenumInput: Callable[["NeedsenumInput"], Any] = lambda x: x -NeedsenumInputTypeAdapter = TypeAdapter(NeedsenumInput) # type: ignore +NeedsenumInputTypeAdapter: TypeAdapter[NeedsenumInput] = TypeAdapter(NeedsenumInput) NeedsenumOutput = Annotated[ Literal["out_first"] | Literal["out_second"] | RiverUnknownValue, WrapValidator(translate_unknown_value), ] -NeedsenumOutputTypeAdapter = TypeAdapter(NeedsenumOutput) # type: ignore +NeedsenumOutputTypeAdapter: TypeAdapter[NeedsenumOutput] = TypeAdapter(NeedsenumOutput) NeedsenumErrors = Annotated[ Literal["err_first"] | Literal["err_second"] | RiverUnknownValue, WrapValidator(translate_unknown_value), ] -NeedsenumErrorsTypeAdapter = TypeAdapter(NeedsenumErrors) # type: ignore +NeedsenumErrorsTypeAdapter: TypeAdapter[NeedsenumErrors] = TypeAdapter(NeedsenumErrors) diff --git a/tests/codegen/snapshot/snapshots/test_unknown_enum/enumService/needsEnumObject.py b/tests/codegen/snapshot/snapshots/test_unknown_enum/enumService/needsEnumObject.py index 6faf4d76..e55e9e46 100644 --- a/tests/codegen/snapshot/snapshots/test_unknown_enum/enumService/needsEnumObject.py +++ b/tests/codegen/snapshot/snapshots/test_unknown_enum/enumService/needsEnumObject.py @@ -71,7 +71,9 @@ class NeedsenumobjectInputOneOf_in_second(TypedDict): if x["kind"] == "in_first" else encode_NeedsenumobjectInputOneOf_in_second(x) ) -NeedsenumobjectInputTypeAdapter = TypeAdapter(NeedsenumobjectInput) # type: ignore +NeedsenumobjectInputTypeAdapter: TypeAdapter[NeedsenumobjectInput] = TypeAdapter( + NeedsenumobjectInput +) class NeedsenumobjectOutputFooOneOf_out_first(BaseModel): @@ -104,7 +106,9 @@ class NeedsenumobjectOutput(BaseModel): foo: Optional[NeedsenumobjectOutputFoo] = None -NeedsenumobjectOutputTypeAdapter = TypeAdapter(NeedsenumobjectOutput) # type: ignore +NeedsenumobjectOutputTypeAdapter: TypeAdapter[NeedsenumobjectOutput] = TypeAdapter( + NeedsenumobjectOutput +) class NeedsenumobjectErrorsFooAnyOf_0(RiverError): @@ -127,4 +131,6 @@ class NeedsenumobjectErrors(RiverError): foo: Optional[NeedsenumobjectErrorsFoo] = None -NeedsenumobjectErrorsTypeAdapter = TypeAdapter(NeedsenumobjectErrors) # type: ignore +NeedsenumobjectErrorsTypeAdapter: TypeAdapter[NeedsenumobjectErrors] = TypeAdapter( + NeedsenumobjectErrors +) diff --git a/tests/codegen/stream/generated/test_service/stream_method.py b/tests/codegen/stream/generated/test_service/stream_method.py index 42d7f28c..13d45c2d 100644 --- a/tests/codegen/stream/generated/test_service/stream_method.py +++ b/tests/codegen/stream/generated/test_service/stream_method.py @@ -39,11 +39,15 @@ class Stream_MethodInput(TypedDict): data: str -Stream_MethodInputTypeAdapter = TypeAdapter(Stream_MethodInput) # type: ignore +Stream_MethodInputTypeAdapter: TypeAdapter[Stream_MethodInput] = TypeAdapter( + Stream_MethodInput +) class Stream_MethodOutput(BaseModel): data: str -Stream_MethodOutputTypeAdapter = TypeAdapter(Stream_MethodOutput) # type: ignore +Stream_MethodOutputTypeAdapter: TypeAdapter[Stream_MethodOutput] = TypeAdapter( + Stream_MethodOutput +) From 8dcd43937133633526146cefffaafe1067bc1661 Mon Sep 17 00:00:00 2001 From: Toby Ho Date: Thu, 13 Mar 2025 16:16:25 -0400 Subject: [PATCH 10/29] regened the code --- src/replit_river/client.py | 5 +-- src/replit_river/codegen/client.py | 40 +++++++++++-------- src/replit_river/error_schema.py | 5 ++- .../rpc/generated/test_service/__init__.py | 4 +- .../test_unknown_enum/enumService/__init__.py | 4 +- .../stream/generated/test_service/__init__.py | 4 +- 6 files changed, 32 insertions(+), 30 deletions(-) diff --git a/src/replit_river/client.py b/src/replit_river/client.py index dad15bb7..5ee79a96 100644 --- a/src/replit_river/client.py +++ b/src/replit_river/client.py @@ -21,7 +21,6 @@ ) from .rpc import ( - ErrorType, InitType, RequestType, ResponseType, @@ -129,7 +128,7 @@ async def send_subscription( request_serializer: Callable[[RequestType], Any], response_deserializer: Callable[[Any], ResponseType], error_deserializer: Callable[[Any], Any], - ) -> AsyncGenerator[Union[ResponseType, ErrorType], None]: + ) -> AsyncGenerator[Union[ResponseType, RiverError], None]: with _trace_procedure( "subscription", service_name, procedure_name ) as span_handle: @@ -157,7 +156,7 @@ async def send_stream( request_serializer: Callable[[RequestType], Any], response_deserializer: Callable[[Any], ResponseType], error_deserializer: Callable[[Any], Any], - ) -> AsyncGenerator[Union[ResponseType, ErrorType], None]: + ) -> AsyncGenerator[Union[ResponseType, RiverError], None]: with _trace_procedure("stream", service_name, procedure_name) as span_handle: session = await self._transport.get_or_create_session() async for msg in session.send_stream( diff --git a/src/replit_river/codegen/client.py b/src/replit_river/codegen/client.py index 9a6518a8..71bb4a3e 100644 --- a/src/replit_river/codegen/client.py +++ b/src/replit_river/codegen/client.py @@ -64,8 +64,7 @@ from pydantic import TypeAdapter -from replit_river.error_schema import RiverError -RiverErrorTypeAdapter = TypeAdapter(RiverError) +from replit_river.error_schema import RiverError, RiverErrorTypeAdapter import replit_river as river """ @@ -763,26 +762,25 @@ def generate_individual_service( schema: RiverService, input_base_class: Literal["TypedDict"] | Literal["BaseModel"], ) -> Tuple[ModuleName, ClassName, dict[RenderedPath, FileContents]]: + serdes: list[Tuple[list[TypeName], list[ModuleName], list[FileContents]]] = [] + def append_type_adapter_definition( type_adapter_name: TypeName, _type: TypeExpression, module_info: list[ModuleName], ) -> None: rendered_type_expr = render_type_expr(_type) + var_name = render_type_expr(type_adapter_name) + var_type = f"TypeAdapter[{rendered_type_expr}]" + var_value = f"TypeAdapter({rendered_type_expr})" serdes.append( ( [type_adapter_name], module_info, - [ - FileContents( - f"{type_adapter_name.value}: TypeAdapter[{rendered_type_expr}] = " - f"TypeAdapter({rendered_type_expr})" - ) - ], + [FileContents(f"{var_name}: {var_type} = {var_value}")], ) ) - serdes: list[Tuple[list[TypeName], list[ModuleName], list[FileContents]]] = [] class_name = ClassName(f"{schema_name.title()}Service") current_chunks: List[str] = [ dedent( @@ -819,7 +817,9 @@ def __init__(self, client: river.Client[Any]): permit_unknown_members=False, ) input_type_name = extract_inner_type(input_type) - input_type_type_adapter_name = TypeName(f"{input_type_name.value}TypeAdapter") + input_type_type_adapter_name = TypeName( + f"{render_literal_type(input_type_name)}TypeAdapter" + ) serdes.append( ( [extract_inner_type(input_type), *encoder_names], @@ -845,7 +845,9 @@ def __init__(self, client: river.Client[Any]): output_chunks, ) ) - output_type_type_adapter_name = TypeName(f"{output_type_name.value}TypeAdapter") + output_type_type_adapter_name = TypeName( + f"{render_literal_type(output_type_name)}TypeAdapter" + ) append_type_adapter_definition( output_type_type_adapter_name, output_type, module_info ) @@ -869,7 +871,9 @@ def __init__(self, client: river.Client[Any]): error_type_name = TypeName("RiverError") error_type = error_type_name - error_type_type_adapter_name = TypeName(f"{error_type_name.value}TypeAdapter") + error_type_type_adapter_name = TypeName( + f"{render_literal_type(error_type_name)}TypeAdapter" + ) if error_type_type_adapter_name.value != "RiverErrorTypeAdapter": if len(module_info) == 0: module_info = output_module_info @@ -882,14 +886,16 @@ def __init__(self, client: river.Client[Any]): # the function strings in the branches below, otherwise `dedent` # will pick our indentation level for normalization, which will # break the "def" indentation presuppositions. + ottd_name = render_literal_type(output_type_type_adapter_name) parse_output_method = f"""\ - lambda x: {output_type_type_adapter_name.value} + lambda x: {ottd_name} .validate_python( x # type: ignore[arg-type] ) """ + ettd_name = render_literal_type(error_type_type_adapter_name) parse_error_method = f"""\ - lambda x: {error_type_type_adapter_name.value} + lambda x: {ettd_name} .validate_python( x # type: ignore[arg-type] ) @@ -920,8 +926,8 @@ def __init__(self, client: river.Client[Any]): init_type_type_adapter_name, init_type, module_info ) render_init_method = f"""\ - lambda x: {init_type_type_adapter_name.value}) - .validate_python + lambda x: {render_type_expr(init_type_type_adapter_name)} + .validate_python """ assert init_type is None or render_init_method, ( @@ -947,7 +953,7 @@ def __init__(self, client: river.Client[Any]): render_input_method = f"encode_{render_literal_type(input_type)}" else: render_input_method = f"""\ - lambda x: {input_type_type_adapter_name.value} + lambda x: {render_type_expr(input_type_type_adapter_name)} .dump_python( x, # type: ignore[arg-type] by_alias=True, diff --git a/src/replit_river/error_schema.py b/src/replit_river/error_schema.py index dc74fa98..a97fbc9c 100644 --- a/src/replit_river/error_schema.py +++ b/src/replit_river/error_schema.py @@ -1,6 +1,6 @@ from typing import Any, List, Optional -from pydantic import BaseModel +from pydantic import BaseModel, TypeAdapter ERROR_CODE_STREAM_CLOSED = "stream_closed" ERROR_HANDSHAKE = "handshake_failed" @@ -25,6 +25,9 @@ class RiverError(BaseModel): message: str +RiverErrorTypeAdapter = TypeAdapter(RiverError) + + class RiverException(Exception): """Exception raised by the River server.""" diff --git a/tests/codegen/rpc/generated/test_service/__init__.py b/tests/codegen/rpc/generated/test_service/__init__.py index 45ad8f39..24545e00 100644 --- a/tests/codegen/rpc/generated/test_service/__init__.py +++ b/tests/codegen/rpc/generated/test_service/__init__.py @@ -5,9 +5,7 @@ from pydantic import TypeAdapter -from replit_river.error_schema import RiverError - -RiverErrorTypeAdapter = TypeAdapter(RiverError) +from replit_river.error_schema import RiverError, RiverErrorTypeAdapter import replit_river as river diff --git a/tests/codegen/snapshot/snapshots/test_unknown_enum/enumService/__init__.py b/tests/codegen/snapshot/snapshots/test_unknown_enum/enumService/__init__.py index a357bdcb..ab9eaa08 100644 --- a/tests/codegen/snapshot/snapshots/test_unknown_enum/enumService/__init__.py +++ b/tests/codegen/snapshot/snapshots/test_unknown_enum/enumService/__init__.py @@ -5,9 +5,7 @@ from pydantic import TypeAdapter -from replit_river.error_schema import RiverError - -RiverErrorTypeAdapter = TypeAdapter(RiverError) +from replit_river.error_schema import RiverError, RiverErrorTypeAdapter import replit_river as river diff --git a/tests/codegen/stream/generated/test_service/__init__.py b/tests/codegen/stream/generated/test_service/__init__.py index df410156..b0d628b2 100644 --- a/tests/codegen/stream/generated/test_service/__init__.py +++ b/tests/codegen/stream/generated/test_service/__init__.py @@ -5,9 +5,7 @@ from pydantic import TypeAdapter -from replit_river.error_schema import RiverError - -RiverErrorTypeAdapter = TypeAdapter(RiverError) +from replit_river.error_schema import RiverError, RiverErrorTypeAdapter import replit_river as river From cb3aac9e841901d019bd7a5bb3904d8b0f640269 Mon Sep 17 00:00:00 2001 From: Toby Ho Date: Fri, 14 Mar 2025 16:36:04 -0400 Subject: [PATCH 11/29] lint --- src/replit_river/codegen/client.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/src/replit_river/codegen/client.py b/src/replit_river/codegen/client.py index 71bb4a3e..8a63a4ef 100644 --- a/src/replit_river/codegen/client.py +++ b/src/replit_river/codegen/client.py @@ -770,14 +770,19 @@ def append_type_adapter_definition( module_info: list[ModuleName], ) -> None: rendered_type_expr = render_type_expr(_type) - var_name = render_type_expr(type_adapter_name) - var_type = f"TypeAdapter[{rendered_type_expr}]" - var_value = f"TypeAdapter({rendered_type_expr})" serdes.append( ( [type_adapter_name], module_info, - [FileContents(f"{var_name}: {var_type} = {var_value}")], + [ + FileContents( + dedent(f""" + {render_type_expr(type_adapter_name)}: TypeAdapter[Any] = ( + TypeAdapter({rendered_type_expr}) + ) + """) + ) + ], ) ) From cda628f88bc70973109591d27368fd049b8337a9 Mon Sep 17 00:00:00 2001 From: Toby Ho Date: Fri, 14 Mar 2025 16:37:43 -0400 Subject: [PATCH 12/29] non-abreviated names --- src/replit_river/codegen/client.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/replit_river/codegen/client.py b/src/replit_river/codegen/client.py index 8a63a4ef..fcf93ae1 100644 --- a/src/replit_river/codegen/client.py +++ b/src/replit_river/codegen/client.py @@ -891,16 +891,16 @@ def __init__(self, client: river.Client[Any]): # the function strings in the branches below, otherwise `dedent` # will pick our indentation level for normalization, which will # break the "def" indentation presuppositions. - ottd_name = render_literal_type(output_type_type_adapter_name) + output_type_adapter = render_literal_type(output_type_type_adapter_name) parse_output_method = f"""\ - lambda x: {ottd_name} + lambda x: {output_type_adapter} .validate_python( x # type: ignore[arg-type] ) """ - ettd_name = render_literal_type(error_type_type_adapter_name) + error_type_adapter = render_literal_type(error_type_type_adapter_name) parse_error_method = f"""\ - lambda x: {ettd_name} + lambda x: {error_type_adapter} .validate_python( x # type: ignore[arg-type] ) From 70c5ab8172cba3a48a9b6dc4c1febc4ea26e4894 Mon Sep 17 00:00:00 2001 From: Toby Ho Date: Fri, 14 Mar 2025 16:40:54 -0400 Subject: [PATCH 13/29] test snapshot --- .../rpc/generated/test_service/rpc_method.py | 6 ++---- .../test_unknown_enum/enumService/needsEnum.py | 11 ++++++++--- .../enumService/needsEnumObject.py | 13 ++++--------- .../stream/generated/test_service/stream_method.py | 8 ++------ 4 files changed, 16 insertions(+), 22 deletions(-) diff --git a/tests/codegen/rpc/generated/test_service/rpc_method.py b/tests/codegen/rpc/generated/test_service/rpc_method.py index 5bf32657..d32b4645 100644 --- a/tests/codegen/rpc/generated/test_service/rpc_method.py +++ b/tests/codegen/rpc/generated/test_service/rpc_method.py @@ -39,13 +39,11 @@ class Rpc_MethodInput(TypedDict): data: str -Rpc_MethodInputTypeAdapter: TypeAdapter[Rpc_MethodInput] = TypeAdapter(Rpc_MethodInput) +Rpc_MethodInputTypeAdapter: TypeAdapter[Any] = TypeAdapter(Rpc_MethodInput) class Rpc_MethodOutput(BaseModel): data: str -Rpc_MethodOutputTypeAdapter: TypeAdapter[Rpc_MethodOutput] = TypeAdapter( - Rpc_MethodOutput -) +Rpc_MethodOutputTypeAdapter: TypeAdapter[Any] = TypeAdapter(Rpc_MethodOutput) diff --git a/tests/codegen/snapshot/snapshots/test_unknown_enum/enumService/needsEnum.py b/tests/codegen/snapshot/snapshots/test_unknown_enum/enumService/needsEnum.py index 3555eaac..95e6bd2c 100644 --- a/tests/codegen/snapshot/snapshots/test_unknown_enum/enumService/needsEnum.py +++ b/tests/codegen/snapshot/snapshots/test_unknown_enum/enumService/needsEnum.py @@ -26,14 +26,19 @@ NeedsenumInput = Literal["in_first"] | Literal["in_second"] encode_NeedsenumInput: Callable[["NeedsenumInput"], Any] = lambda x: x -NeedsenumInputTypeAdapter: TypeAdapter[NeedsenumInput] = TypeAdapter(NeedsenumInput) + +NeedsenumInputTypeAdapter: TypeAdapter[Any] = TypeAdapter(NeedsenumInput) + NeedsenumOutput = Annotated[ Literal["out_first"] | Literal["out_second"] | RiverUnknownValue, WrapValidator(translate_unknown_value), ] -NeedsenumOutputTypeAdapter: TypeAdapter[NeedsenumOutput] = TypeAdapter(NeedsenumOutput) + +NeedsenumOutputTypeAdapter: TypeAdapter[Any] = TypeAdapter(NeedsenumOutput) + NeedsenumErrors = Annotated[ Literal["err_first"] | Literal["err_second"] | RiverUnknownValue, WrapValidator(translate_unknown_value), ] -NeedsenumErrorsTypeAdapter: TypeAdapter[NeedsenumErrors] = TypeAdapter(NeedsenumErrors) + +NeedsenumErrorsTypeAdapter: TypeAdapter[Any] = TypeAdapter(NeedsenumErrors) diff --git a/tests/codegen/snapshot/snapshots/test_unknown_enum/enumService/needsEnumObject.py b/tests/codegen/snapshot/snapshots/test_unknown_enum/enumService/needsEnumObject.py index e55e9e46..4e370433 100644 --- a/tests/codegen/snapshot/snapshots/test_unknown_enum/enumService/needsEnumObject.py +++ b/tests/codegen/snapshot/snapshots/test_unknown_enum/enumService/needsEnumObject.py @@ -71,9 +71,8 @@ class NeedsenumobjectInputOneOf_in_second(TypedDict): if x["kind"] == "in_first" else encode_NeedsenumobjectInputOneOf_in_second(x) ) -NeedsenumobjectInputTypeAdapter: TypeAdapter[NeedsenumobjectInput] = TypeAdapter( - NeedsenumobjectInput -) + +NeedsenumobjectInputTypeAdapter: TypeAdapter[Any] = TypeAdapter(NeedsenumobjectInput) class NeedsenumobjectOutputFooOneOf_out_first(BaseModel): @@ -106,9 +105,7 @@ class NeedsenumobjectOutput(BaseModel): foo: Optional[NeedsenumobjectOutputFoo] = None -NeedsenumobjectOutputTypeAdapter: TypeAdapter[NeedsenumobjectOutput] = TypeAdapter( - NeedsenumobjectOutput -) +NeedsenumobjectOutputTypeAdapter: TypeAdapter[Any] = TypeAdapter(NeedsenumobjectOutput) class NeedsenumobjectErrorsFooAnyOf_0(RiverError): @@ -131,6 +128,4 @@ class NeedsenumobjectErrors(RiverError): foo: Optional[NeedsenumobjectErrorsFoo] = None -NeedsenumobjectErrorsTypeAdapter: TypeAdapter[NeedsenumobjectErrors] = TypeAdapter( - NeedsenumobjectErrors -) +NeedsenumobjectErrorsTypeAdapter: TypeAdapter[Any] = TypeAdapter(NeedsenumobjectErrors) diff --git a/tests/codegen/stream/generated/test_service/stream_method.py b/tests/codegen/stream/generated/test_service/stream_method.py index 13d45c2d..1294a67a 100644 --- a/tests/codegen/stream/generated/test_service/stream_method.py +++ b/tests/codegen/stream/generated/test_service/stream_method.py @@ -39,15 +39,11 @@ class Stream_MethodInput(TypedDict): data: str -Stream_MethodInputTypeAdapter: TypeAdapter[Stream_MethodInput] = TypeAdapter( - Stream_MethodInput -) +Stream_MethodInputTypeAdapter: TypeAdapter[Any] = TypeAdapter(Stream_MethodInput) class Stream_MethodOutput(BaseModel): data: str -Stream_MethodOutputTypeAdapter: TypeAdapter[Stream_MethodOutput] = TypeAdapter( - Stream_MethodOutput -) +Stream_MethodOutputTypeAdapter: TypeAdapter[Any] = TypeAdapter(Stream_MethodOutput) From 1e3aa579318d5891e246d77c210eb2168926df90 Mon Sep 17 00:00:00 2001 From: Toby Ho Date: Fri, 14 Mar 2025 16:55:19 -0400 Subject: [PATCH 14/29] add UnknownRiverError and translate_unknown_error --- src/replit_river/client.py | 10 +++++++++ src/replit_river/codegen/client.py | 5 +++-- src/replit_river/codegen/typing.py | 4 ++-- .../rpc/generated/test_service/rpc_method.py | 8 ++++--- .../enumService/needsEnum.py | 16 +++++++------- .../enumService/needsEnumObject.py | 22 ++++++++++++------- .../generated/test_service/stream_method.py | 10 ++++++--- 7 files changed, 49 insertions(+), 26 deletions(-) diff --git a/src/replit_river/client.py b/src/replit_river/client.py index 5ee79a96..4e0e06dc 100644 --- a/src/replit_river/client.py +++ b/src/replit_river/client.py @@ -35,6 +35,9 @@ class RiverUnknownValue(BaseModel): tag: Literal["RiverUnknownValue"] value: Any +@dataclass(frozen=True) +class RiverUnknownError(RiverError): + pass def translate_unknown_value( value: Any, handler: Callable[[Any], Any], info: ValidationInfo @@ -44,6 +47,13 @@ def translate_unknown_value( except Exception: return RiverUnknownValue(tag="RiverUnknownValue", value=value) +def translate_unknown_error( + value: Any, handler: Callable[[Any], Any], info: ValidationInfo +) -> Any | RiverUnknownError: + try: + return handler(value) + except Exception: + return RiverUnknownError() class Client(Generic[HandshakeMetadataType]): def __init__( diff --git a/src/replit_river/codegen/client.py b/src/replit_river/codegen/client.py index fcf93ae1..576437dc 100644 --- a/src/replit_river/codegen/client.py +++ b/src/replit_river/codegen/client.py @@ -93,7 +93,7 @@ from pydantic import BaseModel, Field, TypeAdapter, WrapValidator from replit_river.error_schema import RiverError -from replit_river.client import RiverUnknownValue, translate_unknown_value +from replit_river.client import RiverUnknownError, translate_unknown_error import replit_river as river @@ -769,6 +769,7 @@ def append_type_adapter_definition( _type: TypeExpression, module_info: list[ModuleName], ) -> None: + varname = render_type_expr(type_adapter_name) rendered_type_expr = render_type_expr(_type) serdes.append( ( @@ -777,7 +778,7 @@ def append_type_adapter_definition( [ FileContents( dedent(f""" - {render_type_expr(type_adapter_name)}: TypeAdapter[Any] = ( + {varname}: TypeAdapter[{rendered_type_expr}] = ( TypeAdapter({rendered_type_expr}) ) """) diff --git a/src/replit_river/codegen/typing.py b/src/replit_river/codegen/typing.py index ea1f00b7..6c512851 100644 --- a/src/replit_river/codegen/typing.py +++ b/src/replit_river/codegen/typing.py @@ -87,8 +87,8 @@ def render_type_expr(value: TypeExpression) -> str: case OpenUnionTypeExpr(inner): return ( "Annotated[" - f"{render_type_expr(inner)} | RiverUnknownValue," - "WrapValidator(translate_unknown_value)" + f"{render_type_expr(inner)} | RiverUnknownError," + "WrapValidator(translate_unknown_error)" "]" ) case TypeName(name): diff --git a/tests/codegen/rpc/generated/test_service/rpc_method.py b/tests/codegen/rpc/generated/test_service/rpc_method.py index d32b4645..a88eab29 100644 --- a/tests/codegen/rpc/generated/test_service/rpc_method.py +++ b/tests/codegen/rpc/generated/test_service/rpc_method.py @@ -19,7 +19,7 @@ from pydantic import BaseModel, Field, TypeAdapter, WrapValidator from replit_river.error_schema import RiverError -from replit_river.client import RiverUnknownValue, translate_unknown_value +from replit_river.client import RiverUnknownError, translate_unknown_error import replit_river as river @@ -39,11 +39,13 @@ class Rpc_MethodInput(TypedDict): data: str -Rpc_MethodInputTypeAdapter: TypeAdapter[Any] = TypeAdapter(Rpc_MethodInput) +Rpc_MethodInputTypeAdapter: TypeAdapter[Rpc_MethodInput] = TypeAdapter(Rpc_MethodInput) class Rpc_MethodOutput(BaseModel): data: str -Rpc_MethodOutputTypeAdapter: TypeAdapter[Any] = TypeAdapter(Rpc_MethodOutput) +Rpc_MethodOutputTypeAdapter: TypeAdapter[Rpc_MethodOutput] = TypeAdapter( + Rpc_MethodOutput +) diff --git a/tests/codegen/snapshot/snapshots/test_unknown_enum/enumService/needsEnum.py b/tests/codegen/snapshot/snapshots/test_unknown_enum/enumService/needsEnum.py index 95e6bd2c..afd893f4 100644 --- a/tests/codegen/snapshot/snapshots/test_unknown_enum/enumService/needsEnum.py +++ b/tests/codegen/snapshot/snapshots/test_unknown_enum/enumService/needsEnum.py @@ -19,7 +19,7 @@ from pydantic import BaseModel, Field, TypeAdapter, WrapValidator from replit_river.error_schema import RiverError -from replit_river.client import RiverUnknownValue, translate_unknown_value +from replit_river.client import RiverUnknownError, translate_unknown_error import replit_river as river @@ -27,18 +27,18 @@ NeedsenumInput = Literal["in_first"] | Literal["in_second"] encode_NeedsenumInput: Callable[["NeedsenumInput"], Any] = lambda x: x -NeedsenumInputTypeAdapter: TypeAdapter[Any] = TypeAdapter(NeedsenumInput) +NeedsenumInputTypeAdapter: TypeAdapter[NeedsenumInput] = TypeAdapter(NeedsenumInput) NeedsenumOutput = Annotated[ - Literal["out_first"] | Literal["out_second"] | RiverUnknownValue, - WrapValidator(translate_unknown_value), + Literal["out_first"] | Literal["out_second"] | RiverUnknownError, + WrapValidator(translate_unknown_error), ] -NeedsenumOutputTypeAdapter: TypeAdapter[Any] = TypeAdapter(NeedsenumOutput) +NeedsenumOutputTypeAdapter: TypeAdapter[NeedsenumOutput] = TypeAdapter(NeedsenumOutput) NeedsenumErrors = Annotated[ - Literal["err_first"] | Literal["err_second"] | RiverUnknownValue, - WrapValidator(translate_unknown_value), + Literal["err_first"] | Literal["err_second"] | RiverUnknownError, + WrapValidator(translate_unknown_error), ] -NeedsenumErrorsTypeAdapter: TypeAdapter[Any] = TypeAdapter(NeedsenumErrors) +NeedsenumErrorsTypeAdapter: TypeAdapter[NeedsenumErrors] = TypeAdapter(NeedsenumErrors) diff --git a/tests/codegen/snapshot/snapshots/test_unknown_enum/enumService/needsEnumObject.py b/tests/codegen/snapshot/snapshots/test_unknown_enum/enumService/needsEnumObject.py index 4e370433..a0a54b6d 100644 --- a/tests/codegen/snapshot/snapshots/test_unknown_enum/enumService/needsEnumObject.py +++ b/tests/codegen/snapshot/snapshots/test_unknown_enum/enumService/needsEnumObject.py @@ -19,7 +19,7 @@ from pydantic import BaseModel, Field, TypeAdapter, WrapValidator from replit_river.error_schema import RiverError -from replit_river.client import RiverUnknownValue, translate_unknown_value +from replit_river.client import RiverUnknownError, translate_unknown_error import replit_river as river @@ -72,7 +72,9 @@ class NeedsenumobjectInputOneOf_in_second(TypedDict): else encode_NeedsenumobjectInputOneOf_in_second(x) ) -NeedsenumobjectInputTypeAdapter: TypeAdapter[Any] = TypeAdapter(NeedsenumobjectInput) +NeedsenumobjectInputTypeAdapter: TypeAdapter[NeedsenumobjectInput] = TypeAdapter( + NeedsenumobjectInput +) class NeedsenumobjectOutputFooOneOf_out_first(BaseModel): @@ -96,8 +98,8 @@ class NeedsenumobjectOutputFooOneOf_out_second(BaseModel): NeedsenumobjectOutputFoo = Annotated[ NeedsenumobjectOutputFooOneOf_out_first | NeedsenumobjectOutputFooOneOf_out_second - | RiverUnknownValue, - WrapValidator(translate_unknown_value), + | RiverUnknownError, + WrapValidator(translate_unknown_error), ] @@ -105,7 +107,9 @@ class NeedsenumobjectOutput(BaseModel): foo: Optional[NeedsenumobjectOutputFoo] = None -NeedsenumobjectOutputTypeAdapter: TypeAdapter[Any] = TypeAdapter(NeedsenumobjectOutput) +NeedsenumobjectOutputTypeAdapter: TypeAdapter[NeedsenumobjectOutput] = TypeAdapter( + NeedsenumobjectOutput +) class NeedsenumobjectErrorsFooAnyOf_0(RiverError): @@ -119,8 +123,8 @@ class NeedsenumobjectErrorsFooAnyOf_1(RiverError): NeedsenumobjectErrorsFoo = Annotated[ NeedsenumobjectErrorsFooAnyOf_0 | NeedsenumobjectErrorsFooAnyOf_1 - | RiverUnknownValue, - WrapValidator(translate_unknown_value), + | RiverUnknownError, + WrapValidator(translate_unknown_error), ] @@ -128,4 +132,6 @@ class NeedsenumobjectErrors(RiverError): foo: Optional[NeedsenumobjectErrorsFoo] = None -NeedsenumobjectErrorsTypeAdapter: TypeAdapter[Any] = TypeAdapter(NeedsenumobjectErrors) +NeedsenumobjectErrorsTypeAdapter: TypeAdapter[NeedsenumobjectErrors] = TypeAdapter( + NeedsenumobjectErrors +) diff --git a/tests/codegen/stream/generated/test_service/stream_method.py b/tests/codegen/stream/generated/test_service/stream_method.py index 1294a67a..a2516e29 100644 --- a/tests/codegen/stream/generated/test_service/stream_method.py +++ b/tests/codegen/stream/generated/test_service/stream_method.py @@ -19,7 +19,7 @@ from pydantic import BaseModel, Field, TypeAdapter, WrapValidator from replit_river.error_schema import RiverError -from replit_river.client import RiverUnknownValue, translate_unknown_value +from replit_river.client import RiverUnknownError, translate_unknown_error import replit_river as river @@ -39,11 +39,15 @@ class Stream_MethodInput(TypedDict): data: str -Stream_MethodInputTypeAdapter: TypeAdapter[Any] = TypeAdapter(Stream_MethodInput) +Stream_MethodInputTypeAdapter: TypeAdapter[Stream_MethodInput] = TypeAdapter( + Stream_MethodInput +) class Stream_MethodOutput(BaseModel): data: str -Stream_MethodOutputTypeAdapter: TypeAdapter[Any] = TypeAdapter(Stream_MethodOutput) +Stream_MethodOutputTypeAdapter: TypeAdapter[Stream_MethodOutput] = TypeAdapter( + Stream_MethodOutput +) From 693b174bfa1c74354d65d395105825ea84f29c4a Mon Sep 17 00:00:00 2001 From: Toby Ho Date: Fri, 14 Mar 2025 16:55:50 -0400 Subject: [PATCH 15/29] lint --- src/replit_river/client.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/replit_river/client.py b/src/replit_river/client.py index 4e0e06dc..8c2ffbff 100644 --- a/src/replit_river/client.py +++ b/src/replit_river/client.py @@ -35,10 +35,12 @@ class RiverUnknownValue(BaseModel): tag: Literal["RiverUnknownValue"] value: Any + @dataclass(frozen=True) class RiverUnknownError(RiverError): pass + def translate_unknown_value( value: Any, handler: Callable[[Any], Any], info: ValidationInfo ) -> Any | RiverUnknownValue: @@ -47,6 +49,7 @@ def translate_unknown_value( except Exception: return RiverUnknownValue(tag="RiverUnknownValue", value=value) + def translate_unknown_error( value: Any, handler: Callable[[Any], Any], info: ValidationInfo ) -> Any | RiverUnknownError: @@ -55,6 +58,7 @@ def translate_unknown_error( except Exception: return RiverUnknownError() + class Client(Generic[HandshakeMetadataType]): def __init__( self, From 7203fef6dc1f12ab5267c0afdad4a40a2bec3a2a Mon Sep 17 00:00:00 2001 From: Toby Ho Date: Fri, 14 Mar 2025 17:01:42 -0400 Subject: [PATCH 16/29] fixes --- src/replit_river/client.py | 5 ++--- src/replit_river/error_schema.py | 3 +++ 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/src/replit_river/client.py b/src/replit_river/client.py index 8c2ffbff..aedc056a 100644 --- a/src/replit_river/client.py +++ b/src/replit_river/client.py @@ -13,7 +13,7 @@ ) from replit_river.client_transport import ClientTransport -from replit_river.error_schema import RiverError, RiverException +from replit_river.error_schema import ERROR_CODE_UNKNOWN, RiverError, RiverException from replit_river.transport_options import ( HandshakeMetadataType, TransportOptions, @@ -36,7 +36,6 @@ class RiverUnknownValue(BaseModel): value: Any -@dataclass(frozen=True) class RiverUnknownError(RiverError): pass @@ -56,7 +55,7 @@ def translate_unknown_error( try: return handler(value) except Exception: - return RiverUnknownError() + return RiverUnknownError(code=ERROR_CODE_UNKNOWN, message="Unknown error") class Client(Generic[HandshakeMetadataType]): diff --git a/src/replit_river/error_schema.py b/src/replit_river/error_schema.py index a97fbc9c..a627a702 100644 --- a/src/replit_river/error_schema.py +++ b/src/replit_river/error_schema.py @@ -17,6 +17,9 @@ # ERROR_CODE_CANCEL is the code used when either server or client cancels the stream. ERROR_CODE_CANCEL = "CANCEL" +# ERROR_CODE_UNKNOWN is the code for the RiverUnknownError +ERROR_CODE_UNKNOWN = "UNKNOWN" + class RiverError(BaseModel): """Error message from the server.""" From 26f386774503bcf0d48b035bafd30ce60b1436e0 Mon Sep 17 00:00:00 2001 From: Toby Ho Date: Fri, 14 Mar 2025 17:11:56 -0400 Subject: [PATCH 17/29] reverts --- src/replit_river/client.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/replit_river/client.py b/src/replit_river/client.py index aedc056a..51c969bc 100644 --- a/src/replit_river/client.py +++ b/src/replit_river/client.py @@ -21,6 +21,7 @@ ) from .rpc import ( + ErrorType, InitType, RequestType, ResponseType, @@ -92,7 +93,7 @@ async def send_rpc( request: RequestType, request_serializer: Callable[[RequestType], Any], response_deserializer: Callable[[Any], ResponseType], - error_deserializer: Callable[[Any], Any], + error_deserializer: Callable[[Any], ErrorType], timeout: timedelta, ) -> ResponseType: with _trace_procedure("rpc", service_name, procedure_name) as span_handle: @@ -117,7 +118,7 @@ async def send_upload( init_serializer: Optional[Callable[[InitType], Any]], request_serializer: Callable[[RequestType], Any], response_deserializer: Callable[[Any], ResponseType], - error_deserializer: Callable[[Any], Any], + error_deserializer: Callable[[Any], ErrorType], ) -> ResponseType: with _trace_procedure("upload", service_name, procedure_name) as span_handle: session = await self._transport.get_or_create_session() @@ -140,7 +141,7 @@ async def send_subscription( request: RequestType, request_serializer: Callable[[RequestType], Any], response_deserializer: Callable[[Any], ResponseType], - error_deserializer: Callable[[Any], Any], + error_deserializer: Callable[[Any], ErrorType], ) -> AsyncGenerator[Union[ResponseType, RiverError], None]: with _trace_procedure( "subscription", service_name, procedure_name @@ -168,7 +169,7 @@ async def send_stream( init_serializer: Optional[Callable[[InitType], Any]], request_serializer: Callable[[RequestType], Any], response_deserializer: Callable[[Any], ResponseType], - error_deserializer: Callable[[Any], Any], + error_deserializer: Callable[[Any], ErrorType], ) -> AsyncGenerator[Union[ResponseType, RiverError], None]: with _trace_procedure("stream", service_name, procedure_name) as span_handle: session = await self._transport.get_or_create_session() From 86b36a1fdb904f3619ca6973d91653c5bc0f734f Mon Sep 17 00:00:00 2001 From: Toby Ho Date: Fri, 14 Mar 2025 17:45:17 -0400 Subject: [PATCH 18/29] fixed test --- .../enumService/needsEnum.py | 16 +++++++++++- tests/codegen/snapshot/test_enum.py | 26 ++++++++++++++++--- 2 files changed, 37 insertions(+), 5 deletions(-) diff --git a/tests/codegen/snapshot/snapshots/test_unknown_enum/enumService/needsEnum.py b/tests/codegen/snapshot/snapshots/test_unknown_enum/enumService/needsEnum.py index afd893f4..8e4636fb 100644 --- a/tests/codegen/snapshot/snapshots/test_unknown_enum/enumService/needsEnum.py +++ b/tests/codegen/snapshot/snapshots/test_unknown_enum/enumService/needsEnum.py @@ -36,9 +36,23 @@ NeedsenumOutputTypeAdapter: TypeAdapter[NeedsenumOutput] = TypeAdapter(NeedsenumOutput) + +class NeedsenumErrorsOneOf_err_first(RiverError): + code: Literal["err_first"] + message: str + + +class NeedsenumErrorsOneOf_err_second(RiverError): + code: Literal["err_second"] + message: str + + NeedsenumErrors = Annotated[ - Literal["err_first"] | Literal["err_second"] | RiverUnknownError, + NeedsenumErrorsOneOf_err_first + | NeedsenumErrorsOneOf_err_second + | RiverUnknownError, WrapValidator(translate_unknown_error), ] + NeedsenumErrorsTypeAdapter: TypeAdapter[NeedsenumErrors] = TypeAdapter(NeedsenumErrors) diff --git a/tests/codegen/snapshot/test_enum.py b/tests/codegen/snapshot/test_enum.py index e3ccfc0d..e2c0dc29 100644 --- a/tests/codegen/snapshot/test_enum.py +++ b/tests/codegen/snapshot/test_enum.py @@ -40,12 +40,30 @@ "errors": { "anyOf": [ { - "type": "string", - "const": "err_first" + "type": "object", + "properties": { + "code": { + "const": "err_first", + "type": "string" + }, + "message": { + "type": "string" + } + }, + "required": ["code", "message"] }, { - "type": "string", - "const": "err_second" + "type": "object", + "properties": { + "code": { + "const": "err_second", + "type": "string" + }, + "message": { + "type": "string" + } + }, + "required": ["code", "message"] } ] } From c46e64afb4c5da4366bd5c275a354a769e61e91d Mon Sep 17 00:00:00 2001 From: Toby Ho Date: Fri, 14 Mar 2025 20:21:42 -0400 Subject: [PATCH 19/29] Update src/replit_river/client.py Co-authored-by: Devon Stewart --- src/replit_river/client.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/replit_river/client.py b/src/replit_river/client.py index 51c969bc..3f986ca1 100644 --- a/src/replit_river/client.py +++ b/src/replit_river/client.py @@ -56,7 +56,13 @@ def translate_unknown_error( try: return handler(value) except Exception: - return RiverUnknownError(code=ERROR_CODE_UNKNOWN, message="Unknown error") + if isinstance(value, dict) and "code" in value and "message" in value: + return RiverUnknownError( + code=value["code"], + message=value["message"], + ) + else: + return RiverUnknownError(code=ERROR_CODE_UNKNOWN, message="Unknown error") class Client(Generic[HandshakeMetadataType]): From 672492fa5be6cd5ae26f2559b48641a29741c361 Mon Sep 17 00:00:00 2001 From: Toby Ho Date: Fri, 21 Mar 2025 10:01:36 -0400 Subject: [PATCH 20/29] make error type and translate function dynamic --- src/replit_river/codegen/client.py | 18 +++++++++++++++--- src/replit_river/codegen/typing.py | 6 ++++-- 2 files changed, 19 insertions(+), 5 deletions(-) diff --git a/src/replit_river/codegen/client.py b/src/replit_river/codegen/client.py index 16cbd915..311647f8 100644 --- a/src/replit_river/codegen/client.py +++ b/src/replit_river/codegen/client.py @@ -93,7 +93,8 @@ from pydantic import BaseModel, Field, TypeAdapter, WrapValidator from replit_river.error_schema import RiverError -from replit_river.client import RiverUnknownError, translate_unknown_error +from replit_river.client import RiverUnknownError, translate_unknown_error, \ + RiverUnknownValue, translate_unknown_value import replit_river as river @@ -168,6 +169,17 @@ def encode_type( in_module: list[ModuleName], permit_unknown_members: bool, ) -> Tuple[TypeExpression, list[ModuleName], list[FileContents], set[TypeName]]: + def _make_open_union_type_expr(one_of: list[TypeExpression]) -> OpenUnionTypeExpr: + return OpenUnionTypeExpr( + UnionTypeExpr(one_of), + fallback_type="RiverUnknownError" + if base_model == "RiverError" + else "RiverUnknownValue", + validator_function="translate_unknown_error" + if base_model == "RiverError" + else "translate_unknown_value", + ) + encoder_name: TypeName | None = None # defining this up here to placate mypy chunks: List[FileContents] = [] if isinstance(type, RiverNotType): @@ -318,7 +330,7 @@ def flatten_union(tpe: RiverType) -> list[RiverType]: ) union: TypeExpression if permit_unknown_members: - union = OpenUnionTypeExpr(UnionTypeExpr(one_of)) + union = _make_open_union_type_expr(one_of) else: union = UnionTypeExpr(one_of) chunks.append( @@ -392,7 +404,7 @@ def flatten_union(tpe: RiverType) -> list[RiverType]: ) raise ValueError(f"What does it mean to have {_o2} here?") if permit_unknown_members: - union = OpenUnionTypeExpr(UnionTypeExpr(any_of)) + union = _make_open_union_type_expr(any_of) else: union = UnionTypeExpr(any_of) if is_literal(type): diff --git a/src/replit_river/codegen/typing.py b/src/replit_river/codegen/typing.py index 6c512851..0c72445f 100644 --- a/src/replit_river/codegen/typing.py +++ b/src/replit_river/codegen/typing.py @@ -58,6 +58,8 @@ def __str__(self) -> str: @dataclass(frozen=True) class OpenUnionTypeExpr: union: UnionTypeExpr + fallback_type: str + validator_function: str def __str__(self) -> str: raise Exception("Complex type must be put through render_type_expr!") @@ -87,8 +89,8 @@ def render_type_expr(value: TypeExpression) -> str: case OpenUnionTypeExpr(inner): return ( "Annotated[" - f"{render_type_expr(inner)} | RiverUnknownError," - "WrapValidator(translate_unknown_error)" + f"{render_type_expr(inner)} | {value.fallback_type}," + f"WrapValidator({value.validator_function})" "]" ) case TypeName(name): From 3f959e7d2921dbc18e86dbcf20ed82054a63fd39 Mon Sep 17 00:00:00 2001 From: Toby Ho Date: Fri, 21 Mar 2025 10:05:41 -0400 Subject: [PATCH 21/29] updated snapshots --- .../rpc/generated/test_service/rpc_method.py | 7 ++++++- .../test_basic_stream/test_service/stream_method.py | 7 ++++++- .../test_service/__init__.py | 2 +- .../test_service/pathological_method.py | 11 ++++++++--- .../test_unknown_enum/enumService/needsEnum.py | 13 ++++++++++--- .../enumService/needsEnumObject.py | 11 ++++++++--- 6 files changed, 39 insertions(+), 12 deletions(-) diff --git a/tests/codegen/rpc/generated/test_service/rpc_method.py b/tests/codegen/rpc/generated/test_service/rpc_method.py index c377dc0d..dfe8a47c 100644 --- a/tests/codegen/rpc/generated/test_service/rpc_method.py +++ b/tests/codegen/rpc/generated/test_service/rpc_method.py @@ -12,7 +12,12 @@ from pydantic import BaseModel, Field, TypeAdapter, WrapValidator from replit_river.error_schema import RiverError -from replit_river.client import RiverUnknownError, translate_unknown_error +from replit_river.client import ( + RiverUnknownError, + translate_unknown_error, + RiverUnknownValue, + translate_unknown_value, +) import replit_river as river diff --git a/tests/codegen/snapshot/snapshots/test_basic_stream/test_service/stream_method.py b/tests/codegen/snapshot/snapshots/test_basic_stream/test_service/stream_method.py index 8d59e9ad..5baa9c40 100644 --- a/tests/codegen/snapshot/snapshots/test_basic_stream/test_service/stream_method.py +++ b/tests/codegen/snapshot/snapshots/test_basic_stream/test_service/stream_method.py @@ -12,7 +12,12 @@ from pydantic import BaseModel, Field, TypeAdapter, WrapValidator from replit_river.error_schema import RiverError -from replit_river.client import RiverUnknownError, translate_unknown_error +from replit_river.client import ( + RiverUnknownError, + translate_unknown_error, + RiverUnknownValue, + translate_unknown_value, +) import replit_river as river diff --git a/tests/codegen/snapshot/snapshots/test_pathological_types/test_service/__init__.py b/tests/codegen/snapshot/snapshots/test_pathological_types/test_service/__init__.py index bc60d38a..3a578118 100644 --- a/tests/codegen/snapshot/snapshots/test_pathological_types/test_service/__init__.py +++ b/tests/codegen/snapshot/snapshots/test_pathological_types/test_service/__init__.py @@ -31,7 +31,7 @@ encode_Pathological_MethodInputReq_Obj_Undefined, ) -boolTypeAdapter: TypeAdapter[Any] = TypeAdapter(bool) +boolTypeAdapter: TypeAdapter[bool] = TypeAdapter(bool) class Test_ServiceService: diff --git a/tests/codegen/snapshot/snapshots/test_pathological_types/test_service/pathological_method.py b/tests/codegen/snapshot/snapshots/test_pathological_types/test_service/pathological_method.py index a9b530f5..137add7b 100644 --- a/tests/codegen/snapshot/snapshots/test_pathological_types/test_service/pathological_method.py +++ b/tests/codegen/snapshot/snapshots/test_pathological_types/test_service/pathological_method.py @@ -12,7 +12,12 @@ from pydantic import BaseModel, Field, TypeAdapter, WrapValidator from replit_river.error_schema import RiverError -from replit_river.client import RiverUnknownValue, translate_unknown_value +from replit_river.client import ( + RiverUnknownError, + translate_unknown_error, + RiverUnknownValue, + translate_unknown_value, +) import replit_river as river @@ -473,6 +478,6 @@ class Pathological_MethodInput(TypedDict): undefined: NotRequired[None] -Pathological_MethodInputTypeAdapter: TypeAdapter[Any] = TypeAdapter( - Pathological_MethodInput +Pathological_MethodInputTypeAdapter: TypeAdapter[Pathological_MethodInput] = ( + TypeAdapter(Pathological_MethodInput) ) diff --git a/tests/codegen/snapshot/snapshots/test_unknown_enum/enumService/needsEnum.py b/tests/codegen/snapshot/snapshots/test_unknown_enum/enumService/needsEnum.py index 42edb8f2..8f325775 100644 --- a/tests/codegen/snapshot/snapshots/test_unknown_enum/enumService/needsEnum.py +++ b/tests/codegen/snapshot/snapshots/test_unknown_enum/enumService/needsEnum.py @@ -12,7 +12,12 @@ from pydantic import BaseModel, Field, TypeAdapter, WrapValidator from replit_river.error_schema import RiverError -from replit_river.client import RiverUnknownError, translate_unknown_error +from replit_river.client import ( + RiverUnknownError, + translate_unknown_error, + RiverUnknownValue, + translate_unknown_value, +) import replit_river as river @@ -45,8 +50,10 @@ class NeedsenumErrorsOneOf_err_second(RiverError): NeedsenumErrors = Annotated[ - Literal["err_first", "err_second"] | RiverUnknownValue, - WrapValidator(translate_unknown_value), + NeedsenumErrorsOneOf_err_first + | NeedsenumErrorsOneOf_err_second + | RiverUnknownError, + WrapValidator(translate_unknown_error), ] diff --git a/tests/codegen/snapshot/snapshots/test_unknown_enum/enumService/needsEnumObject.py b/tests/codegen/snapshot/snapshots/test_unknown_enum/enumService/needsEnumObject.py index 90147c1f..9d74699a 100644 --- a/tests/codegen/snapshot/snapshots/test_unknown_enum/enumService/needsEnumObject.py +++ b/tests/codegen/snapshot/snapshots/test_unknown_enum/enumService/needsEnumObject.py @@ -12,7 +12,12 @@ from pydantic import BaseModel, Field, TypeAdapter, WrapValidator from replit_river.error_schema import RiverError -from replit_river.client import RiverUnknownError, translate_unknown_error +from replit_river.client import ( + RiverUnknownError, + translate_unknown_error, + RiverUnknownValue, + translate_unknown_value, +) import replit_river as river @@ -98,8 +103,8 @@ class NeedsenumobjectOutputFooOneOf_out_second(BaseModel): NeedsenumobjectOutputFoo = Annotated[ NeedsenumobjectOutputFooOneOf_out_first | NeedsenumobjectOutputFooOneOf_out_second - | RiverUnknownError, - WrapValidator(translate_unknown_error), + | RiverUnknownValue, + WrapValidator(translate_unknown_value), ] From b16f1814af6fa8e1f72229970282ef284138bb66 Mon Sep 17 00:00:00 2001 From: Toby Ho Date: Fri, 21 Mar 2025 10:16:26 -0400 Subject: [PATCH 22/29] lint --- src/replit_river/codegen/client.py | 8 ++++---- src/replit_river/codegen/typing.py | 7 ++++--- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/src/replit_river/codegen/client.py b/src/replit_river/codegen/client.py index af86cccd..5dbc5725 100644 --- a/src/replit_river/codegen/client.py +++ b/src/replit_river/codegen/client.py @@ -159,11 +159,11 @@ def _make_open_union_type_expr(one_of: list[TypeExpression]) -> OpenUnionTypeExp return OpenUnionTypeExpr( UnionTypeExpr(one_of), fallback_type="RiverUnknownError" - if base_model == "RiverError" - else "RiverUnknownValue", + if base_model == "RiverError" + else "RiverUnknownValue", validator_function="translate_unknown_error" - if base_model == "RiverError" - else "translate_unknown_value", + if base_model == "RiverError" + else "translate_unknown_value", ) encoder_name: TypeName | None = None # defining this up here to placate mypy diff --git a/src/replit_river/codegen/typing.py b/src/replit_river/codegen/typing.py index 92afbb2c..53c028ff 100644 --- a/src/replit_river/codegen/typing.py +++ b/src/replit_river/codegen/typing.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import NewType, assert_never +from typing import NewType, assert_never, cast ModuleName = NewType("ModuleName", str) ClassName = NewType("ClassName", str) @@ -184,10 +184,11 @@ def render_type_expr(value: TypeExpression) -> str: retval = "None" return retval case OpenUnionTypeExpr(inner): + open_union = cast(OpenUnionTypeExpr, value) return ( "Annotated[" - f"{render_type_expr(inner)} | {value.fallback_type}," - f"WrapValidator({value.validator_function})" + f"{render_type_expr(inner)} | {open_union.fallback_type}," + f"WrapValidator({open_union.validator_function})" "]" ) case TypeName(name): From ada5db6824dc6ca244ba5c85e3e8029f93b352e0 Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Fri, 21 Mar 2025 17:57:25 -0700 Subject: [PATCH 23/29] Make the code more readable --- src/replit_river/codegen/client.py | 21 ++++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/src/replit_river/codegen/client.py b/src/replit_river/codegen/client.py index 5dbc5725..453b7a53 100644 --- a/src/replit_river/codegen/client.py +++ b/src/replit_river/codegen/client.py @@ -156,15 +156,18 @@ def encode_type( permit_unknown_members: bool, ) -> tuple[TypeExpression, list[ModuleName], list[FileContents], set[TypeName]]: def _make_open_union_type_expr(one_of: list[TypeExpression]) -> OpenUnionTypeExpr: - return OpenUnionTypeExpr( - UnionTypeExpr(one_of), - fallback_type="RiverUnknownError" - if base_model == "RiverError" - else "RiverUnknownValue", - validator_function="translate_unknown_error" - if base_model == "RiverError" - else "translate_unknown_value", - ) + if base_model == "RiverError": + return OpenUnionTypeExpr( + UnionTypeExpr(one_of), + fallback_type="RiverUnknownError", + validator_function="translate_unknown_error", + ) + else: + return OpenUnionTypeExpr( + UnionTypeExpr(one_of), + fallback_type="RiverUnknownValue", + validator_function="translate_unknown_value", + ) encoder_name: TypeName | None = None # defining this up here to placate mypy chunks: list[FileContents] = [] From e6edc73bb7d7149d3bb6acf83f8d87bb131015ab Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Fri, 21 Mar 2025 18:19:16 -0700 Subject: [PATCH 24/29] Test for unknown error values --- tests/codegen/snapshot/test_enum.py | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/tests/codegen/snapshot/test_enum.py b/tests/codegen/snapshot/test_enum.py index 976a62d9..d9953dee 100644 --- a/tests/codegen/snapshot/test_enum.py +++ b/tests/codegen/snapshot/test_enum.py @@ -1,3 +1,4 @@ +import importlib from io import StringIO from pytest_snapshot.plugin import Snapshot @@ -175,3 +176,30 @@ def test_unknown_enum(snapshot: Snapshot) -> None: target_path="test_unknown_enum", client_name="foo", ) + + import tests.codegen.snapshot.snapshots.test_unknown_enum + + importlib.reload(tests.codegen.snapshot.snapshots.test_unknown_enum) + from tests.codegen.snapshot.snapshots.test_unknown_enum.enumService.needsEnum import ( # noqa + NeedsenumErrorsTypeAdapter, + ) + + payloads: list[dict[str, str]] = [ + { + "code": "err_first", + "message": "This is a message", + }, + { + "code": "err_second", + "message": "This is a message", + }, + { + "code": "unknown_error", + "message": "This is new!", + }, + ] + + for error in payloads: + x = NeedsenumErrorsTypeAdapter.validate_python(error) + assert x.code == error["code"] + assert x.message == error["message"] From ca6c0fe2b97aa3b93c59cf4df5560f5e5461148a Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Fri, 21 Mar 2025 18:38:39 -0700 Subject: [PATCH 25/29] Make it possible to request a client that can emit errors --- tests/river_fixtures/clientserver.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/tests/river_fixtures/clientserver.py b/tests/river_fixtures/clientserver.py index cf1e1e29..bf576b5c 100644 --- a/tests/river_fixtures/clientserver.py +++ b/tests/river_fixtures/clientserver.py @@ -32,7 +32,7 @@ def server( @pytest.fixture -async def client( +async def erroringClient( server: Server, transport_options: TransportOptions, no_logging_error: NoErrors, @@ -69,5 +69,13 @@ async def websocket_uri_factory() -> UriAndMetadata[None]: await server.close() if binding: await binding.wait_closed() - # Server should close normally - no_logging_error() + + +@pytest.fixture +async def client( + erroringClient: Client, + no_logging_error: NoErrors, +) -> AsyncGenerator[Client, None]: + yield erroringClient + # Server should close normally + no_logging_error() From a0db2fed50cfd07fb3d12cc57ef4f8984a9e91df Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Fri, 21 Mar 2025 19:56:23 -0700 Subject: [PATCH 26/29] Adding a streaming method that can emit known and unknown errors --- tests/codegen/stream/schema.json | 39 ++++++++++++++++++++++++++ tests/common_handlers.py | 48 ++++++++++++++++++++++++++++++++ 2 files changed, 87 insertions(+) diff --git a/tests/codegen/stream/schema.json b/tests/codegen/stream/schema.json index c7aeebcb..d24df12c 100644 --- a/tests/codegen/stream/schema.json +++ b/tests/codegen/stream/schema.json @@ -25,6 +25,45 @@ "not": {} }, "type": "stream" + }, + "emit_error": { + "input": { + "type": "integer" + }, + "output": { + "type": "boolean" + }, + "errors": { + "anyOf": [ + { + "type": "object", + "properties": { + "code": { + "const": "DATA_LOSS", + "type": "string" + }, + "message": { + "type": "string" + } + }, + "required": ["code", "message"] + }, + { + "type": "object", + "properties": { + "code": { + "const": "UNEXPECTED_DISCONNECT", + "type": "string" + }, + "message": { + "type": "string" + } + }, + "required": ["code", "message"] + } + ] + }, + "type": "stream" } } } diff --git a/tests/common_handlers.py b/tests/common_handlers.py index 19a5e2fa..631038f6 100644 --- a/tests/common_handlers.py +++ b/tests/common_handlers.py @@ -80,3 +80,51 @@ async def stream_handler( stream_method_handler(stream_handler, deserialize_request, serialize_response), ), } + + +async def stream_error( + request: Iterator[int] | AsyncIterator[int], + context: grpc.aio.ServicerContext, +) -> AsyncGenerator[bool, None]: + if isinstance(request, AsyncIterator): + async for data in request: + match data % 4: + case 0: + yield True + case 1: + yield False + case 2: + await context.abort( + grpc.StatusCode.DATA_LOSS, + details="We know about the Data Loss error code", + ) + case 3: + await context.abort( + grpc.StatusCode.UNIMPLEMENTED, + details="This is a completely unknown error code", + ) + else: + for data in request: + match data % 4: + case 0: + yield True + case 1: + yield False + case 2: + await context.abort( + grpc.StatusCode.DATA_LOSS, + details="We know about the Data Loss error code", + ) + case 3: + await context.abort( + grpc.StatusCode.UNIMPLEMENTED, + details="This is a completely unknown error code", + ) + + +error_stream: HandlerMapping = { + ("test_service", "emit_error"): ( + "stream", + stream_method_handler(stream_error, lambda x: x, lambda x: x), + ), +} From 75f3431aecdf6d93af427aa6034a68e63aa28265 Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Fri, 21 Mar 2025 19:56:43 -0700 Subject: [PATCH 27/29] Regenerating code --- .../test_service/__init__.py | 25 +++++++++++ .../test_service/emit_error.py | 45 +++++++++++++++++++ 2 files changed, 70 insertions(+) create mode 100644 tests/codegen/snapshot/snapshots/test_basic_stream/test_service/emit_error.py diff --git a/tests/codegen/snapshot/snapshots/test_basic_stream/test_service/__init__.py b/tests/codegen/snapshot/snapshots/test_basic_stream/test_service/__init__.py index 4133aea3..44d6c18c 100644 --- a/tests/codegen/snapshot/snapshots/test_basic_stream/test_service/__init__.py +++ b/tests/codegen/snapshot/snapshots/test_basic_stream/test_service/__init__.py @@ -16,6 +16,12 @@ Stream_MethodOutputTypeAdapter, encode_Stream_MethodInput, ) +from .emit_error import Emit_ErrorErrors, Emit_ErrorErrorsTypeAdapter + +intTypeAdapter: TypeAdapter[int] = TypeAdapter(int) + + +boolTypeAdapter: TypeAdapter[bool] = TypeAdapter(bool) class Test_ServiceService: @@ -40,3 +46,22 @@ async def stream_method( x # type: ignore[arg-type] ), ) + + async def emit_error( + self, + inputStream: AsyncIterable[int], + ) -> AsyncIterator[bool | Emit_ErrorErrors | RiverError]: + return self.client.send_stream( + "test_service", + "emit_error", + None, + inputStream, + None, + lambda x: x, + lambda x: boolTypeAdapter.validate_python( + x # type: ignore[arg-type] + ), + lambda x: Emit_ErrorErrorsTypeAdapter.validate_python( + x # type: ignore[arg-type] + ), + ) diff --git a/tests/codegen/snapshot/snapshots/test_basic_stream/test_service/emit_error.py b/tests/codegen/snapshot/snapshots/test_basic_stream/test_service/emit_error.py new file mode 100644 index 00000000..ddba3a38 --- /dev/null +++ b/tests/codegen/snapshot/snapshots/test_basic_stream/test_service/emit_error.py @@ -0,0 +1,45 @@ +# Code generated by river.codegen. DO NOT EDIT. +from collections.abc import AsyncIterable, AsyncIterator +import datetime +from typing import ( + Any, + Literal, + Mapping, + NotRequired, + TypedDict, +) +from typing_extensions import Annotated + +from pydantic import BaseModel, Field, TypeAdapter, WrapValidator +from replit_river.error_schema import RiverError +from replit_river.client import ( + RiverUnknownError, + translate_unknown_error, + RiverUnknownValue, + translate_unknown_value, +) + +import replit_river as river + + +class Emit_ErrorErrorsOneOf_DATA_LOSS(RiverError): + code: Literal["DATA_LOSS"] + message: str + + +class Emit_ErrorErrorsOneOf_UNEXPECTED_DISCONNECT(RiverError): + code: Literal["UNEXPECTED_DISCONNECT"] + message: str + + +Emit_ErrorErrors = Annotated[ + Emit_ErrorErrorsOneOf_DATA_LOSS + | Emit_ErrorErrorsOneOf_UNEXPECTED_DISCONNECT + | RiverUnknownError, + WrapValidator(translate_unknown_error), +] + + +Emit_ErrorErrorsTypeAdapter: TypeAdapter[Emit_ErrorErrors] = TypeAdapter( + Emit_ErrorErrors +) From 7f98b0c99cadfca47d51a0255e41baa4b3efccb8 Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Fri, 21 Mar 2025 19:57:57 -0700 Subject: [PATCH 28/29] Only generate the snapshot code once --- tests/codegen/stream/test_stream.py | 31 ++++++++++++++++++++--------- 1 file changed, 22 insertions(+), 9 deletions(-) diff --git a/tests/codegen/stream/test_stream.py b/tests/codegen/stream/test_stream.py index 2529fd1a..a4ec5f1e 100644 --- a/tests/codegen/stream/test_stream.py +++ b/tests/codegen/stream/test_stream.py @@ -1,5 +1,5 @@ import importlib -from typing import AsyncIterable +from typing import AsyncIterable, Literal import pytest from pytest_snapshot.plugin import Snapshot @@ -8,19 +8,32 @@ from tests.codegen.snapshot.codegen_snapshot_fixtures import validate_codegen from tests.common_handlers import basic_stream +_AlreadyGenerated = False -@pytest.mark.parametrize("handlers", [{**basic_stream}]) -async def test_basic_stream(snapshot: Snapshot, client: Client) -> None: - validate_codegen( - snapshot=snapshot, - read_schema=lambda: open("tests/codegen/stream/schema.json"), - target_path="test_basic_stream", - client_name="StreamClient", - ) + +@pytest.fixture +def stream_client_codegen(snapshot: Snapshot) -> Literal[True]: + global _AlreadyGenerated + if not _AlreadyGenerated: + validate_codegen( + snapshot=snapshot, + read_schema=lambda: open("tests/codegen/stream/schema.json"), + target_path="test_basic_stream", + client_name="StreamClient", + ) + _AlreadyGenerated = True import tests.codegen.snapshot.snapshots.test_basic_stream importlib.reload(tests.codegen.snapshot.snapshots.test_basic_stream) + return True + + +@pytest.mark.parametrize("handlers", [{**basic_stream}]) +async def test_basic_stream( + stream_client_codegen: Literal[True], + client: Client, +) -> None: from tests.codegen.snapshot.snapshots.test_basic_stream import ( StreamClient, # noqa: E501 ) From 8f5fce43e80592f9d2326b03a5c9778bc38f7a7e Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Fri, 21 Mar 2025 20:07:32 -0700 Subject: [PATCH 29/29] Actually write a test for stream errors --- tests/codegen/stream/test_stream.py | 34 +++++++++++++++++++++++++++-- 1 file changed, 32 insertions(+), 2 deletions(-) diff --git a/tests/codegen/stream/test_stream.py b/tests/codegen/stream/test_stream.py index a4ec5f1e..f3966043 100644 --- a/tests/codegen/stream/test_stream.py +++ b/tests/codegen/stream/test_stream.py @@ -4,9 +4,9 @@ import pytest from pytest_snapshot.plugin import Snapshot -from replit_river.client import Client +from replit_river.client import Client, RiverUnknownError from tests.codegen.snapshot.codegen_snapshot_fixtures import validate_codegen -from tests.common_handlers import basic_stream +from tests.common_handlers import basic_stream, error_stream _AlreadyGenerated = False @@ -54,3 +54,33 @@ async def emit() -> AsyncIterable[Stream_MethodInput]: assert f"Stream response for {i}" == datum.data, f"{i} == {datum.data}" i = i + 1 assert i == 5 + + +@pytest.mark.parametrize("handlers", [{**error_stream}]) +@pytest.mark.parametrize("phase", [0, 1, 2, 3]) +async def test_error_stream( + stream_client_codegen: Literal[True], + erroringClient: Client, + phase: int, +) -> None: + from tests.codegen.snapshot.snapshots.test_basic_stream import ( + StreamClient, # noqa: E501 + ) + + async def emit() -> AsyncIterable[int]: + yield phase + + res = await StreamClient(erroringClient).test_service.emit_error(emit()) + + async for datum in res: + match phase: + case 0: + assert datum + case 1: + assert not datum + case 2: + assert not isinstance(datum, bool) + assert datum.code == "DATA_LOSS" + case 3: + assert isinstance(datum, RiverUnknownError) + assert datum.code == "UNIMPLEMENTED"