diff --git a/src/replit_river/client.py b/src/replit_river/client.py index c78fb171..51e45a2d 100644 --- a/src/replit_river/client.py +++ b/src/replit_river/client.py @@ -7,6 +7,10 @@ from opentelemetry import trace from opentelemetry.trace import Span, SpanKind, Status, StatusCode +from pydantic import ( + BaseModel, + ValidationInfo, +) from replit_river.client_transport import ClientTransport from replit_river.error_schema import RiverError, RiverException @@ -27,6 +31,21 @@ tracer = trace.get_tracer(__name__) +@dataclass(frozen=True) +class RiverUnknownValue(BaseModel): + tag: Literal["RiverUnknownValue"] + value: Any + + +def translate_unknown_value( + value: Any, handler: Callable[[Any], Any], info: ValidationInfo +) -> Any | RiverUnknownValue: + try: + return handler(value) + except Exception: + return RiverUnknownValue(tag="RiverUnknownValue", value=value) + + class Client(Generic[HandshakeMetadataType]): def __init__( self, diff --git a/src/replit_river/codegen/client.py b/src/replit_river/codegen/client.py index c58f745c..f7271dc7 100644 --- a/src/replit_river/codegen/client.py +++ b/src/replit_river/codegen/client.py @@ -30,11 +30,11 @@ ListTypeExpr, LiteralTypeExpr, ModuleName, + OpenUnionTypeExpr, RenderedPath, TypeExpression, TypeName, UnionTypeExpr, - UnknownTypeExpr, ensure_literal_type, extract_inner_type, render_type_expr, @@ -83,15 +83,16 @@ Literal, Optional, Mapping, - NewType, NotRequired, Union, Tuple, TypedDict, ) +from typing_extensions import Annotated -from pydantic import BaseModel, Field, TypeAdapter +from pydantic import BaseModel, Field, TypeAdapter, WrapValidator from replit_river.error_schema import RiverError +from replit_river.client import RiverUnknownValue, translate_unknown_value import replit_river as river @@ -311,19 +312,12 @@ def flatten_union(tpe: RiverType) -> list[RiverType]: else """, ) + union: TypeExpression if permit_unknown_members: - unknown_name = TypeName(f"{prefix}AnyOf__Unknown") - chunks.append( - FileContents( - f"{unknown_name} = NewType({repr(unknown_name)}, object)" - ) - ) - one_of.append(UnknownTypeExpr(unknown_name)) - chunks.append( - FileContents( - f"{prefix} = {render_type_expr(UnionTypeExpr(one_of))}" - ) - ) + union = OpenUnionTypeExpr(UnionTypeExpr(one_of)) + else: + union = UnionTypeExpr(one_of) + chunks.append(FileContents(f"{prefix} = {render_type_expr(union)}")) chunks.append(FileContents("")) if base_model == "TypedDict": @@ -386,16 +380,12 @@ def flatten_union(tpe: RiverType) -> list[RiverType]: f"encode_{ensure_literal_type(other)}(x)" ) if permit_unknown_members: - unknown_name = TypeName(f"{prefix}AnyOf__Unknown") - chunks.append( - FileContents(f"{unknown_name} = NewType({repr(unknown_name)}, object)") - ) - any_of.append(UnknownTypeExpr(unknown_name)) + union = OpenUnionTypeExpr(UnionTypeExpr(any_of)) + else: + union = UnionTypeExpr(any_of) if is_literal(type): typeddict_encoder = ["x"] - chunks.append( - FileContents(f"{prefix} = {render_type_expr(UnionTypeExpr(any_of))}") - ) + chunks.append(FileContents(f"{prefix} = {render_type_expr(union)}")) if base_model == "TypedDict": encoder_name = TypeName(f"encode_{prefix}") encoder_names.add(encoder_name) diff --git a/src/replit_river/codegen/run.py b/src/replit_river/codegen/run.py index dd041483..c9ab6384 100644 --- a/src/replit_river/codegen/run.py +++ b/src/replit_river/codegen/run.py @@ -30,7 +30,7 @@ def main() -> None: client = subparsers.add_parser( "client", help="Codegen a River client from JSON schema" ) - client.add_argument("--output", help="output file", required=True) + client.add_argument("--output", help="output path", required=True) client.add_argument("--client-name", help="name of the class", required=True) client.add_argument( "--typed-dict-inputs", diff --git a/src/replit_river/codegen/typing.py b/src/replit_river/codegen/typing.py index af9ae7fa..87c577bf 100644 --- a/src/replit_river/codegen/typing.py +++ b/src/replit_river/codegen/typing.py @@ -31,8 +31,8 @@ class UnionTypeExpr: @dataclass -class UnknownTypeExpr: - name: TypeName +class OpenUnionTypeExpr: + union: UnionTypeExpr TypeExpression = ( @@ -41,7 +41,7 @@ class UnknownTypeExpr: | ListTypeExpr | LiteralTypeExpr | UnionTypeExpr - | UnknownTypeExpr + | OpenUnionTypeExpr ) @@ -55,10 +55,15 @@ def render_type_expr(value: TypeExpression) -> str: return f"Literal[{repr(inner)}]" case UnionTypeExpr(inner): return " | ".join(render_type_expr(x) for x in inner) + case OpenUnionTypeExpr(inner): + return ( + "Annotated[" + f"{render_type_expr(inner)} | RiverUnknownValue," + "WrapValidator(translate_unknown_value)" + "]" + ) case str(name): return TypeName(name) - case UnknownTypeExpr(name): - return TypeName(name) case other: assert_never(other) @@ -75,10 +80,12 @@ def extract_inner_type(value: TypeExpression) -> TypeName: raise ValueError( f"Attempting to extract from a union, currently not possible: {value}" ) + case OpenUnionTypeExpr(_): + raise ValueError( + f"Attempting to extract from a union, currently not possible: {value}" + ) case str(name): return TypeName(name) - case UnknownTypeExpr(name): - return name case other: assert_never(other) @@ -101,9 +108,11 @@ def ensure_literal_type(value: TypeExpression) -> TypeName: raise ValueError( f"Unexpected expression when expecting a type name: {value}" ) + case OpenUnionTypeExpr(_): + raise ValueError( + f"Unexpected expression when expecting a type name: {value}" + ) case str(name): return TypeName(name) - case UnknownTypeExpr(name): - return name case other: assert_never(other) diff --git a/tests/codegen/rpc/generated/test_service/__init__.py b/tests/codegen/rpc/generated/test_service/__init__.py index 8012ee6b..fc994615 100644 --- a/tests/codegen/rpc/generated/test_service/__init__.py +++ b/tests/codegen/rpc/generated/test_service/__init__.py @@ -9,7 +9,7 @@ import replit_river as river -from .rpc_method import encode_Rpc_MethodInput, Rpc_MethodInput, Rpc_MethodOutput +from .rpc_method import Rpc_MethodInput, Rpc_MethodOutput, encode_Rpc_MethodInput class Test_ServiceService: diff --git a/tests/codegen/rpc/generated/test_service/rpc_method.py b/tests/codegen/rpc/generated/test_service/rpc_method.py index d6dc64f4..263c36b0 100644 --- a/tests/codegen/rpc/generated/test_service/rpc_method.py +++ b/tests/codegen/rpc/generated/test_service/rpc_method.py @@ -10,15 +10,16 @@ Literal, Optional, Mapping, - NewType, NotRequired, Union, Tuple, TypedDict, ) +from typing_extensions import Annotated -from pydantic import BaseModel, Field, TypeAdapter +from pydantic import BaseModel, Field, TypeAdapter, WrapValidator from replit_river.error_schema import RiverError +from replit_river.client import RiverUnknownValue, translate_unknown_value import replit_river as river diff --git a/tests/codegen/snapshot/snapshots/test_unknown_enum/enumService/needsEnum.py b/tests/codegen/snapshot/snapshots/test_unknown_enum/enumService/needsEnum.py index 4633dfd6..426049b9 100644 --- a/tests/codegen/snapshot/snapshots/test_unknown_enum/enumService/needsEnum.py +++ b/tests/codegen/snapshot/snapshots/test_unknown_enum/enumService/needsEnum.py @@ -10,26 +10,27 @@ Literal, Optional, Mapping, - NewType, NotRequired, Union, Tuple, TypedDict, ) +from typing_extensions import Annotated -from pydantic import BaseModel, Field, TypeAdapter +from pydantic import BaseModel, Field, TypeAdapter, WrapValidator from replit_river.error_schema import RiverError +from replit_river.client import RiverUnknownValue, translate_unknown_value import replit_river as river NeedsenumInput = Literal["in_first"] | Literal["in_second"] encode_NeedsenumInput: Callable[["NeedsenumInput"], Any] = lambda x: x -NeedsenumOutputAnyOf__Unknown = NewType("NeedsenumOutputAnyOf__Unknown", object) -NeedsenumOutput = ( - Literal["out_first"] | Literal["out_second"] | NeedsenumOutputAnyOf__Unknown -) -NeedsenumErrorsAnyOf__Unknown = NewType("NeedsenumErrorsAnyOf__Unknown", object) -NeedsenumErrors = ( - Literal["err_first"] | Literal["err_second"] | NeedsenumErrorsAnyOf__Unknown -) +NeedsenumOutput = Annotated[ + Literal["out_first"] | Literal["out_second"] | RiverUnknownValue, + WrapValidator(translate_unknown_value), +] +NeedsenumErrors = Annotated[ + Literal["err_first"] | Literal["err_second"] | RiverUnknownValue, + WrapValidator(translate_unknown_value), +] diff --git a/tests/codegen/snapshot/snapshots/test_unknown_enum/enumService/needsEnumObject.py b/tests/codegen/snapshot/snapshots/test_unknown_enum/enumService/needsEnumObject.py index 2fe76987..6766eb79 100644 --- a/tests/codegen/snapshot/snapshots/test_unknown_enum/enumService/needsEnumObject.py +++ b/tests/codegen/snapshot/snapshots/test_unknown_enum/enumService/needsEnumObject.py @@ -10,15 +10,16 @@ Literal, Optional, Mapping, - NewType, NotRequired, Union, Tuple, TypedDict, ) +from typing_extensions import Annotated -from pydantic import BaseModel, Field, TypeAdapter +from pydantic import BaseModel, Field, TypeAdapter, WrapValidator from replit_river.error_schema import RiverError +from replit_river.client import RiverUnknownValue, translate_unknown_value import replit_river as river @@ -90,14 +91,12 @@ class NeedsenumobjectOutputFooOneOf_out_second(BaseModel): bar: int -NeedsenumobjectOutputFooAnyOf__Unknown = NewType( - "NeedsenumobjectOutputFooAnyOf__Unknown", object -) -NeedsenumobjectOutputFoo = ( +NeedsenumobjectOutputFoo = Annotated[ NeedsenumobjectOutputFooOneOf_out_first | NeedsenumobjectOutputFooOneOf_out_second - | NeedsenumobjectOutputFooAnyOf__Unknown -) + | RiverUnknownValue, + WrapValidator(translate_unknown_value), +] class NeedsenumobjectOutput(BaseModel): @@ -112,14 +111,12 @@ class NeedsenumobjectErrorsFooAnyOf_1(RiverError): borp: Optional[Literal["err_second"]] = None -NeedsenumobjectErrorsFooAnyOf__Unknown = NewType( - "NeedsenumobjectErrorsFooAnyOf__Unknown", object -) -NeedsenumobjectErrorsFoo = ( +NeedsenumobjectErrorsFoo = Annotated[ NeedsenumobjectErrorsFooAnyOf_0 | NeedsenumobjectErrorsFooAnyOf_1 - | NeedsenumobjectErrorsFooAnyOf__Unknown -) + | RiverUnknownValue, + WrapValidator(translate_unknown_value), +] class NeedsenumobjectErrors(RiverError): diff --git a/tests/codegen/stream/generated/test_service/__init__.py b/tests/codegen/stream/generated/test_service/__init__.py index 00f9c553..f59c046e 100644 --- a/tests/codegen/stream/generated/test_service/__init__.py +++ b/tests/codegen/stream/generated/test_service/__init__.py @@ -10,9 +10,9 @@ from .stream_method import ( - encode_Stream_MethodInput, - Stream_MethodOutput, Stream_MethodInput, + Stream_MethodOutput, + encode_Stream_MethodInput, ) diff --git a/tests/codegen/stream/generated/test_service/stream_method.py b/tests/codegen/stream/generated/test_service/stream_method.py index d66aff55..53267467 100644 --- a/tests/codegen/stream/generated/test_service/stream_method.py +++ b/tests/codegen/stream/generated/test_service/stream_method.py @@ -10,15 +10,16 @@ Literal, Optional, Mapping, - NewType, NotRequired, Union, Tuple, TypedDict, ) +from typing_extensions import Annotated -from pydantic import BaseModel, Field, TypeAdapter +from pydantic import BaseModel, Field, TypeAdapter, WrapValidator from replit_river.error_schema import RiverError +from replit_river.client import RiverUnknownValue, translate_unknown_value import replit_river as river