diff --git a/src/replit_river/client.py b/src/replit_river/client.py index 8598c47c..3f852d64 100644 --- a/src/replit_river/client.py +++ b/src/replit_river/client.py @@ -13,7 +13,7 @@ ) from replit_river.client_transport import ClientTransport -from replit_river.error_schema import RiverError, RiverException +from replit_river.error_schema import ERROR_CODE_UNKNOWN, RiverError, RiverException from replit_river.transport_options import ( HandshakeMetadataType, TransportOptions, @@ -37,6 +37,10 @@ class RiverUnknownValue(BaseModel): value: Any +class RiverUnknownError(RiverError): + pass + + def translate_unknown_value( value: Any, handler: Callable[[Any], Any], info: ValidationInfo ) -> Any | RiverUnknownValue: @@ -46,6 +50,21 @@ def translate_unknown_value( return RiverUnknownValue(tag="RiverUnknownValue", value=value) +def translate_unknown_error( + value: Any, handler: Callable[[Any], Any], info: ValidationInfo +) -> Any | RiverUnknownError: + try: + return handler(value) + except Exception: + if isinstance(value, dict) and "code" in value and "message" in value: + return RiverUnknownError( + code=value["code"], + message=value["message"], + ) + else: + return RiverUnknownError(code=ERROR_CODE_UNKNOWN, message="Unknown error") + + class Client(Generic[HandshakeMetadataType]): def __init__( self, diff --git a/src/replit_river/codegen/client.py b/src/replit_river/codegen/client.py index 0c837e12..453b7a53 100644 --- a/src/replit_river/codegen/client.py +++ b/src/replit_river/codegen/client.py @@ -81,7 +81,8 @@ from pydantic import BaseModel, Field, TypeAdapter, WrapValidator from replit_river.error_schema import RiverError -from replit_river.client import RiverUnknownValue, translate_unknown_value +from replit_river.client import RiverUnknownError, translate_unknown_error, \ + RiverUnknownValue, translate_unknown_value import replit_river as river @@ -154,6 +155,20 @@ def encode_type( in_module: list[ModuleName], permit_unknown_members: bool, ) -> tuple[TypeExpression, list[ModuleName], list[FileContents], set[TypeName]]: + def _make_open_union_type_expr(one_of: list[TypeExpression]) -> OpenUnionTypeExpr: + if base_model == "RiverError": + return OpenUnionTypeExpr( + UnionTypeExpr(one_of), + fallback_type="RiverUnknownError", + validator_function="translate_unknown_error", + ) + else: + return OpenUnionTypeExpr( + UnionTypeExpr(one_of), + fallback_type="RiverUnknownValue", + validator_function="translate_unknown_value", + ) + encoder_name: TypeName | None = None # defining this up here to placate mypy chunks: list[FileContents] = [] if isinstance(type, RiverNotType): @@ -304,7 +319,7 @@ def flatten_union(tpe: RiverType) -> list[RiverType]: ) union: TypeExpression if permit_unknown_members: - union = OpenUnionTypeExpr(UnionTypeExpr(one_of)) + union = _make_open_union_type_expr(one_of) else: union = UnionTypeExpr(one_of) chunks.append( @@ -383,7 +398,7 @@ def {_field_name}( ) raise ValueError(f"What does it mean to have {_o2} here?") if permit_unknown_members: - union = OpenUnionTypeExpr(UnionTypeExpr(any_of)) + union = _make_open_union_type_expr(any_of) else: union = UnionTypeExpr(any_of) if is_literal(type): @@ -795,6 +810,7 @@ def _type_adapter_definition( _type: TypeExpression, module_info: list[ModuleName], ) -> tuple[list[TypeName], list[ModuleName], list[FileContents]]: + varname = render_type_expr(type_adapter_name) rendered_type_expr = render_type_expr(_type) return ( [type_adapter_name], @@ -802,10 +818,10 @@ def _type_adapter_definition( [ FileContents( dedent(f""" - {render_type_expr(type_adapter_name)}: TypeAdapter[Any] = ( - TypeAdapter({rendered_type_expr}) - ) - """) + {varname}: TypeAdapter[{rendered_type_expr}] = ( + TypeAdapter({rendered_type_expr}) + ) + """) ) ], ) diff --git a/src/replit_river/codegen/typing.py b/src/replit_river/codegen/typing.py index 0d78005b..53c028ff 100644 --- a/src/replit_river/codegen/typing.py +++ b/src/replit_river/codegen/typing.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import NewType, assert_never +from typing import NewType, assert_never, cast ModuleName = NewType("ModuleName", str) ClassName = NewType("ClassName", str) @@ -96,6 +96,8 @@ def __lt__(self, other: object) -> bool: @dataclass(frozen=True) class OpenUnionTypeExpr: union: UnionTypeExpr + fallback_type: str + validator_function: str def __str__(self) -> str: raise Exception("Complex type must be put through render_type_expr!") @@ -182,10 +184,11 @@ def render_type_expr(value: TypeExpression) -> str: retval = "None" return retval case OpenUnionTypeExpr(inner): + open_union = cast(OpenUnionTypeExpr, value) return ( "Annotated[" - f"{render_type_expr(inner)} | RiverUnknownValue," - "WrapValidator(translate_unknown_value)" + f"{render_type_expr(inner)} | {open_union.fallback_type}," + f"WrapValidator({open_union.validator_function})" "]" ) case TypeName(name): diff --git a/src/replit_river/error_schema.py b/src/replit_river/error_schema.py index 6aae861d..af5837dd 100644 --- a/src/replit_river/error_schema.py +++ b/src/replit_river/error_schema.py @@ -17,6 +17,9 @@ # ERROR_CODE_CANCEL is the code used when either server or client cancels the stream. ERROR_CODE_CANCEL = "CANCEL" +# ERROR_CODE_UNKNOWN is the code for the RiverUnknownError +ERROR_CODE_UNKNOWN = "UNKNOWN" + class RiverError(BaseModel): """Error message from the server.""" diff --git a/tests/codegen/rpc/generated/test_service/rpc_method.py b/tests/codegen/rpc/generated/test_service/rpc_method.py index f7dff38d..dfe8a47c 100644 --- a/tests/codegen/rpc/generated/test_service/rpc_method.py +++ b/tests/codegen/rpc/generated/test_service/rpc_method.py @@ -12,7 +12,12 @@ from pydantic import BaseModel, Field, TypeAdapter, WrapValidator from replit_river.error_schema import RiverError -from replit_river.client import RiverUnknownValue, translate_unknown_value +from replit_river.client import ( + RiverUnknownError, + translate_unknown_error, + RiverUnknownValue, + translate_unknown_value, +) import replit_river as river @@ -35,11 +40,13 @@ class Rpc_MethodInput(TypedDict): data: str -Rpc_MethodInputTypeAdapter: TypeAdapter[Any] = TypeAdapter(Rpc_MethodInput) +Rpc_MethodInputTypeAdapter: TypeAdapter[Rpc_MethodInput] = TypeAdapter(Rpc_MethodInput) class Rpc_MethodOutput(BaseModel): data: str -Rpc_MethodOutputTypeAdapter: TypeAdapter[Any] = TypeAdapter(Rpc_MethodOutput) +Rpc_MethodOutputTypeAdapter: TypeAdapter[Rpc_MethodOutput] = TypeAdapter( + Rpc_MethodOutput +) diff --git a/tests/codegen/snapshot/snapshots/test_basic_stream/test_service/__init__.py b/tests/codegen/snapshot/snapshots/test_basic_stream/test_service/__init__.py index 4133aea3..44d6c18c 100644 --- a/tests/codegen/snapshot/snapshots/test_basic_stream/test_service/__init__.py +++ b/tests/codegen/snapshot/snapshots/test_basic_stream/test_service/__init__.py @@ -16,6 +16,12 @@ Stream_MethodOutputTypeAdapter, encode_Stream_MethodInput, ) +from .emit_error import Emit_ErrorErrors, Emit_ErrorErrorsTypeAdapter + +intTypeAdapter: TypeAdapter[int] = TypeAdapter(int) + + +boolTypeAdapter: TypeAdapter[bool] = TypeAdapter(bool) class Test_ServiceService: @@ -40,3 +46,22 @@ async def stream_method( x # type: ignore[arg-type] ), ) + + async def emit_error( + self, + inputStream: AsyncIterable[int], + ) -> AsyncIterator[bool | Emit_ErrorErrors | RiverError]: + return self.client.send_stream( + "test_service", + "emit_error", + None, + inputStream, + None, + lambda x: x, + lambda x: boolTypeAdapter.validate_python( + x # type: ignore[arg-type] + ), + lambda x: Emit_ErrorErrorsTypeAdapter.validate_python( + x # type: ignore[arg-type] + ), + ) diff --git a/tests/codegen/snapshot/snapshots/test_basic_stream/test_service/emit_error.py b/tests/codegen/snapshot/snapshots/test_basic_stream/test_service/emit_error.py new file mode 100644 index 00000000..ddba3a38 --- /dev/null +++ b/tests/codegen/snapshot/snapshots/test_basic_stream/test_service/emit_error.py @@ -0,0 +1,45 @@ +# Code generated by river.codegen. DO NOT EDIT. +from collections.abc import AsyncIterable, AsyncIterator +import datetime +from typing import ( + Any, + Literal, + Mapping, + NotRequired, + TypedDict, +) +from typing_extensions import Annotated + +from pydantic import BaseModel, Field, TypeAdapter, WrapValidator +from replit_river.error_schema import RiverError +from replit_river.client import ( + RiverUnknownError, + translate_unknown_error, + RiverUnknownValue, + translate_unknown_value, +) + +import replit_river as river + + +class Emit_ErrorErrorsOneOf_DATA_LOSS(RiverError): + code: Literal["DATA_LOSS"] + message: str + + +class Emit_ErrorErrorsOneOf_UNEXPECTED_DISCONNECT(RiverError): + code: Literal["UNEXPECTED_DISCONNECT"] + message: str + + +Emit_ErrorErrors = Annotated[ + Emit_ErrorErrorsOneOf_DATA_LOSS + | Emit_ErrorErrorsOneOf_UNEXPECTED_DISCONNECT + | RiverUnknownError, + WrapValidator(translate_unknown_error), +] + + +Emit_ErrorErrorsTypeAdapter: TypeAdapter[Emit_ErrorErrors] = TypeAdapter( + Emit_ErrorErrors +) diff --git a/tests/codegen/snapshot/snapshots/test_basic_stream/test_service/stream_method.py b/tests/codegen/snapshot/snapshots/test_basic_stream/test_service/stream_method.py index eff77816..5baa9c40 100644 --- a/tests/codegen/snapshot/snapshots/test_basic_stream/test_service/stream_method.py +++ b/tests/codegen/snapshot/snapshots/test_basic_stream/test_service/stream_method.py @@ -12,7 +12,12 @@ from pydantic import BaseModel, Field, TypeAdapter, WrapValidator from replit_river.error_schema import RiverError -from replit_river.client import RiverUnknownValue, translate_unknown_value +from replit_river.client import ( + RiverUnknownError, + translate_unknown_error, + RiverUnknownValue, + translate_unknown_value, +) import replit_river as river @@ -35,11 +40,15 @@ class Stream_MethodInput(TypedDict): data: str -Stream_MethodInputTypeAdapter: TypeAdapter[Any] = TypeAdapter(Stream_MethodInput) +Stream_MethodInputTypeAdapter: TypeAdapter[Stream_MethodInput] = TypeAdapter( + Stream_MethodInput +) class Stream_MethodOutput(BaseModel): data: str -Stream_MethodOutputTypeAdapter: TypeAdapter[Any] = TypeAdapter(Stream_MethodOutput) +Stream_MethodOutputTypeAdapter: TypeAdapter[Stream_MethodOutput] = TypeAdapter( + Stream_MethodOutput +) diff --git a/tests/codegen/snapshot/snapshots/test_pathological_types/test_service/__init__.py b/tests/codegen/snapshot/snapshots/test_pathological_types/test_service/__init__.py index bc60d38a..3a578118 100644 --- a/tests/codegen/snapshot/snapshots/test_pathological_types/test_service/__init__.py +++ b/tests/codegen/snapshot/snapshots/test_pathological_types/test_service/__init__.py @@ -31,7 +31,7 @@ encode_Pathological_MethodInputReq_Obj_Undefined, ) -boolTypeAdapter: TypeAdapter[Any] = TypeAdapter(bool) +boolTypeAdapter: TypeAdapter[bool] = TypeAdapter(bool) class Test_ServiceService: diff --git a/tests/codegen/snapshot/snapshots/test_pathological_types/test_service/pathological_method.py b/tests/codegen/snapshot/snapshots/test_pathological_types/test_service/pathological_method.py index a9b530f5..137add7b 100644 --- a/tests/codegen/snapshot/snapshots/test_pathological_types/test_service/pathological_method.py +++ b/tests/codegen/snapshot/snapshots/test_pathological_types/test_service/pathological_method.py @@ -12,7 +12,12 @@ from pydantic import BaseModel, Field, TypeAdapter, WrapValidator from replit_river.error_schema import RiverError -from replit_river.client import RiverUnknownValue, translate_unknown_value +from replit_river.client import ( + RiverUnknownError, + translate_unknown_error, + RiverUnknownValue, + translate_unknown_value, +) import replit_river as river @@ -473,6 +478,6 @@ class Pathological_MethodInput(TypedDict): undefined: NotRequired[None] -Pathological_MethodInputTypeAdapter: TypeAdapter[Any] = TypeAdapter( - Pathological_MethodInput +Pathological_MethodInputTypeAdapter: TypeAdapter[Pathological_MethodInput] = ( + TypeAdapter(Pathological_MethodInput) ) diff --git a/tests/codegen/snapshot/snapshots/test_unknown_enum/enumService/needsEnum.py b/tests/codegen/snapshot/snapshots/test_unknown_enum/enumService/needsEnum.py index 0204f9c2..8f325775 100644 --- a/tests/codegen/snapshot/snapshots/test_unknown_enum/enumService/needsEnum.py +++ b/tests/codegen/snapshot/snapshots/test_unknown_enum/enumService/needsEnum.py @@ -12,7 +12,12 @@ from pydantic import BaseModel, Field, TypeAdapter, WrapValidator from replit_river.error_schema import RiverError -from replit_river.client import RiverUnknownValue, translate_unknown_value +from replit_river.client import ( + RiverUnknownError, + translate_unknown_error, + RiverUnknownValue, + translate_unknown_value, +) import replit_river as river @@ -24,18 +29,32 @@ def encode_NeedsenumInput(x: "NeedsenumInput") -> Any: return x -NeedsenumInputTypeAdapter: TypeAdapter[Any] = TypeAdapter(NeedsenumInput) +NeedsenumInputTypeAdapter: TypeAdapter[NeedsenumInput] = TypeAdapter(NeedsenumInput) NeedsenumOutput = Annotated[ Literal["out_first", "out_second"] | RiverUnknownValue, WrapValidator(translate_unknown_value), ] -NeedsenumOutputTypeAdapter: TypeAdapter[Any] = TypeAdapter(NeedsenumOutput) +NeedsenumOutputTypeAdapter: TypeAdapter[NeedsenumOutput] = TypeAdapter(NeedsenumOutput) + + +class NeedsenumErrorsOneOf_err_first(RiverError): + code: Literal["err_first"] + message: str + + +class NeedsenumErrorsOneOf_err_second(RiverError): + code: Literal["err_second"] + message: str + NeedsenumErrors = Annotated[ - Literal["err_first", "err_second"] | RiverUnknownValue, - WrapValidator(translate_unknown_value), + NeedsenumErrorsOneOf_err_first + | NeedsenumErrorsOneOf_err_second + | RiverUnknownError, + WrapValidator(translate_unknown_error), ] -NeedsenumErrorsTypeAdapter: TypeAdapter[Any] = TypeAdapter(NeedsenumErrors) + +NeedsenumErrorsTypeAdapter: TypeAdapter[NeedsenumErrors] = TypeAdapter(NeedsenumErrors) diff --git a/tests/codegen/snapshot/snapshots/test_unknown_enum/enumService/needsEnumObject.py b/tests/codegen/snapshot/snapshots/test_unknown_enum/enumService/needsEnumObject.py index 97559be3..9d74699a 100644 --- a/tests/codegen/snapshot/snapshots/test_unknown_enum/enumService/needsEnumObject.py +++ b/tests/codegen/snapshot/snapshots/test_unknown_enum/enumService/needsEnumObject.py @@ -12,7 +12,12 @@ from pydantic import BaseModel, Field, TypeAdapter, WrapValidator from replit_river.error_schema import RiverError -from replit_river.client import RiverUnknownValue, translate_unknown_value +from replit_river.client import ( + RiverUnknownError, + translate_unknown_error, + RiverUnknownValue, + translate_unknown_value, +) import replit_river as river @@ -72,7 +77,9 @@ def encode_NeedsenumobjectInput( ) -NeedsenumobjectInputTypeAdapter: TypeAdapter[Any] = TypeAdapter(NeedsenumobjectInput) +NeedsenumobjectInputTypeAdapter: TypeAdapter[NeedsenumobjectInput] = TypeAdapter( + NeedsenumobjectInput +) class NeedsenumobjectOutputFooOneOf_out_first(BaseModel): @@ -105,7 +112,9 @@ class NeedsenumobjectOutput(BaseModel): foo: NeedsenumobjectOutputFoo | None = None -NeedsenumobjectOutputTypeAdapter: TypeAdapter[Any] = TypeAdapter(NeedsenumobjectOutput) +NeedsenumobjectOutputTypeAdapter: TypeAdapter[NeedsenumobjectOutput] = TypeAdapter( + NeedsenumobjectOutput +) class NeedsenumobjectErrorsFooAnyOf_0(RiverError): @@ -119,8 +128,8 @@ class NeedsenumobjectErrorsFooAnyOf_1(RiverError): NeedsenumobjectErrorsFoo = Annotated[ NeedsenumobjectErrorsFooAnyOf_0 | NeedsenumobjectErrorsFooAnyOf_1 - | RiverUnknownValue, - WrapValidator(translate_unknown_value), + | RiverUnknownError, + WrapValidator(translate_unknown_error), ] @@ -128,4 +137,6 @@ class NeedsenumobjectErrors(RiverError): foo: NeedsenumobjectErrorsFoo | None = None -NeedsenumobjectErrorsTypeAdapter: TypeAdapter[Any] = TypeAdapter(NeedsenumobjectErrors) +NeedsenumobjectErrorsTypeAdapter: TypeAdapter[NeedsenumobjectErrors] = TypeAdapter( + NeedsenumobjectErrors +) diff --git a/tests/codegen/snapshot/test_enum.py b/tests/codegen/snapshot/test_enum.py index b24fbc20..d9953dee 100644 --- a/tests/codegen/snapshot/test_enum.py +++ b/tests/codegen/snapshot/test_enum.py @@ -1,3 +1,4 @@ +import importlib from io import StringIO from pytest_snapshot.plugin import Snapshot @@ -38,12 +39,30 @@ "errors": { "anyOf": [ { - "type": "string", - "const": "err_first" + "type": "object", + "properties": { + "code": { + "const": "err_first", + "type": "string" + }, + "message": { + "type": "string" + } + }, + "required": ["code", "message"] }, { - "type": "string", - "const": "err_second" + "type": "object", + "properties": { + "code": { + "const": "err_second", + "type": "string" + }, + "message": { + "type": "string" + } + }, + "required": ["code", "message"] } ] } @@ -157,3 +176,30 @@ def test_unknown_enum(snapshot: Snapshot) -> None: target_path="test_unknown_enum", client_name="foo", ) + + import tests.codegen.snapshot.snapshots.test_unknown_enum + + importlib.reload(tests.codegen.snapshot.snapshots.test_unknown_enum) + from tests.codegen.snapshot.snapshots.test_unknown_enum.enumService.needsEnum import ( # noqa + NeedsenumErrorsTypeAdapter, + ) + + payloads: list[dict[str, str]] = [ + { + "code": "err_first", + "message": "This is a message", + }, + { + "code": "err_second", + "message": "This is a message", + }, + { + "code": "unknown_error", + "message": "This is new!", + }, + ] + + for error in payloads: + x = NeedsenumErrorsTypeAdapter.validate_python(error) + assert x.code == error["code"] + assert x.message == error["message"] diff --git a/tests/codegen/stream/schema.json b/tests/codegen/stream/schema.json index c7aeebcb..d24df12c 100644 --- a/tests/codegen/stream/schema.json +++ b/tests/codegen/stream/schema.json @@ -25,6 +25,45 @@ "not": {} }, "type": "stream" + }, + "emit_error": { + "input": { + "type": "integer" + }, + "output": { + "type": "boolean" + }, + "errors": { + "anyOf": [ + { + "type": "object", + "properties": { + "code": { + "const": "DATA_LOSS", + "type": "string" + }, + "message": { + "type": "string" + } + }, + "required": ["code", "message"] + }, + { + "type": "object", + "properties": { + "code": { + "const": "UNEXPECTED_DISCONNECT", + "type": "string" + }, + "message": { + "type": "string" + } + }, + "required": ["code", "message"] + } + ] + }, + "type": "stream" } } } diff --git a/tests/codegen/stream/test_stream.py b/tests/codegen/stream/test_stream.py index 2529fd1a..f3966043 100644 --- a/tests/codegen/stream/test_stream.py +++ b/tests/codegen/stream/test_stream.py @@ -1,26 +1,39 @@ import importlib -from typing import AsyncIterable +from typing import AsyncIterable, Literal import pytest from pytest_snapshot.plugin import Snapshot -from replit_river.client import Client +from replit_river.client import Client, RiverUnknownError from tests.codegen.snapshot.codegen_snapshot_fixtures import validate_codegen -from tests.common_handlers import basic_stream +from tests.common_handlers import basic_stream, error_stream +_AlreadyGenerated = False -@pytest.mark.parametrize("handlers", [{**basic_stream}]) -async def test_basic_stream(snapshot: Snapshot, client: Client) -> None: - validate_codegen( - snapshot=snapshot, - read_schema=lambda: open("tests/codegen/stream/schema.json"), - target_path="test_basic_stream", - client_name="StreamClient", - ) + +@pytest.fixture +def stream_client_codegen(snapshot: Snapshot) -> Literal[True]: + global _AlreadyGenerated + if not _AlreadyGenerated: + validate_codegen( + snapshot=snapshot, + read_schema=lambda: open("tests/codegen/stream/schema.json"), + target_path="test_basic_stream", + client_name="StreamClient", + ) + _AlreadyGenerated = True import tests.codegen.snapshot.snapshots.test_basic_stream importlib.reload(tests.codegen.snapshot.snapshots.test_basic_stream) + return True + + +@pytest.mark.parametrize("handlers", [{**basic_stream}]) +async def test_basic_stream( + stream_client_codegen: Literal[True], + client: Client, +) -> None: from tests.codegen.snapshot.snapshots.test_basic_stream import ( StreamClient, # noqa: E501 ) @@ -41,3 +54,33 @@ async def emit() -> AsyncIterable[Stream_MethodInput]: assert f"Stream response for {i}" == datum.data, f"{i} == {datum.data}" i = i + 1 assert i == 5 + + +@pytest.mark.parametrize("handlers", [{**error_stream}]) +@pytest.mark.parametrize("phase", [0, 1, 2, 3]) +async def test_error_stream( + stream_client_codegen: Literal[True], + erroringClient: Client, + phase: int, +) -> None: + from tests.codegen.snapshot.snapshots.test_basic_stream import ( + StreamClient, # noqa: E501 + ) + + async def emit() -> AsyncIterable[int]: + yield phase + + res = await StreamClient(erroringClient).test_service.emit_error(emit()) + + async for datum in res: + match phase: + case 0: + assert datum + case 1: + assert not datum + case 2: + assert not isinstance(datum, bool) + assert datum.code == "DATA_LOSS" + case 3: + assert isinstance(datum, RiverUnknownError) + assert datum.code == "UNIMPLEMENTED" diff --git a/tests/common_handlers.py b/tests/common_handlers.py index 19a5e2fa..631038f6 100644 --- a/tests/common_handlers.py +++ b/tests/common_handlers.py @@ -80,3 +80,51 @@ async def stream_handler( stream_method_handler(stream_handler, deserialize_request, serialize_response), ), } + + +async def stream_error( + request: Iterator[int] | AsyncIterator[int], + context: grpc.aio.ServicerContext, +) -> AsyncGenerator[bool, None]: + if isinstance(request, AsyncIterator): + async for data in request: + match data % 4: + case 0: + yield True + case 1: + yield False + case 2: + await context.abort( + grpc.StatusCode.DATA_LOSS, + details="We know about the Data Loss error code", + ) + case 3: + await context.abort( + grpc.StatusCode.UNIMPLEMENTED, + details="This is a completely unknown error code", + ) + else: + for data in request: + match data % 4: + case 0: + yield True + case 1: + yield False + case 2: + await context.abort( + grpc.StatusCode.DATA_LOSS, + details="We know about the Data Loss error code", + ) + case 3: + await context.abort( + grpc.StatusCode.UNIMPLEMENTED, + details="This is a completely unknown error code", + ) + + +error_stream: HandlerMapping = { + ("test_service", "emit_error"): ( + "stream", + stream_method_handler(stream_error, lambda x: x, lambda x: x), + ), +} diff --git a/tests/river_fixtures/clientserver.py b/tests/river_fixtures/clientserver.py index cf1e1e29..bf576b5c 100644 --- a/tests/river_fixtures/clientserver.py +++ b/tests/river_fixtures/clientserver.py @@ -32,7 +32,7 @@ def server( @pytest.fixture -async def client( +async def erroringClient( server: Server, transport_options: TransportOptions, no_logging_error: NoErrors, @@ -69,5 +69,13 @@ async def websocket_uri_factory() -> UriAndMetadata[None]: await server.close() if binding: await binding.wait_closed() - # Server should close normally - no_logging_error() + + +@pytest.fixture +async def client( + erroringClient: Client, + no_logging_error: NoErrors, +) -> AsyncGenerator[Client, None]: + yield erroringClient + # Server should close normally + no_logging_error()