From d177703f25c67affdc0995e856b64b6d5f28cbcb Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Mon, 17 Mar 2025 16:10:21 -0700 Subject: [PATCH 01/14] Moving generated code --- .../snapshots/test_basic_stream}/__init__.py | 0 .../snapshots/test_basic_stream}/test_service/__init__.py | 0 .../snapshots/test_basic_stream}/test_service/stream_method.py | 0 3 files changed, 0 insertions(+), 0 deletions(-) rename tests/codegen/{stream/generated => snapshot/snapshots/test_basic_stream}/__init__.py (100%) rename tests/codegen/{stream/generated => snapshot/snapshots/test_basic_stream}/test_service/__init__.py (100%) rename tests/codegen/{stream/generated => snapshot/snapshots/test_basic_stream}/test_service/stream_method.py (100%) diff --git a/tests/codegen/stream/generated/__init__.py b/tests/codegen/snapshot/snapshots/test_basic_stream/__init__.py similarity index 100% rename from tests/codegen/stream/generated/__init__.py rename to tests/codegen/snapshot/snapshots/test_basic_stream/__init__.py diff --git a/tests/codegen/stream/generated/test_service/__init__.py b/tests/codegen/snapshot/snapshots/test_basic_stream/test_service/__init__.py similarity index 100% rename from tests/codegen/stream/generated/test_service/__init__.py rename to tests/codegen/snapshot/snapshots/test_basic_stream/test_service/__init__.py diff --git a/tests/codegen/stream/generated/test_service/stream_method.py b/tests/codegen/snapshot/snapshots/test_basic_stream/test_service/stream_method.py similarity index 100% rename from tests/codegen/stream/generated/test_service/stream_method.py rename to tests/codegen/snapshot/snapshots/test_basic_stream/test_service/stream_method.py From 5cd0efc179479d3ee394ac1f412c1643aef70af5 Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Mon, 17 Mar 2025 16:21:47 -0700 Subject: [PATCH 02/14] Switching janky implementation of pytest-snapshot over to just use pytest-snapshot --- tests/codegen/stream/test_stream.py | 49 ++++++++++++++++------------- 1 file changed, 28 insertions(+), 21 deletions(-) diff --git a/tests/codegen/stream/test_stream.py b/tests/codegen/stream/test_stream.py index b093a297..c03940b4 100644 --- a/tests/codegen/stream/test_stream.py +++ b/tests/codegen/stream/test_stream.py @@ -1,42 +1,49 @@ import importlib -import shutil from pathlib import Path from typing import AsyncIterable, TextIO import pytest +from pytest_snapshot.plugin import Snapshot from replit_river.client import Client from replit_river.codegen.client import schema_to_river_client_codegen -from tests.codegen.stream.generated.test_service.stream_method import ( - Stream_MethodInput, - Stream_MethodOutput, -) +from tests.codegen.snapshot.test_enum import UnclosableStringIO from tests.common_handlers import basic_stream -@pytest.fixture(scope="session", autouse=True) -def generate_stream_client() -> None: - import tests.codegen.stream.generated - - shutil.rmtree("tests/codegen/stream/generated") +@pytest.mark.parametrize("handlers", [{**basic_stream}]) +async def test_basic_stream(snapshot: Snapshot, client: Client) -> None: + snapshot.snapshot_dir = "tests/codegen/snapshot/snapshots" + files: dict[Path, UnclosableStringIO] = {} def file_opener(path: Path) -> TextIO: - return open(path, "w") + buffer = UnclosableStringIO() + assert path not in files, "Codegen attempted to write to the same file twice!" + files[path] = buffer + return buffer schema_to_river_client_codegen( - lambda: open("tests/codegen/stream/schema.json"), - "tests/codegen/stream/generated", - "StreamClient", - True, - file_opener, + read_schema=lambda: open("tests/codegen/stream/schema.json"), + target_path="test_basic_stream", + client_name="StreamClient", + file_opener=file_opener, + typed_dict_inputs=True, ) - importlib.reload(tests.codegen.stream.generated) + for path, file in files.items(): + file.seek(0) + snapshot.assert_match(file.read(), Path(snapshot.snapshot_dir, path)) -@pytest.mark.asyncio -@pytest.mark.parametrize("handlers", [{**basic_stream}]) -async def test_basic_stream(client: Client) -> None: - from tests.codegen.stream.generated import StreamClient + import tests.codegen.snapshot.snapshots.test_basic_stream + + importlib.reload(tests.codegen.snapshot.snapshots.test_basic_stream) + from tests.codegen.snapshot.snapshots.test_basic_stream import ( + StreamClient, # noqa: E501 + ) + from tests.codegen.snapshot.snapshots.test_basic_stream.test_service.stream_method import ( # noqa: E501 + Stream_MethodInput, + Stream_MethodOutput, + ) async def emit() -> AsyncIterable[Stream_MethodInput]: for i in range(5): From 949761ae3764f840134e3cc96db9406118199870 Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Mon, 17 Mar 2025 16:22:11 -0700 Subject: [PATCH 03/14] Bumping generated code --- .../snapshots/test_basic_stream/test_service/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/codegen/snapshot/snapshots/test_basic_stream/test_service/__init__.py b/tests/codegen/snapshot/snapshots/test_basic_stream/test_service/__init__.py index b0d628b2..4133aea3 100644 --- a/tests/codegen/snapshot/snapshots/test_basic_stream/test_service/__init__.py +++ b/tests/codegen/snapshot/snapshots/test_basic_stream/test_service/__init__.py @@ -25,7 +25,7 @@ def __init__(self, client: river.Client[Any]): async def stream_method( self, inputStream: AsyncIterable[Stream_MethodInput], - ) -> AsyncIterator[Stream_MethodOutput | RiverError]: + ) -> AsyncIterator[Stream_MethodOutput | RiverError | RiverError]: return self.client.send_stream( "test_service", "stream_method", From 989866481286b927c294d2e7b935fa2577714c7d Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Mon, 17 Mar 2025 13:55:39 -0700 Subject: [PATCH 04/14] Adding scripts and tests to lint command --- scripts/lint/src/lint/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/scripts/lint/src/lint/__init__.py b/scripts/lint/src/lint/__init__.py index 6ac5fa1c..f5b5c3d3 100644 --- a/scripts/lint/src/lint/__init__.py +++ b/scripts/lint/src/lint/__init__.py @@ -11,7 +11,7 @@ def raise_err(code: int) -> None: def main() -> None: fix = ["--fix"] if "--fix" in sys.argv else [] - raise_err(os.system(" ".join(["ruff", "check", "src"] + fix))) - raise_err(os.system("ruff format src")) + raise_err(os.system(" ".join(["ruff", "check", "src", "scripts", "tests"] + fix))) + raise_err(os.system("ruff format src scripts tests")) raise_err(os.system("mypy src")) raise_err(os.system("pyright src")) From 7f1170f447095a98484679868476e6bd9921ff47 Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Sat, 15 Mar 2025 22:40:16 -0700 Subject: [PATCH 05/14] Switch from context manager to a straight binding to avoid deadlock --- tests/river_fixtures/clientserver.py | 46 ++++++++++++++++------------ 1 file changed, 26 insertions(+), 20 deletions(-) diff --git a/tests/river_fixtures/clientserver.py b/tests/river_fixtures/clientserver.py index 7c9d3172..3f4ff446 100644 --- a/tests/river_fixtures/clientserver.py +++ b/tests/river_fixtures/clientserver.py @@ -38,32 +38,38 @@ async def client( transport_options: TransportOptions, no_logging_error: NoErrors, ) -> AsyncGenerator[Client, None]: + binding = None try: - async with serve(server.serve, "127.0.0.1") as binding: - sockets = list(binding.sockets) - assert len(sockets) == 1, "Too many sockets!" - socket = sockets[0] + binding = await serve(server.serve, "127.0.0.1") + sockets = list(binding.sockets) + assert len(sockets) == 1, "Too many sockets!" + socket = sockets[0] - async def websocket_uri_factory() -> UriAndMetadata[None]: - return { - "uri": "ws://%s:%d" % socket.getsockname(), - "metadata": None, - } + async def websocket_uri_factory() -> UriAndMetadata[None]: + return { + "uri": "ws://%s:%d" % socket.getsockname(), + "metadata": None, + } + + client: Client[Literal[None]] = Client[None]( + uri_and_metadata_factory=websocket_uri_factory, + client_id="test_client", + server_id="test_server", + transport_options=transport_options, + ) + try: + yield client + finally: + logging.debug("Start closing test client : %s", "test_client") + await client.close() - client: Client[Literal[None]] = Client[None]( - uri_and_metadata_factory=websocket_uri_factory, - client_id="test_client", - server_id="test_server", - transport_options=transport_options, - ) - try: - yield client - finally: - logging.debug("Start closing test client : %s", "test_client") - await client.close() finally: await asyncio.sleep(1) logging.debug("Start closing test server") + if binding: + binding.close() await server.close() + if binding: + await binding.wait_closed() # Server should close normally no_logging_error() From 4e6128e5686d2e627655e33a7ca2a0e9958573bb Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Thu, 7 Nov 2024 11:17:35 -0800 Subject: [PATCH 06/14] We need to explicitly import grpc.aio if we want to use it without errors --- src/replit_river/rpc.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/replit_river/rpc.py b/src/replit_river/rpc.py index ad537f33..a231c149 100644 --- a/src/replit_river/rpc.py +++ b/src/replit_river/rpc.py @@ -21,6 +21,7 @@ ) import grpc +import grpc.aio from aiochannel import Channel, ChannelClosed from opentelemetry.propagators.textmap import Setter from pydantic import BaseModel, ConfigDict, Field From 00af5ac73fd9f19ac01bb3e11eeaa4aaa1ff7019 Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Tue, 5 Nov 2024 22:50:43 -0800 Subject: [PATCH 07/14] Swapping lambda for an equivalent "def" --- src/replit_river/codegen/client.py | 55 +++++++++++++++++++----------- 1 file changed, 36 insertions(+), 19 deletions(-) diff --git a/src/replit_river/codegen/client.py b/src/replit_river/codegen/client.py index 36ddbc76..a22bfa09 100644 --- a/src/replit_river/codegen/client.py +++ b/src/replit_river/codegen/client.py @@ -78,7 +78,6 @@ import datetime from typing import ( Any, - Callable, Dict, List, Literal, @@ -332,22 +331,27 @@ def flatten_union(tpe: RiverType) -> list[RiverType]: encoder_name = TypeName(f"encode_{render_literal_type(prefix)}") encoder_names.add(encoder_name) _field_name = render_literal_type(encoder_name) - _field_type = ( - f"Callable[[{repr(render_literal_type(prefix))}], Any]" - ) + typeddict_encoder = typeddict_encoder[:-1] # Drop the last ternary chunks.append( FileContents( "\n".join( [ dedent( f"""\ - {_field_name}: {_field_type} = ( - lambda x: - """.rstrip() + def {_field_name}( + x: {repr(render_literal_type(prefix))}, + ) -> Any: + return ( + { + reindent( + " ", + "\n".join(typeddict_encoder), + ) + } + ) + """.rstrip() ) ] - + typeddict_encoder[:-1] # Drop the last ternary - + [")"] ) ) ) @@ -404,13 +408,20 @@ def flatten_union(tpe: RiverType) -> list[RiverType]: encoder_name = TypeName(f"encode_{render_literal_type(prefix)}") encoder_names.add(encoder_name) _field_name = render_literal_type(encoder_name) - _field_type = f"Callable[[{repr(render_literal_type(prefix))}], Any]" chunks.append( FileContents( - "\n".join( - [f"{_field_name}: {_field_type} = (lambda x: "] - + typeddict_encoder - + [")"] + dedent( + f""" + def {_field_name}(x: {repr(render_literal_type(prefix))}) -> Any: + return ( + { + reindent( + " ", + "\n".join(typeddict_encoder), + ) + } + ) + """ ) ) ) @@ -702,7 +713,6 @@ def extract_props(tpe: RiverType) -> list[dict[str, RiverType]]: encoder_name = TypeName(f"encode_{render_literal_type(prefix)}") encoder_names.add(encoder_name) _field_name = render_literal_type(encoder_name) - _field_type = f"Callable[[{repr(render_literal_type(prefix))}], Any]" current_chunks.insert( 0, FileContents( @@ -710,13 +720,20 @@ def extract_props(tpe: RiverType) -> list[dict[str, RiverType]]: [ dedent( f"""\ - {_field_name}: {_field_type} = ( - lambda {binding}: + def {_field_name}( + {binding}: {repr(render_literal_type(prefix))}, + ) -> Any: + return ( + { + reindent( + " ", + "\n".join(typeddict_encoder), + ) + } + ) """ ) ] - + typeddict_encoder - + [")"] ) ), ) From 521c2ca4366f0b130ec9086c44135159f85ae015 Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Tue, 5 Nov 2024 14:12:00 -0800 Subject: [PATCH 08/14] This case should have been handled above by is_literal --- src/replit_river/codegen/client.py | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/src/replit_river/codegen/client.py b/src/replit_river/codegen/client.py index a22bfa09..5028acf8 100644 --- a/src/replit_river/codegen/client.py +++ b/src/replit_river/codegen/client.py @@ -155,7 +155,7 @@ def is_literal(tpe: RiverType) -> bool: if isinstance(tpe, RiverUnionType): return all(is_literal(t) for t in tpe.anyOf) elif isinstance(tpe, RiverConcreteType): - return tpe.type in set(["string", "number", "boolean"]) + return tpe.type not in set(["object", "array"]) else: return False @@ -988,15 +988,6 @@ def __init__(self, client: river.Client[Any]): exclude_none=True, ) """ - if ( - ( - isinstance(procedure.input, RiverConcreteType) - and procedure.input.type not in ["object", "array"] - ) - or isinstance(procedure.input, RiverNotType) - or procedure.input is None - ): - render_input_method = "lambda x: x" assert render_input_method, ( f"Unable to derive the input encoder from: {input_type}" From ea63428cdc3f3ebed3a3f53d16de713f34985aac Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Fri, 14 Mar 2025 14:26:01 -0700 Subject: [PATCH 09/14] List, Dict, Tuple -> list, dict, tuple --- src/replit_river/client_transport.py | 6 ++--- src/replit_river/codegen/client.py | 40 ++++++++++++---------------- src/replit_river/codegen/schema.py | 16 +++++------ src/replit_river/codegen/server.py | 18 ++++++------- src/replit_river/error_schema.py | 4 +-- src/replit_river/rate_limiter.py | 9 +++---- src/replit_river/rpc.py | 6 ++--- src/replit_river/server.py | 4 +-- src/replit_river/server_transport.py | 4 +-- src/replit_river/session.py | 8 +++--- src/replit_river/transport.py | 5 ++-- 11 files changed, 55 insertions(+), 65 deletions(-) diff --git a/src/replit_river/client_transport.py b/src/replit_river/client_transport.py index 7ae62b23..be5d9d6a 100644 --- a/src/replit_river/client_transport.py +++ b/src/replit_river/client_transport.py @@ -1,7 +1,7 @@ import asyncio import logging from collections.abc import Awaitable, Callable -from typing import Generic, Optional, Tuple +from typing import Generic, Optional import websockets from pydantic import ValidationError @@ -118,7 +118,7 @@ async def _get_existing_session(self) -> Optional[ClientSession]: async def _establish_new_connection( self, old_session: Optional[ClientSession] = None, - ) -> Tuple[ + ) -> tuple[ WebSocketCommonProtocol, ControlMessageHandshakeRequest[HandshakeMetadataType], ControlMessageHandshakeResponse, @@ -292,7 +292,7 @@ async def _establish_handshake( handshake_metadata: HandshakeMetadataType, websocket: WebSocketCommonProtocol, old_session: Optional[ClientSession], - ) -> Tuple[ + ) -> tuple[ ControlMessageHandshakeRequest[HandshakeMetadataType], ControlMessageHandshakeResponse, ]: diff --git a/src/replit_river/codegen/client.py b/src/replit_river/codegen/client.py index 5028acf8..bfe17cd0 100644 --- a/src/replit_river/codegen/client.py +++ b/src/replit_river/codegen/client.py @@ -6,15 +6,12 @@ from typing import ( Any, Callable, - Dict, - List, Literal, Optional, OrderedDict, Sequence, Set, TextIO, - Tuple, Union, cast, ) @@ -78,14 +75,11 @@ import datetime from typing import ( Any, - Dict, - List, Literal, Optional, Mapping, NotRequired, Union, - Tuple, TypedDict, ) from typing_extensions import Annotated @@ -102,19 +96,19 @@ class RiverConcreteType(BaseModel): type: Optional[str] = Field(default=None) - properties: Dict[str, "RiverType"] = Field(default_factory=lambda: dict()) + properties: dict[str, "RiverType"] = Field(default_factory=lambda: dict()) required: Set[str] = Field(default=set()) items: Optional["RiverType"] = Field(default=None) const: Optional[Union[str, int]] = Field(default=None) - patternProperties: Dict[str, "RiverType"] = Field(default_factory=lambda: dict()) + patternProperties: dict[str, "RiverType"] = Field(default_factory=lambda: dict()) class RiverUnionType(BaseModel): - anyOf: List["RiverType"] + anyOf: list["RiverType"] class RiverIntersectionType(BaseModel): - allOf: List["RiverType"] + allOf: list["RiverType"] class RiverNotType(BaseModel): @@ -140,11 +134,11 @@ class RiverProcedure(BaseModel): class RiverService(BaseModel): - procedures: Dict[str, RiverProcedure] + procedures: dict[str, RiverProcedure] class RiverSchema(BaseModel): - services: Dict[str, RiverService] + services: dict[str, RiverService] handshakeSchema: Optional[RiverConcreteType] = Field(default=None) @@ -166,9 +160,9 @@ def encode_type( base_model: str, in_module: list[ModuleName], permit_unknown_members: bool, -) -> Tuple[TypeExpression, list[ModuleName], list[FileContents], set[TypeName]]: +) -> tuple[TypeExpression, list[ModuleName], list[FileContents], set[TypeName]]: encoder_name: TypeName | None = None # defining this up here to placate mypy - chunks: List[FileContents] = [] + chunks: list[FileContents] = [] if isinstance(type, RiverNotType): return (NoneTypeExpr(), [], [], set()) elif isinstance(type, RiverUnionType): @@ -190,7 +184,7 @@ def flatten_union(tpe: RiverType) -> list[RiverType]: type = RiverUnionType(anyOf=flatten_union(type)) - one_of_candidate_types: List[RiverConcreteType] = [ + one_of_candidate_types: list[RiverConcreteType] = [ t for _t in type.anyOf for t in (_t.anyOf if isinstance(_t, RiverUnionType) else [_t]) @@ -238,7 +232,7 @@ def flatten_union(tpe: RiverType) -> list[RiverType]: (discriminator_value, []), )[1].append(oneof_t) - one_of: List[TypeExpression] = [] + one_of: list[TypeExpression] = [] if discriminator_name == "$kind": discriminator_name = "kind" for pfx, (discriminator_value, oneof_ts) in one_of_pending.items(): @@ -359,7 +353,7 @@ def {_field_name}( # End of stable union detection # Restore the non-flattened union type type = original_type - any_of: List[TypeExpression] = [] + any_of: list[TypeExpression] = [] typeddict_encoder = [] for i, t in enumerate(type.anyOf): @@ -531,7 +525,7 @@ def extract_props(tpe: RiverType) -> list[dict[str, RiverType]]: return (DictTypeExpr(type_name), module_info, type_chunks, encoder_names) assert type.type == "object", type.type - current_chunks: List[str] = [ + current_chunks: list[str] = [ f"class {render_literal_type(prefix)}({base_model}):" ] # For the encoder path, do we need "x" to be bound? @@ -746,7 +740,7 @@ def generate_common_client( client_name: str, handshake_type: HandshakeType, handshake_chunks: Sequence[str], - modules: list[Tuple[ModuleName, ClassName]], + modules: list[tuple[ModuleName, ClassName]], ) -> FileContents: chunks: list[str] = [ROOT_FILE_HEADER] chunks.extend( @@ -778,8 +772,8 @@ def generate_individual_service( schema_name: str, schema: RiverService, input_base_class: Literal["TypedDict"] | Literal["BaseModel"], -) -> Tuple[ModuleName, ClassName, dict[RenderedPath, FileContents]]: - serdes: list[Tuple[list[TypeName], list[ModuleName], list[FileContents]]] = [] +) -> tuple[ModuleName, ClassName, dict[RenderedPath, FileContents]]: + serdes: list[tuple[list[TypeName], list[ModuleName], list[FileContents]]] = [] def _type_adapter_definition( type_adapter_name: TypeName, @@ -802,7 +796,7 @@ def _type_adapter_definition( ) class_name = ClassName(f"{schema_name.title()}Service") - current_chunks: List[str] = [ + current_chunks: list[str] = [ dedent( f"""\ class {class_name}: @@ -1216,7 +1210,7 @@ def generate_river_client_module( else: handshake_type = HandshakeType("Literal[None]") - modules: list[Tuple[ModuleName, ClassName]] = [] + modules: list[tuple[ModuleName, ClassName]] = [] input_base_class: Literal["TypedDict"] | Literal["BaseModel"] = ( "TypedDict" if typed_dict_inputs else "BaseModel" ) diff --git a/src/replit_river/codegen/schema.py b/src/replit_river/codegen/schema.py index c7a924ca..004cd8ce 100644 --- a/src/replit_river/codegen/schema.py +++ b/src/replit_river/codegen/schema.py @@ -4,7 +4,7 @@ import json import os.path import tempfile -from typing import Any, DefaultDict, Dict, List +from typing import Any, DefaultDict import grpc_tools # type: ignore from google.protobuf import descriptor_pb2 @@ -29,15 +29,15 @@ def message_type( module_name: str, m: descriptor_pb2.DescriptorProto, sender: bool, -) -> Dict[str, Any]: +) -> dict[str, Any]: """Generates the type of a protobuf message into Typebox descriptions.""" - type: Dict[str, Any] = { + type: dict[str, Any] = { "type": "object", "properties": {}, "required": [], } # Non-oneof fields. - oneofs: DefaultDict[int, List[descriptor_pb2.FieldDescriptorProto]] = ( + oneofs: DefaultDict[int, list[descriptor_pb2.FieldDescriptorProto]] = ( collections.defaultdict(list) ) for field in m.field: @@ -63,11 +63,11 @@ def message_type( def generate_river_schema( module_name: str, fds: descriptor_pb2.FileDescriptorSet, -) -> List[Dict[str, Any]]: +) -> list[dict[str, Any]]: """Generates the JSON schema of a River module.""" - service_schemas: List[Dict[str, Any]] = [] + service_schemas: list[dict[str, Any]] = [] - message_types: Dict[str, descriptor_pb2.DescriptorProto] = {} + message_types: dict[str, descriptor_pb2.DescriptorProto] = {} for pd in fds.file: for message in pd.message_type: @@ -80,7 +80,7 @@ def _remove_namespace(name: str) -> str: # Generate the service stubs. for service in pd.service: - service_schema: Dict[str, Any] = { + service_schema: dict[str, Any] = { "name": "".join([service.name[0].lower(), service.name[1:]]), "state": {}, "procedures": {}, diff --git a/src/replit_river/codegen/server.py b/src/replit_river/codegen/server.py index cd2f706c..597c5975 100644 --- a/src/replit_river/codegen/server.py +++ b/src/replit_river/codegen/server.py @@ -3,7 +3,7 @@ import subprocess import tempfile from textwrap import dedent -from typing import DefaultDict, List, Sequence +from typing import DefaultDict, Sequence import grpc_tools # type: ignore from google.protobuf import descriptor_pb2 @@ -53,7 +53,7 @@ def _{m.name}Decoder( ), ] # Non-oneof fields. - oneofs: DefaultDict[int, List[descriptor_pb2.FieldDescriptorProto]] = ( + oneofs: DefaultDict[int, list[descriptor_pb2.FieldDescriptorProto]] = ( collections.defaultdict(list) ) for field in m.field: @@ -224,13 +224,13 @@ def message_encoder( f"""\ def _{m.name}Encoder( e: {module_name}_pb2.{m.name} - ) -> Dict[str, Any]: - d: Dict[str, Any] = {{}} + ) -> dict[str, Any]: + d: dict[str, Any] = {{}} """ ), ] # Non-oneof fields. - oneofs: DefaultDict[int, List[descriptor_pb2.FieldDescriptorProto]] = ( + oneofs: DefaultDict[int, list[descriptor_pb2.FieldDescriptorProto]] = ( collections.defaultdict(list) ) for field in m.field: @@ -302,12 +302,12 @@ def generate_river_module( fds: descriptor_pb2.FileDescriptorSet, ) -> Sequence[str]: """Generates the lines of a River module.""" - chunks: List[str] = [ + chunks: list[str] = [ dedent( f"""\ # Code generated by river.codegen. DO NOT EDIT. import datetime - from typing import Any, Dict, Mapping, Tuple + from typing import Any, Mapping from google.protobuf import timestamp_pb2 from google.protobuf.wrappers_pb2 import BoolValue @@ -340,8 +340,8 @@ def add_{service.name}Servicer_to_server( server: river.Server, ) -> None: rpc_method_handlers: Mapping[ - Tuple[str, str], - Tuple[str, river.GenericRpcHandler] + tuple[str, str], + tuple[str, river.GenericRpcHandler] ] = {{ """ ), diff --git a/src/replit_river/error_schema.py b/src/replit_river/error_schema.py index a97fbc9c..b8d0439f 100644 --- a/src/replit_river/error_schema.py +++ b/src/replit_river/error_schema.py @@ -1,4 +1,4 @@ -from typing import Any, List, Optional +from typing import Any, Optional from pydantic import BaseModel, TypeAdapter @@ -94,7 +94,7 @@ def stringify_exception(e: BaseException, limit: int = 10) -> str: if e.__cause__ is None: # If there are no causes, just fall back to stringifying the exception. return str(e) - causes: List[str] = [] + causes: list[str] = [] cause: Optional[BaseException] = e while cause and limit: causes.append(str(cause)) diff --git a/src/replit_river/rate_limiter.py b/src/replit_river/rate_limiter.py index 32b6df7a..b9265eee 100644 --- a/src/replit_river/rate_limiter.py +++ b/src/replit_river/rate_limiter.py @@ -1,7 +1,6 @@ import asyncio import random from contextvars import Context -from typing import Dict from replit_river.transport_options import ConnectionRetryOptions @@ -15,16 +14,16 @@ class LeakyBucketRateLimit: Attributes: options (ConnectionRetryOptions): Configuration options for retry behavior. - budget_consumed (Dict[str, int]): Dictionary tracking the number of retries + budget_consumed (dict[str, int]): Dictionary tracking the number of retries (or budget) consumed per user. - tasks (Dict[str, asyncio.Task]): Dictionary holding asyncio tasks for budget + tasks (dict[str, asyncio.Task]): Dictionary holding asyncio tasks for budget restoration. """ def __init__(self, options: ConnectionRetryOptions): self.options = options - self.budget_consumed: Dict[str, int] = {} - self.tasks: Dict[str, asyncio.Task] = {} + self.budget_consumed: dict[str, int] = {} + self.tasks: dict[str, asyncio.Task] = {} def get_backoff_ms(self, user: str) -> float: """Calculate the backoff time in milliseconds for a user. diff --git a/src/replit_river/rpc.py b/src/replit_river/rpc.py index a231c149..8c314937 100644 --- a/src/replit_river/rpc.py +++ b/src/replit_river/rpc.py @@ -6,7 +6,6 @@ Awaitable, Callable, Coroutine, - Dict, Generic, Iterable, Iterator, @@ -15,7 +14,6 @@ NoReturn, Optional, Sequence, - Tuple, TypeVar, Union, ) @@ -45,7 +43,7 @@ ResponseType = TypeVar("ResponseType") ErrorType = TypeVar("ErrorType", bound=RiverError) -_MetadataType = Union[grpc.aio.Metadata, Sequence[Tuple[str, Union[str, bytes]]]] +_MetadataType = Union[grpc.aio.Metadata, Sequence[tuple[str, Union[str, bytes]]]] GenericRpcHandler = Callable[ [str, Channel[Any], Channel[Any]], Coroutine[None, None, None] @@ -202,7 +200,7 @@ async def write(self, message: ResponseType) -> None: def get_response_or_error_payload( response: Any, response_serializer: Callable[[ResponseType], Any] -) -> Dict: +) -> dict[Any, Any]: if isinstance(response, RiverError): return { "ok": False, diff --git a/src/replit_river/server.py b/src/replit_river/server.py index c4ec23d0..25ae19ae 100644 --- a/src/replit_river/server.py +++ b/src/replit_river/server.py @@ -1,6 +1,6 @@ import asyncio import logging -from typing import Mapping, Optional, Tuple +from typing import Mapping, Optional import websockets from websockets.exceptions import ConnectionClosed @@ -35,7 +35,7 @@ async def close(self) -> None: def add_rpc_handlers( self, - rpc_handlers: Mapping[Tuple[str, str], Tuple[str, GenericRpcHandler]], + rpc_handlers: Mapping[tuple[str, str], tuple[str, GenericRpcHandler]], ) -> None: self._transport._handlers.update(rpc_handlers) diff --git a/src/replit_river/server_transport.py b/src/replit_river/server_transport.py index c2e30657..1296e4dd 100644 --- a/src/replit_river/server_transport.py +++ b/src/replit_river/server_transport.py @@ -1,5 +1,5 @@ import logging -from typing import Any, Optional, Tuple +from typing import Any, Optional import nanoid # type: ignore # type: ignore from pydantic import ValidationError @@ -192,7 +192,7 @@ async def websocket_closed_callback() -> None: async def _establish_handshake( self, request_message: TransportMessage, websocket: WebSocketCommonProtocol - ) -> Tuple[ + ) -> tuple[ ControlMessageHandshakeRequest[Any], ControlMessageHandshakeResponse, ]: diff --git a/src/replit_river/session.py b/src/replit_river/session.py index a69f212c..79f777fe 100644 --- a/src/replit_river/session.py +++ b/src/replit_river/session.py @@ -1,7 +1,7 @@ import asyncio import enum import logging -from typing import Any, Callable, Coroutine, Dict, Optional, Tuple +from typing import Any, Callable, Coroutine, Optional import nanoid # type: ignore import websockets @@ -64,7 +64,7 @@ def __init__( websocket: websockets.WebSocketCommonProtocol, transport_options: TransportOptions, is_server: bool, - handlers: Dict[Tuple[str, str], Tuple[str, GenericRpcHandler]], + handlers: dict[tuple[str, str], tuple[str, GenericRpcHandler]], close_session_callback: Callable[["Session"], Coroutine[Any, Any, Any]], retry_connection_callback: Optional[ Callable[ @@ -94,7 +94,7 @@ def __init__( # stream for tasks self._stream_lock = asyncio.Lock() - self._streams: Dict[str, Channel[Any]] = {} + self._streams: dict[str, Channel[Any]] = {} # book keeping self._seq_manager = SeqManager() @@ -370,7 +370,7 @@ async def get_next_expected_ack(self) -> int: async def send_message( self, stream_id: str, - payload: Dict | str, + payload: dict[Any, Any] | str, control_flags: int = 0, service_name: str | None = None, procedure_name: str | None = None, diff --git a/src/replit_river/transport.py b/src/replit_river/transport.py index d4395bea..f0e2b920 100644 --- a/src/replit_river/transport.py +++ b/src/replit_river/transport.py @@ -1,6 +1,5 @@ import asyncio import logging -from typing import Dict, Tuple import nanoid # type: ignore @@ -23,8 +22,8 @@ def __init__( self._transport_id = transport_id self._transport_options = transport_options self._is_server = is_server - self._sessions: Dict[str, Session] = {} - self._handlers: Dict[Tuple[str, str], Tuple[str, GenericRpcHandler]] = {} + self._sessions: dict[str, Session] = {} + self._handlers: dict[tuple[str, str], tuple[str, GenericRpcHandler]] = {} self._session_lock = asyncio.Lock() async def _close_all_sessions(self) -> None: From 4af86374d150cda119930ed8e322c077c90fc4d1 Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Tue, 5 Nov 2024 23:26:36 -0800 Subject: [PATCH 10/14] Union -> | --- scripts/parity/check_parity.py | 62 +++++++++++++++--------------- src/replit_river/client.py | 8 ++-- src/replit_river/client_session.py | 6 +-- src/replit_river/codegen/client.py | 8 +--- src/replit_river/rpc.py | 4 +- 5 files changed, 43 insertions(+), 45 deletions(-) diff --git a/scripts/parity/check_parity.py b/scripts/parity/check_parity.py index 8e75532c..c89a6918 100644 --- a/scripts/parity/check_parity.py +++ b/scripts/parity/check_parity.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, Literal, TypedDict, TypeVar, Union +from typing import Any, Callable, Literal, TypedDict, TypeVar import pyd import tyd @@ -85,35 +85,37 @@ def testAgenttoollanguageserverOpendocumentInput() -> None: ) -kind_type = Union[ - Literal[1], - Literal[2], - Literal[3], - Literal[4], - Literal[5], - Literal[6], - Literal[7], - Literal[8], - Literal[9], - Literal[10], - Literal[11], - Literal[12], - Literal[13], - Literal[14], - Literal[15], - Literal[16], - Literal[17], - Literal[18], - Literal[19], - Literal[20], - Literal[21], - Literal[22], - Literal[23], - Literal[24], - Literal[25], - Literal[26], - None, -] +kind_type = ( + Literal[ + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + 10, + 11, + 12, + 13, + 14, + 15, + 16, + 17, + 18, + 19, + 20, + 21, + 22, + 23, + 24, + 25, + 26, + ] + | None +) def testAgenttoollanguageserverGetcodesymbolInput() -> None: diff --git a/src/replit_river/client.py b/src/replit_river/client.py index 90a695b1..9e7600fd 100644 --- a/src/replit_river/client.py +++ b/src/replit_river/client.py @@ -3,7 +3,7 @@ from contextlib import contextmanager from dataclasses import dataclass from datetime import timedelta -from typing import Any, AsyncGenerator, Generator, Generic, Literal, Optional, Union +from typing import Any, AsyncGenerator, Generator, Generic, Literal, Optional from opentelemetry import trace from opentelemetry.trace import Span, SpanKind, Status, StatusCode @@ -129,7 +129,7 @@ async def send_subscription( request_serializer: Callable[[RequestType], Any], response_deserializer: Callable[[Any], ResponseType], error_deserializer: Callable[[Any], ErrorType], - ) -> AsyncGenerator[Union[ResponseType, RiverError], None]: + ) -> AsyncGenerator[ResponseType | RiverError, None]: with _trace_procedure( "subscription", service_name, procedure_name ) as span_handle: @@ -157,7 +157,7 @@ async def send_stream( request_serializer: Callable[[RequestType], Any], response_deserializer: Callable[[Any], ResponseType], error_deserializer: Callable[[Any], ErrorType], - ) -> AsyncGenerator[Union[ResponseType, RiverError], None]: + ) -> AsyncGenerator[ResponseType | RiverError, None]: with _trace_procedure("stream", service_name, procedure_name) as span_handle: session = await self._transport.get_or_create_session() async for msg in session.send_stream( @@ -185,7 +185,7 @@ class _SpanHandle: def set_status( self, - status: Union[Status, StatusCode], + status: Status | StatusCode, description: Optional[str] = None, ) -> None: if self.did_set_status: diff --git a/src/replit_river/client_session.py b/src/replit_river/client_session.py index a5a0c44e..f54f0b33 100644 --- a/src/replit_river/client_session.py +++ b/src/replit_river/client_session.py @@ -2,7 +2,7 @@ import logging from collections.abc import AsyncIterable from datetime import timedelta -from typing import Any, AsyncGenerator, Callable, Optional, Union +from typing import Any, AsyncGenerator, Callable, Optional import nanoid # type: ignore from aiochannel import Channel @@ -194,7 +194,7 @@ async def send_subscription( response_deserializer: Callable[[Any], ResponseType], error_deserializer: Callable[[Any], ErrorType], span: Span, - ) -> AsyncGenerator[Union[ResponseType, ErrorType], None]: + ) -> AsyncGenerator[ResponseType | ErrorType, None]: """Sends a subscription request to the server. Expects the input and output be messages that will be msgpacked. @@ -248,7 +248,7 @@ async def send_stream( response_deserializer: Callable[[Any], ResponseType], error_deserializer: Callable[[Any], ErrorType], span: Span, - ) -> AsyncGenerator[Union[ResponseType, ErrorType], None]: + ) -> AsyncGenerator[ResponseType | ErrorType, None]: """Sends a subscription request to the server. Expects the input and output be messages that will be msgpacked. diff --git a/src/replit_river/codegen/client.py b/src/replit_river/codegen/client.py index bfe17cd0..75fa54aa 100644 --- a/src/replit_river/codegen/client.py +++ b/src/replit_river/codegen/client.py @@ -12,7 +12,6 @@ Sequence, Set, TextIO, - Union, cast, ) @@ -79,7 +78,6 @@ Optional, Mapping, NotRequired, - Union, TypedDict, ) from typing_extensions import Annotated @@ -99,7 +97,7 @@ class RiverConcreteType(BaseModel): properties: dict[str, "RiverType"] = Field(default_factory=lambda: dict()) required: Set[str] = Field(default=set()) items: Optional["RiverType"] = Field(default=None) - const: Optional[Union[str, int]] = Field(default=None) + const: Optional[str | int] = Field(default=None) patternProperties: dict[str, "RiverType"] = Field(default_factory=lambda: dict()) @@ -117,9 +115,7 @@ class RiverNotType(BaseModel): not_: Any = Field(..., alias="not") -RiverType = Union[ - RiverConcreteType, RiverUnionType, RiverNotType, RiverIntersectionType -] +RiverType = RiverConcreteType | RiverUnionType | RiverNotType | RiverIntersectionType class RiverProcedure(BaseModel): diff --git a/src/replit_river/rpc.py b/src/replit_river/rpc.py index 8c314937..acee5fa1 100644 --- a/src/replit_river/rpc.py +++ b/src/replit_river/rpc.py @@ -14,8 +14,8 @@ NoReturn, Optional, Sequence, + TypeAlias, TypeVar, - Union, ) import grpc @@ -43,7 +43,7 @@ ResponseType = TypeVar("ResponseType") ErrorType = TypeVar("ErrorType", bound=RiverError) -_MetadataType = Union[grpc.aio.Metadata, Sequence[tuple[str, Union[str, bytes]]]] +_MetadataType: TypeAlias = grpc.aio.Metadata | Sequence[tuple[str, str | bytes]] GenericRpcHandler = Callable[ [str, Channel[Any], Channel[Any]], Coroutine[None, None, None] From c207b3624f92acb0381142707a017c5b96e4bae9 Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Fri, 14 Mar 2025 13:48:45 -0700 Subject: [PATCH 11/14] Optional -> A | None --- scripts/parity/gen.py | 4 ++-- src/replit_river/client.py | 12 ++++++------ src/replit_river/client_session.py | 10 +++++----- src/replit_river/client_transport.py | 12 ++++++------ src/replit_river/codegen/client.py | 28 +++++++++++++--------------- src/replit_river/error_schema.py | 6 +++--- src/replit_river/message_buffer.py | 3 +-- src/replit_river/rpc.py | 25 ++++++++++++------------- src/replit_river/server.py | 4 ++-- src/replit_river/server_transport.py | 6 +++--- src/replit_river/session.py | 13 +++++++------ src/replit_river/task_manager.py | 4 ++-- 12 files changed, 62 insertions(+), 65 deletions(-) diff --git a/scripts/parity/gen.py b/scripts/parity/gen.py index 92bae50b..b719c3f6 100644 --- a/scripts/parity/gen.py +++ b/scripts/parity/gen.py @@ -1,6 +1,6 @@ import random import string -from typing import Callable, Optional, TypeVar +from typing import Callable, TypeVar A = TypeVar("A") @@ -37,7 +37,7 @@ def gen_choice(choices: list[A]) -> Callable[[], A]: return lambda: random.choice(choices) -def gen_opt(gen_x: Callable[[], A]) -> Callable[[], Optional[A]]: +def gen_opt(gen_x: Callable[[], A]) -> Callable[[], A | None]: return lambda: gen_x() if gen_bool() else None diff --git a/src/replit_river/client.py b/src/replit_river/client.py index 9e7600fd..8598c47c 100644 --- a/src/replit_river/client.py +++ b/src/replit_river/client.py @@ -3,7 +3,7 @@ from contextlib import contextmanager from dataclasses import dataclass from datetime import timedelta -from typing import Any, AsyncGenerator, Generator, Generic, Literal, Optional +from typing import Any, AsyncGenerator, Generator, Generic, Literal from opentelemetry import trace from opentelemetry.trace import Span, SpanKind, Status, StatusCode @@ -100,9 +100,9 @@ async def send_upload( self, service_name: str, procedure_name: str, - init: Optional[InitType], + init: InitType | None, request: AsyncIterable[RequestType], - init_serializer: Optional[Callable[[InitType], Any]], + init_serializer: Callable[[InitType], Any] | None, request_serializer: Callable[[RequestType], Any], response_deserializer: Callable[[Any], ResponseType], error_deserializer: Callable[[Any], ErrorType], @@ -151,9 +151,9 @@ async def send_stream( self, service_name: str, procedure_name: str, - init: Optional[InitType], + init: InitType | None, request: AsyncIterable[RequestType], - init_serializer: Optional[Callable[[InitType], Any]], + init_serializer: Callable[[InitType], Any] | None, request_serializer: Callable[[RequestType], Any], response_deserializer: Callable[[Any], ResponseType], error_deserializer: Callable[[Any], ErrorType], @@ -186,7 +186,7 @@ class _SpanHandle: def set_status( self, status: Status | StatusCode, - description: Optional[str] = None, + description: str | None = None, ) -> None: if self.did_set_status: return diff --git a/src/replit_river/client_session.py b/src/replit_river/client_session.py index f54f0b33..479a9f50 100644 --- a/src/replit_river/client_session.py +++ b/src/replit_river/client_session.py @@ -2,7 +2,7 @@ import logging from collections.abc import AsyncIterable from datetime import timedelta -from typing import Any, AsyncGenerator, Callable, Optional +from typing import Any, AsyncGenerator, Callable import nanoid # type: ignore from aiochannel import Channel @@ -102,9 +102,9 @@ async def send_upload( self, service_name: str, procedure_name: str, - init: Optional[InitType], + init: InitType | None, request: AsyncIterable[RequestType], - init_serializer: Optional[Callable[[InitType], Any]], + init_serializer: Callable[[InitType], Any] | None, request_serializer: Callable[[RequestType], Any], response_deserializer: Callable[[Any], ResponseType], error_deserializer: Callable[[Any], ErrorType], @@ -241,9 +241,9 @@ async def send_stream( self, service_name: str, procedure_name: str, - init: Optional[InitType], + init: InitType | None, request: AsyncIterable[RequestType], - init_serializer: Optional[Callable[[InitType], Any]], + init_serializer: Callable[[InitType], Any] | None, request_serializer: Callable[[RequestType], Any], response_deserializer: Callable[[Any], ResponseType], error_deserializer: Callable[[Any], ErrorType], diff --git a/src/replit_river/client_transport.py b/src/replit_river/client_transport.py index be5d9d6a..79552b58 100644 --- a/src/replit_river/client_transport.py +++ b/src/replit_river/client_transport.py @@ -1,7 +1,7 @@ import asyncio import logging from collections.abc import Awaitable, Callable -from typing import Generic, Optional +from typing import Generic import websockets from pydantic import ValidationError @@ -98,7 +98,7 @@ async def get_or_create_session(self) -> ClientSession: await existing_session.close() return await self._create_new_session() - async def _get_existing_session(self) -> Optional[ClientSession]: + async def _get_existing_session(self) -> ClientSession | None: async with self._session_lock: if not self._sessions: return None @@ -117,7 +117,7 @@ async def _get_existing_session(self) -> Optional[ClientSession]: async def _establish_new_connection( self, - old_session: Optional[ClientSession] = None, + old_session: ClientSession | None = None, ) -> tuple[ WebSocketCommonProtocol, ControlMessageHandshakeRequest[HandshakeMetadataType], @@ -129,7 +129,7 @@ async def _establish_new_connection( client_id = self._client_id logger.info("Attempting to establish new ws connection") - last_error: Optional[Exception] = None + last_error: Exception | None = None for i in range(max_retry): if i > 0: logger.info(f"Retrying build handshake number {i} times") @@ -221,7 +221,7 @@ async def _send_handshake_request( transport_id: str, to_id: str, session_id: str, - handshake_metadata: Optional[HandshakeMetadataType], + handshake_metadata: HandshakeMetadataType | None, websocket: WebSocketCommonProtocol, expected_session_state: ExpectedSessionState, ) -> ControlMessageHandshakeRequest[HandshakeMetadataType]: @@ -291,7 +291,7 @@ async def _establish_handshake( session_id: str, handshake_metadata: HandshakeMetadataType, websocket: WebSocketCommonProtocol, - old_session: Optional[ClientSession], + old_session: ClientSession | None, ) -> tuple[ ControlMessageHandshakeRequest[HandshakeMetadataType], ControlMessageHandshakeResponse, diff --git a/src/replit_river/codegen/client.py b/src/replit_river/codegen/client.py index 75fa54aa..440e6634 100644 --- a/src/replit_river/codegen/client.py +++ b/src/replit_river/codegen/client.py @@ -7,7 +7,6 @@ Any, Callable, Literal, - Optional, OrderedDict, Sequence, Set, @@ -75,7 +74,6 @@ from typing import ( Any, Literal, - Optional, Mapping, NotRequired, TypedDict, @@ -93,11 +91,11 @@ class RiverConcreteType(BaseModel): - type: Optional[str] = Field(default=None) + type: str | None = Field(default=None) properties: dict[str, "RiverType"] = Field(default_factory=lambda: dict()) required: Set[str] = Field(default=set()) - items: Optional["RiverType"] = Field(default=None) - const: Optional[str | int] = Field(default=None) + items: "RiverType | None" = Field(default=None) + const: str | int | None = Field(default=None) patternProperties: dict[str, "RiverType"] = Field(default_factory=lambda: dict()) @@ -119,14 +117,14 @@ class RiverNotType(BaseModel): class RiverProcedure(BaseModel): - init: Optional[RiverType] = Field(default=None) + init: RiverType | None = Field(default=None) input: RiverType output: RiverType - errors: Optional[RiverType] = Field(default=None) + errors: RiverType | None = Field(default=None) type: ( Literal["rpc"] | Literal["stream"] | Literal["subscription"] | Literal["upload"] ) - description: Optional[str] = Field(default=None) + description: str | None = Field(default=None) class RiverService(BaseModel): @@ -135,7 +133,7 @@ class RiverService(BaseModel): class RiverSchema(BaseModel): services: dict[str, RiverService] - handshakeSchema: Optional[RiverConcreteType] = Field(default=None) + handshakeSchema: RiverConcreteType | None = Field(default=None) RiverSchemaFile = RootModel[RiverSchema] @@ -645,7 +643,7 @@ def extract_props(tpe: RiverType) -> list[dict[str, RiverType]]: """ ) current_chunks.append( - f" kind: Optional[{render_type_expr(type_name)}]{value}" + f" kind: {render_type_expr(type_name)} | None{value}" ) else: value = "" @@ -668,7 +666,7 @@ def extract_props(tpe: RiverType) -> list[dict[str, RiverType]]: reindent( " ", f"""\ - {name}: NotRequired[Optional[{render_type_expr(type_name)}]] + {name}: NotRequired[{render_type_expr(type_name)}] | None """, ) ) @@ -677,7 +675,7 @@ def extract_props(tpe: RiverType) -> list[dict[str, RiverType]]: reindent( " ", f"""\ - {name}: Optional[{render_type_expr(type_name)}] = None + {name}: {render_type_expr(type_name)} | None = None """, ) ) @@ -803,7 +801,7 @@ def __init__(self, client: river.Client[Any]): ] for name, procedure in schema.procedures.items(): module_names = [ModuleName(name)] - init_type: Optional[TypeExpression] = None + init_type: TypeExpression | None = None if procedure.init: init_type, module_info, init_chunks, encoder_names = encode_type( procedure.init, @@ -918,7 +916,7 @@ def __init__(self, client: river.Client[Any]): """ # Init renderer - render_init_method: Optional[str] = None + render_init_method: str | None = None if init_type and procedure.init is not None: if input_base_class == "TypedDict": if is_literal(procedure.init): @@ -953,7 +951,7 @@ def __init__(self, client: river.Client[Any]): ) # Input renderer - render_input_method: Optional[str] = None + render_input_method: str | None = None if input_base_class == "TypedDict": if is_literal(procedure.input): render_input_method = "lambda x: x" diff --git a/src/replit_river/error_schema.py b/src/replit_river/error_schema.py index b8d0439f..6aae861d 100644 --- a/src/replit_river/error_schema.py +++ b/src/replit_river/error_schema.py @@ -1,4 +1,4 @@ -from typing import Any, Optional +from typing import Any from pydantic import BaseModel, TypeAdapter @@ -41,7 +41,7 @@ class RiverServiceException(RiverException): """Exception raised by river as a result of a fault in the service running river.""" def __init__( - self, code: str, message: str, service: Optional[str], procedure: Optional[str] + self, code: str, message: str, service: str | None, procedure: str | None ) -> None: self.code = code self.message = message @@ -95,7 +95,7 @@ def stringify_exception(e: BaseException, limit: int = 10) -> str: # If there are no causes, just fall back to stringifying the exception. return str(e) causes: list[str] = [] - cause: Optional[BaseException] = e + cause: BaseException | None = e while cause and limit: causes.append(str(cause)) cause = cause.__cause__ diff --git a/src/replit_river/message_buffer.py b/src/replit_river/message_buffer.py index e07ff56e..8bcf023c 100644 --- a/src/replit_river/message_buffer.py +++ b/src/replit_river/message_buffer.py @@ -1,6 +1,5 @@ import asyncio import logging -from typing import Optional from replit_river.rpc import TransportMessage from replit_river.transport_options import MAX_MESSAGE_BUFFER_SIZE @@ -41,7 +40,7 @@ async def put(self, message: TransportMessage) -> None: raise MessageBufferClosedError("message buffer is closed") self.buffer.append(message) - async def peek(self) -> Optional[TransportMessage]: + async def peek(self) -> TransportMessage | None: """Peek the first message in the buffer, returns None if the buffer is empty.""" async with self._lock: if len(self.buffer) == 0: diff --git a/src/replit_river/rpc.py b/src/replit_river/rpc.py index acee5fa1..1a3cebba 100644 --- a/src/replit_river/rpc.py +++ b/src/replit_river/rpc.py @@ -12,7 +12,6 @@ Literal, Mapping, NoReturn, - Optional, Sequence, TypeAlias, TypeVar, @@ -63,7 +62,7 @@ # Equivalent of https://github.com/replit/river/blob/c1345f1ff6a17a841d4319fad5c153b5bda43827/transport/message.ts#L23-L33 class ExpectedSessionState(BaseModel): nextExpectedSeq: int - nextSentSeq: Optional[int] = None + nextSentSeq: int | None = None class ControlMessageHandshakeRequest(BaseModel, Generic[HandshakeMetadataType]): @@ -71,14 +70,14 @@ class ControlMessageHandshakeRequest(BaseModel, Generic[HandshakeMetadataType]): protocolVersion: str sessionId: str expectedSessionState: ExpectedSessionState - metadata: Optional[HandshakeMetadataType] = None + metadata: HandshakeMetadataType | None = None class HandShakeStatus(BaseModel): ok: bool - sessionId: Optional[str] = None - reason: Optional[str] = None - code: Optional[str] = None + sessionId: str | None = None + reason: str | None = None + code: str | None = None class ControlMessageHandshakeResponse(BaseModel): @@ -98,11 +97,11 @@ class TransportMessage(BaseModel): to: str seq: int ack: int - serviceName: Optional[str] = None - procedureName: Optional[str] = None + serviceName: str | None = None + procedureName: str | None = None streamId: str controlFlags: int - tracing: Optional[PropagationContext] = None + tracing: PropagationContext | None = None payload: Any model_config = ConfigDict(populate_by_name=True) # need this because we create TransportMessage objects with destructuring @@ -131,8 +130,8 @@ class GrpcContext(grpc.aio.ServicerContext, Generic[RequestType, ResponseType]): def __init__(self, peer: str) -> None: self._peer = peer - self._abort_code: Optional[grpc.StatusCode] = None - self._abort_details: Optional[str] = None + self._abort_code: grpc.StatusCode | None = None + self._abort_details: str | None = None async def abort( self, @@ -157,10 +156,10 @@ def invocation_metadata(self) -> None: def peer(self) -> str: return self._peer - def peer_identities(self) -> Optional[Iterable[bytes]]: + def peer_identities(self) -> Iterable[bytes] | None: return None - def peer_identity_key(self) -> Optional[str]: + def peer_identity_key(self) -> str | None: return None async def read(self) -> RequestType: diff --git a/src/replit_river/server.py b/src/replit_river/server.py index 25ae19ae..2bdf05b9 100644 --- a/src/replit_river/server.py +++ b/src/replit_river/server.py @@ -1,6 +1,6 @@ import asyncio import logging -from typing import Mapping, Optional +from typing import Mapping import websockets from websockets.exceptions import ConnectionClosed @@ -41,7 +41,7 @@ def add_rpc_handlers( async def _handshake_to_get_session( self, websocket: WebSocketServerProtocol - ) -> Optional[Session]: + ) -> Session | None: """This is a wrapper to make sentry happy, sentry doesn't recognize the exception handling outside of a task or asyncio.wait_for. So we need to catch the errors specifically here. diff --git a/src/replit_river/server_transport.py b/src/replit_river/server_transport.py index 1296e4dd..888e0ce3 100644 --- a/src/replit_river/server_transport.py +++ b/src/replit_river/server_transport.py @@ -1,5 +1,5 @@ import logging -from typing import Any, Optional +from typing import Any import nanoid # type: ignore # type: ignore from pydantic import ValidationError @@ -98,8 +98,8 @@ async def _get_or_create_session( websocket: WebSocketCommonProtocol, ) -> Session: async with self._session_lock: - session_to_close: Optional[Session] = None - new_session: Optional[Session] = None + session_to_close: Session | None = None + new_session: Session | None = None if to_id not in self._sessions: logger.info( 'Creating new session with "%s" using ws: %s', to_id, websocket.id diff --git a/src/replit_river/session.py b/src/replit_river/session.py index 79f777fe..a94156a8 100644 --- a/src/replit_river/session.py +++ b/src/replit_river/session.py @@ -1,7 +1,7 @@ import asyncio import enum import logging -from typing import Any, Callable, Coroutine, Optional +from typing import Any, Callable, Coroutine import nanoid # type: ignore import websockets @@ -66,12 +66,13 @@ def __init__( is_server: bool, handlers: dict[tuple[str, str], tuple[str, GenericRpcHandler]], close_session_callback: Callable[["Session"], Coroutine[Any, Any, Any]], - retry_connection_callback: Optional[ + retry_connection_callback: ( Callable[ [], Coroutine[Any, Any, Any], ] - ] = None, + | None + ) = None, ) -> None: self._transport_id = transport_id self._to_id = to_id @@ -84,7 +85,7 @@ def __init__( self._state = SessionState.ACTIVE self._state_lock = asyncio.Lock() self._close_session_callback = close_session_callback - self._close_session_after_time_secs: Optional[float] = None + self._close_session_after_time_secs: float | None = None # ws state self._ws_lock = asyncio.Lock() @@ -172,7 +173,7 @@ async def _update_book_keeping(self, msg: TransportMessage) -> None: self._reset_session_close_countdown() async def _handle_messages_from_ws( - self, tg: Optional[asyncio.TaskGroup] = None + self, tg: asyncio.TaskGroup | None = None ) -> None: logger.debug( "%s start handling messages from ws %s", @@ -462,7 +463,7 @@ async def close_websocket( async def _open_stream_and_call_handler( self, msg: TransportMessage, - tg: Optional[asyncio.TaskGroup], + tg: asyncio.TaskGroup | None, ) -> Channel: if not self._is_server: raise InvalidMessageException("Client should not receive stream open bit") diff --git a/src/replit_river/task_manager.py b/src/replit_river/task_manager.py index 4064681d..531292d0 100644 --- a/src/replit_river/task_manager.py +++ b/src/replit_river/task_manager.py @@ -1,6 +1,6 @@ import asyncio import logging -from typing import Coroutine, Optional, Set +from typing import Coroutine, Set from replit_river.error_schema import ERROR_CODE_STREAM_CLOSED, RiverException @@ -83,7 +83,7 @@ def _task_done_callback( ) def create_task( - self, fn: Coroutine[None, None, None], tg: Optional[asyncio.TaskGroup] = None + self, fn: Coroutine[None, None, None], tg: asyncio.TaskGroup | None = None ) -> asyncio.Task[None]: """Creates a task from a callable and adds it to the background tasks set. From 3eb41c295f0d55706698c2940a471647412b789f Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Wed, 6 Nov 2024 18:13:05 -0800 Subject: [PATCH 12/14] Add noqa to remove the need for post-processing in consumers --- src/replit_river/codegen/server.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/replit_river/codegen/server.py b/src/replit_river/codegen/server.py index 597c5975..acee8106 100644 --- a/src/replit_river/codegen/server.py +++ b/src/replit_river/codegen/server.py @@ -305,6 +305,7 @@ def generate_river_module( chunks: list[str] = [ dedent( f"""\ + # ruff: noqa # Code generated by river.codegen. DO NOT EDIT. import datetime from typing import Any, Mapping From 971ca4cfc8f0bf4a50559cc00929780180f71c67 Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Fri, 14 Mar 2025 16:20:50 -0700 Subject: [PATCH 13/14] Flatten nested unions, flatten Literal peers --- src/replit_river/codegen/typing.py | 45 ++++++++++++++++++++++++++++-- 1 file changed, 43 insertions(+), 2 deletions(-) diff --git a/src/replit_river/codegen/typing.py b/src/replit_river/codegen/typing.py index ea1f00b7..3489945c 100644 --- a/src/replit_river/codegen/typing.py +++ b/src/replit_river/codegen/typing.py @@ -74,8 +74,33 @@ def __str__(self) -> str: ) +def _flatten_nested_unions(value: TypeExpression) -> TypeExpression: + def work( + value: TypeExpression, + ) -> tuple[list[TypeExpression], TypeExpression | None]: + match value: + case UnionTypeExpr(inner): + flattened: list[TypeExpression] = [] + for tpe in inner: + _union, _nonunion = work(tpe) + flattened.extend(_union) + if _nonunion is not None: + flattened.append(_nonunion) + return (flattened, None) + case other: + return ([], other) + + _inner, nonunion = work(value) + if nonunion and not _inner: + return nonunion + elif _inner and nonunion is None: + return UnionTypeExpr(_inner) + else: + raise ValueError("Incoherent state when trying to flatten unions") + + def render_type_expr(value: TypeExpression) -> str: - match value: + match _flatten_nested_unions(value): case DictTypeExpr(nested): return f"dict[str, {render_type_expr(nested)}]" case ListTypeExpr(nested): @@ -83,7 +108,23 @@ def render_type_expr(value: TypeExpression) -> str: case LiteralTypeExpr(inner): return f"Literal[{repr(inner)}]" case UnionTypeExpr(inner): - return " | ".join(render_type_expr(x) for x in inner) + literals: list[LiteralTypeExpr] = [] + _other: list[TypeExpression] = [] + for tpe in inner: + if isinstance(tpe, UnionTypeExpr): + raise ValueError("These should have been flattened") + elif isinstance(tpe, LiteralTypeExpr): + literals.append(tpe) + else: + _other.append(tpe) + retval: str = " | ".join(render_type_expr(x) for x in _other) + if literals: + _rendered: str = ", ".join(repr(x.nested) for x in literals) + if retval: + retval = f"Literal[{_rendered}] | {retval}" + else: + retval = f"Literal[{_rendered}]" + return retval case OpenUnionTypeExpr(inner): return ( "Annotated[" From 652a176ed13521101cc1355590bdc90c4ed31a70 Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Mon, 17 Mar 2025 13:41:35 -0700 Subject: [PATCH 14/14] Regenerating tests --- .../rpc/generated/test_service/rpc_method.py | 27 +++---- .../test_service/stream_method.py | 27 +++---- .../enumService/needsEnum.py | 18 ++--- .../enumService/needsEnumObject.py | 79 ++++++++++--------- 4 files changed, 72 insertions(+), 79 deletions(-) diff --git a/tests/codegen/rpc/generated/test_service/rpc_method.py b/tests/codegen/rpc/generated/test_service/rpc_method.py index d32b4645..91f0e562 100644 --- a/tests/codegen/rpc/generated/test_service/rpc_method.py +++ b/tests/codegen/rpc/generated/test_service/rpc_method.py @@ -4,15 +4,9 @@ import datetime from typing import ( Any, - Callable, - Dict, - List, Literal, - Optional, Mapping, NotRequired, - Union, - Tuple, TypedDict, ) from typing_extensions import Annotated @@ -24,15 +18,18 @@ import replit_river as river -encode_Rpc_MethodInput: Callable[["Rpc_MethodInput"], Any] = lambda x: { - k: v - for (k, v) in ( - { - "data": x.get("data"), - } - ).items() - if v is not None -} +def encode_Rpc_MethodInput( + x: "Rpc_MethodInput", +) -> Any: + return { + k: v + for (k, v) in ( + { + "data": x.get("data"), + } + ).items() + if v is not None + } class Rpc_MethodInput(TypedDict): diff --git a/tests/codegen/snapshot/snapshots/test_basic_stream/test_service/stream_method.py b/tests/codegen/snapshot/snapshots/test_basic_stream/test_service/stream_method.py index 1294a67a..23d2ab6d 100644 --- a/tests/codegen/snapshot/snapshots/test_basic_stream/test_service/stream_method.py +++ b/tests/codegen/snapshot/snapshots/test_basic_stream/test_service/stream_method.py @@ -4,15 +4,9 @@ import datetime from typing import ( Any, - Callable, - Dict, - List, Literal, - Optional, Mapping, NotRequired, - Union, - Tuple, TypedDict, ) from typing_extensions import Annotated @@ -24,15 +18,18 @@ import replit_river as river -encode_Stream_MethodInput: Callable[["Stream_MethodInput"], Any] = lambda x: { - k: v - for (k, v) in ( - { - "data": x.get("data"), - } - ).items() - if v is not None -} +def encode_Stream_MethodInput( + x: "Stream_MethodInput", +) -> Any: + return { + k: v + for (k, v) in ( + { + "data": x.get("data"), + } + ).items() + if v is not None + } class Stream_MethodInput(TypedDict): diff --git a/tests/codegen/snapshot/snapshots/test_unknown_enum/enumService/needsEnum.py b/tests/codegen/snapshot/snapshots/test_unknown_enum/enumService/needsEnum.py index 95e6bd2c..dbe6e51e 100644 --- a/tests/codegen/snapshot/snapshots/test_unknown_enum/enumService/needsEnum.py +++ b/tests/codegen/snapshot/snapshots/test_unknown_enum/enumService/needsEnum.py @@ -4,15 +4,9 @@ import datetime from typing import ( Any, - Callable, - Dict, - List, Literal, - Optional, Mapping, NotRequired, - Union, - Tuple, TypedDict, ) from typing_extensions import Annotated @@ -24,20 +18,24 @@ import replit_river as river -NeedsenumInput = Literal["in_first"] | Literal["in_second"] -encode_NeedsenumInput: Callable[["NeedsenumInput"], Any] = lambda x: x +NeedsenumInput = Literal["in_first", "in_second"] + + +def encode_NeedsenumInput(x: "NeedsenumInput") -> Any: + return x + NeedsenumInputTypeAdapter: TypeAdapter[Any] = TypeAdapter(NeedsenumInput) NeedsenumOutput = Annotated[ - Literal["out_first"] | Literal["out_second"] | RiverUnknownValue, + Literal["out_first", "out_second"] | RiverUnknownValue, WrapValidator(translate_unknown_value), ] NeedsenumOutputTypeAdapter: TypeAdapter[Any] = TypeAdapter(NeedsenumOutput) NeedsenumErrors = Annotated[ - Literal["err_first"] | Literal["err_second"] | RiverUnknownValue, + Literal["err_first", "err_second"] | RiverUnknownValue, WrapValidator(translate_unknown_value), ] diff --git a/tests/codegen/snapshot/snapshots/test_unknown_enum/enumService/needsEnumObject.py b/tests/codegen/snapshot/snapshots/test_unknown_enum/enumService/needsEnumObject.py index 4e370433..75f00e1c 100644 --- a/tests/codegen/snapshot/snapshots/test_unknown_enum/enumService/needsEnumObject.py +++ b/tests/codegen/snapshot/snapshots/test_unknown_enum/enumService/needsEnumObject.py @@ -4,15 +4,9 @@ import datetime from typing import ( Any, - Callable, - Dict, - List, Literal, - Optional, Mapping, NotRequired, - Union, - Tuple, TypedDict, ) from typing_extensions import Annotated @@ -24,18 +18,19 @@ import replit_river as river -encode_NeedsenumobjectInputOneOf_in_first: Callable[ - ["NeedsenumobjectInputOneOf_in_first"], Any -] = lambda x: { - k: v - for (k, v) in ( - { - "$kind": x.get("kind"), - "value": x.get("value"), - } - ).items() - if v is not None -} +def encode_NeedsenumobjectInputOneOf_in_first( + x: "NeedsenumobjectInputOneOf_in_first", +) -> Any: + return { + k: v + for (k, v) in ( + { + "$kind": x.get("kind"), + "value": x.get("value"), + } + ).items() + if v is not None + } class NeedsenumobjectInputOneOf_in_first(TypedDict): @@ -43,18 +38,19 @@ class NeedsenumobjectInputOneOf_in_first(TypedDict): value: str -encode_NeedsenumobjectInputOneOf_in_second: Callable[ - ["NeedsenumobjectInputOneOf_in_second"], Any -] = lambda x: { - k: v - for (k, v) in ( - { - "$kind": x.get("kind"), - "bleep": x.get("bleep"), - } - ).items() - if v is not None -} +def encode_NeedsenumobjectInputOneOf_in_second( + x: "NeedsenumobjectInputOneOf_in_second", +) -> Any: + return { + k: v + for (k, v) in ( + { + "$kind": x.get("kind"), + "bleep": x.get("bleep"), + } + ).items() + if v is not None + } class NeedsenumobjectInputOneOf_in_second(TypedDict): @@ -66,11 +62,16 @@ class NeedsenumobjectInputOneOf_in_second(TypedDict): NeedsenumobjectInputOneOf_in_first | NeedsenumobjectInputOneOf_in_second ) -encode_NeedsenumobjectInput: Callable[["NeedsenumobjectInput"], Any] = ( - lambda x: encode_NeedsenumobjectInputOneOf_in_first(x) - if x["kind"] == "in_first" - else encode_NeedsenumobjectInputOneOf_in_second(x) -) + +def encode_NeedsenumobjectInput( + x: "NeedsenumobjectInput", +) -> Any: + return ( + encode_NeedsenumobjectInputOneOf_in_first(x) + if x["kind"] == "in_first" + else encode_NeedsenumobjectInputOneOf_in_second(x) + ) + NeedsenumobjectInputTypeAdapter: TypeAdapter[Any] = TypeAdapter(NeedsenumobjectInput) @@ -102,18 +103,18 @@ class NeedsenumobjectOutputFooOneOf_out_second(BaseModel): class NeedsenumobjectOutput(BaseModel): - foo: Optional[NeedsenumobjectOutputFoo] = None + foo: NeedsenumobjectOutputFoo | None = None NeedsenumobjectOutputTypeAdapter: TypeAdapter[Any] = TypeAdapter(NeedsenumobjectOutput) class NeedsenumobjectErrorsFooAnyOf_0(RiverError): - beep: Optional[Literal["err_first"]] = None + beep: Literal["err_first"] | None = None class NeedsenumobjectErrorsFooAnyOf_1(RiverError): - borp: Optional[Literal["err_second"]] = None + borp: Literal["err_second"] | None = None NeedsenumobjectErrorsFoo = Annotated[ @@ -125,7 +126,7 @@ class NeedsenumobjectErrorsFooAnyOf_1(RiverError): class NeedsenumobjectErrors(RiverError): - foo: Optional[NeedsenumobjectErrorsFoo] = None + foo: NeedsenumobjectErrorsFoo | None = None NeedsenumobjectErrorsTypeAdapter: TypeAdapter[Any] = TypeAdapter(NeedsenumobjectErrors)