Skip to content

Commit fe976e5

Browse files
committed
Accept a pre-built dispatcher via a constructor keyword
Replaces the from_dispatcher classmethod: read_stream/write_stream become optional and dispatcher is a keyword-only alternative, with mutual exclusion validated at construction. Drops the __new__-based alternate constructor and its shared state-init helper.
1 parent c722121 commit fe976e5

3 files changed

Lines changed: 53 additions & 76 deletions

File tree

docs/migration.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1168,7 +1168,7 @@ In practice, replace direct `ServerSession` use with `Server.run(read_stream, wr
11681168

11691169
### `ClientSession` now runs on `JSONRPCDispatcher`; `BaseSession` removed
11701170

1171-
`ClientSession` keeps its public surface — the `(read_stream, write_stream, ...)` constructor, every typed method, manual `initialize()`, and the async context-manager lifecycle — but the v1 receive loop (`BaseSession`) underneath it is gone. A new `ClientSession.from_dispatcher(dispatcher, ...)` constructor accepts a pre-built dispatcher (for example a `DirectDispatcher` for in-process embedding).
1171+
`ClientSession` keeps its public surface — the `(read_stream, write_stream, ...)` constructor, every typed method, manual `initialize()`, and the async context-manager lifecycle — but the v1 receive loop (`BaseSession`) underneath it is gone. A new keyword-only `dispatcher=` constructor argument accepts a pre-built dispatcher instead of the stream pair (for example a `DirectDispatcher` for in-process embedding).
11721172

11731173
Behavior changes:
11741174

src/mcp/client/session.py

Lines changed: 28 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -117,19 +117,25 @@ async def _default_logging_callback(
117117

118118

119119
class ClientSession:
120-
"""Client half of an MCP connection, running on `JSONRPCDispatcher`.
121-
122-
Construct it over a transport's stream pair, enter it as an async context
123-
manager, then call `initialize()`. The receive loop, request correlation,
124-
and per-request concurrency live in the dispatcher; this class owns the
125-
MCP type layer: typed requests, the initialize handshake, and routing
126-
server-initiated traffic to the constructor callbacks.
120+
"""Client half of an MCP connection, running on a `Dispatcher`.
121+
122+
Construct it over a transport's stream pair (or pass a pre-built
123+
`dispatcher=` instead, e.g. a `DirectDispatcher` for in-process
124+
embedding), enter it as an async context manager, then call
125+
`initialize()`. The receive loop, request correlation, and per-request
126+
concurrency live in the dispatcher; this class owns the MCP type layer:
127+
typed requests, the initialize handshake, and routing server-initiated
128+
traffic to the constructor callbacks.
129+
130+
Transport-level `Exception` items reach `message_handler` only when the
131+
session builds its own dispatcher from streams, where it wires the
132+
dispatcher's `on_stream_exception` itself.
127133
"""
128134

129135
def __init__(
130136
self,
131-
read_stream: ReadStream[SessionMessage | Exception],
132-
write_stream: WriteStream[SessionMessage],
137+
read_stream: ReadStream[SessionMessage | Exception] | None = None,
138+
write_stream: WriteStream[SessionMessage] | None = None,
133139
read_timeout_seconds: float | None = None,
134140
sampling_callback: SamplingFnT | None = None,
135141
elicitation_callback: ElicitationFnT | None = None,
@@ -139,69 +145,7 @@ def __init__(
139145
client_info: types.Implementation | None = None,
140146
*,
141147
sampling_capabilities: types.SamplingCapability | None = None,
142-
) -> None:
143-
self._init_state(
144-
read_timeout_seconds=read_timeout_seconds,
145-
sampling_callback=sampling_callback,
146-
elicitation_callback=elicitation_callback,
147-
list_roots_callback=list_roots_callback,
148-
logging_callback=logging_callback,
149-
message_handler=message_handler,
150-
client_info=client_info,
151-
sampling_capabilities=sampling_capabilities,
152-
)
153-
# Built here (inert until run() starts in __aenter__) so notifications
154-
# can be sent before entering the context manager, as before.
155-
self._dispatcher: Dispatcher[Any] = JSONRPCDispatcher(
156-
read_stream, write_stream, on_stream_exception=self._on_stream_exception
157-
)
158-
159-
@classmethod
160-
def from_dispatcher(
161-
cls,
162-
dispatcher: Dispatcher[Any],
163-
*,
164-
read_timeout_seconds: float | None = None,
165-
sampling_callback: SamplingFnT | None = None,
166-
elicitation_callback: ElicitationFnT | None = None,
167-
list_roots_callback: ListRootsFnT | None = None,
168-
logging_callback: LoggingFnT | None = None,
169-
message_handler: MessageHandlerFnT | None = None,
170-
client_info: types.Implementation | None = None,
171-
sampling_capabilities: types.SamplingCapability | None = None,
172-
) -> Self:
173-
"""Build a session over a pre-built dispatcher instead of a stream pair.
174-
175-
For embedding a server in-process (`DirectDispatcher`) or transports
176-
that construct their own dispatcher. Transport-level `Exception` items
177-
reach `message_handler` only on the stream constructor, where the
178-
session wires the dispatcher's `on_stream_exception` itself.
179-
"""
180-
self = cls.__new__(cls)
181-
self._init_state(
182-
read_timeout_seconds=read_timeout_seconds,
183-
sampling_callback=sampling_callback,
184-
elicitation_callback=elicitation_callback,
185-
list_roots_callback=list_roots_callback,
186-
logging_callback=logging_callback,
187-
message_handler=message_handler,
188-
client_info=client_info,
189-
sampling_capabilities=sampling_capabilities,
190-
)
191-
self._dispatcher = dispatcher
192-
return self
193-
194-
def _init_state(
195-
self,
196-
*,
197-
read_timeout_seconds: float | None,
198-
sampling_callback: SamplingFnT | None,
199-
elicitation_callback: ElicitationFnT | None,
200-
list_roots_callback: ListRootsFnT | None,
201-
logging_callback: LoggingFnT | None,
202-
message_handler: MessageHandlerFnT | None,
203-
client_info: types.Implementation | None,
204-
sampling_capabilities: types.SamplingCapability | None,
148+
dispatcher: Dispatcher[Any] | None = None,
205149
) -> None:
206150
self._session_read_timeout_seconds = read_timeout_seconds
207151
self._client_info = client_info or DEFAULT_CLIENT_INFO
@@ -214,6 +158,18 @@ def _init_state(
214158
self._tool_output_schemas: dict[str, dict[str, Any] | None] = {}
215159
self._initialize_result: types.InitializeResult | None = None
216160
self._task_group: anyio.abc.TaskGroup | None = None
161+
if dispatcher is not None:
162+
if read_stream is not None or write_stream is not None:
163+
raise ValueError("pass read_stream/write_stream or dispatcher, not both")
164+
self._dispatcher: Dispatcher[Any] = dispatcher
165+
else:
166+
if read_stream is None or write_stream is None:
167+
raise ValueError("read_stream and write_stream are required when no dispatcher is given")
168+
# Built here (inert until run() starts in __aenter__) so notifications
169+
# can be sent before entering the context manager, as before.
170+
self._dispatcher = JSONRPCDispatcher(
171+
read_stream, write_stream, on_stream_exception=self._on_stream_exception
172+
)
217173

218174
async def __aenter__(self) -> Self:
219175
self._task_group = anyio.create_task_group()

tests/client/test_session.py

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -900,8 +900,8 @@ async def call() -> None:
900900

901901

902902
@pytest.mark.anyio
903-
async def test_from_dispatcher_runs_over_direct_dispatch():
904-
"""A session built with from_dispatcher works without a stream pair (in-process embedding)."""
903+
async def test_dispatcher_keyword_runs_over_direct_dispatch():
904+
"""A session built with dispatcher= works without a stream pair (in-process embedding)."""
905905
from mcp.shared.direct_dispatcher import create_direct_dispatcher_pair
906906
from mcp.shared.dispatcher import DispatchContext
907907
from mcp.shared.transport_context import TransportContext
@@ -921,7 +921,7 @@ async def server_on_notify(
921921
) -> None:
922922
notified.append(method)
923923

924-
session = ClientSession.from_dispatcher(client_side)
924+
session = ClientSession(dispatcher=client_side)
925925
results: list[types.EmptyResult] = []
926926
async with anyio.create_task_group() as tg:
927927
await tg.start(server_side.run, server_on_request, server_on_notify)
@@ -938,6 +938,27 @@ async def server_on_notify(
938938
assert notified == ["notifications/roots/list_changed"]
939939

940940

941+
def test_constructor_rejects_streams_and_dispatcher_together():
942+
from mcp.shared.direct_dispatcher import create_direct_dispatcher_pair
943+
944+
client_side, _server_side = create_direct_dispatcher_pair()
945+
s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](1)
946+
with pytest.raises(ValueError, match="not both"):
947+
ClientSession(s2c_recv, dispatcher=client_side)
948+
s2c_send.close()
949+
s2c_recv.close()
950+
951+
952+
def test_constructor_requires_both_streams_without_dispatcher():
953+
s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](1)
954+
with pytest.raises(ValueError, match="read_stream and write_stream are required"):
955+
ClientSession(s2c_recv)
956+
with pytest.raises(ValueError, match="read_stream and write_stream are required"):
957+
ClientSession()
958+
s2c_send.close()
959+
s2c_recv.close()
960+
961+
941962
@pytest.mark.anyio
942963
async def test_send_request_with_server_metadata_routes_related_request_id():
943964
"""ServerMessageMetadata.related_request_id is threaded onto the outgoing message."""

0 commit comments

Comments
 (0)