From 9dddd4e89bcd0a8b49fe2264cad902ff7c538a8f Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Thu, 16 Apr 2026 11:56:49 +0000 Subject: [PATCH 01/27] feat: add Dispatcher Protocol and DirectDispatcher Introduces the Dispatcher abstraction that decouples MCP request/response handling from JSON-RPC framing. A Dispatcher exposes call/notify for outbound messages and run(on_call, on_notify) for inbound dispatch, with no knowledge of MCP types or wire encoding. - shared/dispatcher.py: Dispatcher, DispatchContext, RequestSender Protocols; CallOptions, OnCall/OnNotify, ProgressFnT, DispatchMiddleware - shared/transport_context.py: TransportContext base dataclass - shared/direct_dispatcher.py: in-memory Dispatcher impl that wires two peers with no transport; serves as a fast test substrate and second-impl proof - shared/exceptions.py: NoBackChannelError(MCPError) for transports without a server-to-client request channel - types: REQUEST_CANCELLED SDK error code The JSON-RPC implementation and ServerRunner that consume this Protocol land in follow-up PRs. --- src/mcp/shared/direct_dispatcher.py | 173 +++++++++++++++++++ src/mcp/shared/dispatcher.py | 167 ++++++++++++++++++ src/mcp/shared/exceptions.py | 21 ++- src/mcp/shared/transport_context.py | 30 ++++ src/mcp/types/__init__.py | 2 + src/mcp/types/jsonrpc.py | 1 + tests/shared/test_dispatcher.py | 253 ++++++++++++++++++++++++++++ 7 files changed, 646 insertions(+), 1 deletion(-) create mode 100644 src/mcp/shared/direct_dispatcher.py create mode 100644 src/mcp/shared/dispatcher.py create mode 100644 src/mcp/shared/transport_context.py create mode 100644 tests/shared/test_dispatcher.py diff --git a/src/mcp/shared/direct_dispatcher.py b/src/mcp/shared/direct_dispatcher.py new file mode 100644 index 0000000000..4650619428 --- /dev/null +++ b/src/mcp/shared/direct_dispatcher.py @@ -0,0 +1,173 @@ +"""In-memory `Dispatcher` that wires two peers together with no transport. + +`DirectDispatcher` is the simplest possible `Dispatcher` implementation: a call +on one side directly invokes the other side's `on_call`. There is no +serialization, no JSON-RPC framing, and no streams. It exists to: + +* prove the `Dispatcher` Protocol is implementable without JSON-RPC +* provide a fast substrate for testing the layers above the dispatcher + (`ServerRunner`, `Context`, `Connection`) without wire-level moving parts +* embed a server in-process when the JSON-RPC overhead is unnecessary + +Unlike `JSONRPCDispatcher`, exceptions raised in a handler propagate directly +to the caller — there is no exception-to-`ErrorData` boundary here. +""" + +from __future__ import annotations + +from collections.abc import Awaitable, Callable, Mapping +from dataclasses import dataclass, field +from typing import Any + +import anyio + +from mcp.shared.dispatcher import CallOptions, OnCall, OnNotify, ProgressFnT +from mcp.shared.exceptions import MCPError, NoBackChannelError +from mcp.shared.transport_context import TransportContext +from mcp.types import INTERNAL_ERROR, REQUEST_TIMEOUT + +__all__ = ["DirectDispatcher", "create_direct_dispatcher_pair"] + +DIRECT_TRANSPORT_KIND = "direct" + + +_Call = Callable[[str, Mapping[str, Any] | None, CallOptions | None], Awaitable[dict[str, Any]]] +_Notify = Callable[[str, Mapping[str, Any] | None], Awaitable[None]] + + +@dataclass +class _DirectDispatchContext: + """`DispatchContext` for an inbound call on a `DirectDispatcher`. + + The back-channel callables target the *originating* side, so a handler's + `send_request` reaches the peer that made the inbound call. + """ + + transport: TransportContext + _back_call: _Call + _back_notify: _Notify + _on_progress: ProgressFnT | None = None + cancel_requested: anyio.Event = field(default_factory=anyio.Event) + + async def notify(self, method: str, params: Mapping[str, Any] | None) -> None: + await self._back_notify(method, params) + + async def send_request( + self, + method: str, + params: Mapping[str, Any] | None, + opts: CallOptions | None = None, + ) -> dict[str, Any]: + if not self.transport.can_send_request: + raise NoBackChannelError(method) + return await self._back_call(method, params, opts) + + async def progress(self, progress: float, total: float | None = None, message: str | None = None) -> None: + if self._on_progress is not None: + await self._on_progress(progress, total, message) + + +class DirectDispatcher: + """A `Dispatcher` that calls a peer's handlers directly, in-process. + + Two instances are wired together with `create_direct_dispatcher_pair`; each + holds a reference to the other. `call` on one awaits the peer's `on_call`. + `run` parks until `close` is called. + """ + + def __init__(self, transport_ctx: TransportContext): + self._transport_ctx = transport_ctx + self._peer: DirectDispatcher | None = None + self._on_call: OnCall | None = None + self._on_notify: OnNotify | None = None + self._ready = anyio.Event() + self._closed = anyio.Event() + + def connect_to(self, peer: DirectDispatcher) -> None: + self._peer = peer + + async def call( + self, + method: str, + params: Mapping[str, Any] | None, + opts: CallOptions | None = None, + ) -> dict[str, Any]: + if self._peer is None: + raise RuntimeError("DirectDispatcher has no peer; use create_direct_dispatcher_pair()") + return await self._peer._dispatch_call(method, params, opts) + + async def notify(self, method: str, params: Mapping[str, Any] | None) -> None: + if self._peer is None: + raise RuntimeError("DirectDispatcher has no peer; use create_direct_dispatcher_pair()") + await self._peer._dispatch_notify(method, params) + + async def run(self, on_call: OnCall, on_notify: OnNotify) -> None: + self._on_call = on_call + self._on_notify = on_notify + self._ready.set() + await self._closed.wait() + + def close(self) -> None: + self._closed.set() + + def _make_context(self, on_progress: ProgressFnT | None = None) -> _DirectDispatchContext: + assert self._peer is not None + peer = self._peer + return _DirectDispatchContext( + transport=self._transport_ctx, + _back_call=lambda m, p, o: peer._dispatch_call(m, p, o), + _back_notify=lambda m, p: peer._dispatch_notify(m, p), + _on_progress=on_progress, + ) + + async def _dispatch_call( + self, + method: str, + params: Mapping[str, Any] | None, + opts: CallOptions | None, + ) -> dict[str, Any]: + await self._ready.wait() + assert self._on_call is not None + opts = opts or {} + dctx = self._make_context(on_progress=opts.get("on_progress")) + try: + with anyio.fail_after(opts.get("timeout")): + try: + return await self._on_call(dctx, method, params) + except MCPError: + raise + except Exception as e: + raise MCPError(code=INTERNAL_ERROR, message=str(e)) from e + except TimeoutError: + raise MCPError( + code=REQUEST_TIMEOUT, + message=f"Timed out after {opts.get('timeout')}s waiting for {method!r}", + ) from None + + async def _dispatch_notify(self, method: str, params: Mapping[str, Any] | None) -> None: + await self._ready.wait() + assert self._on_notify is not None + dctx = self._make_context() + await self._on_notify(dctx, method, params) + + +def create_direct_dispatcher_pair( + *, + can_send_request: bool = True, +) -> tuple[DirectDispatcher, DirectDispatcher]: + """Create two `DirectDispatcher` instances wired to each other. + + Args: + can_send_request: Sets `TransportContext.can_send_request` on both + sides. Pass ``False`` to simulate a transport with no back-channel. + + Returns: + A ``(left, right)`` pair. Conventionally ``left`` is the client side + and ``right`` is the server side, but the wiring is symmetric. + """ + ctx = TransportContext(kind=DIRECT_TRANSPORT_KIND, can_send_request=can_send_request) + left = DirectDispatcher(ctx) + right = DirectDispatcher(ctx) + left.connect_to(right) + right.connect_to(left) + return left, right diff --git a/src/mcp/shared/dispatcher.py b/src/mcp/shared/dispatcher.py new file mode 100644 index 0000000000..09e5e87bb6 --- /dev/null +++ b/src/mcp/shared/dispatcher.py @@ -0,0 +1,167 @@ +"""Dispatcher Protocol — the call/return boundary between transports and handlers. + +A Dispatcher turns a duplex message channel into two things: + +* an outbound API: ``call(method, params)`` and ``notify(method, params)`` +* an inbound pump: ``run(on_call, on_notify)`` that drives the receive loop and + invokes the supplied handlers for each incoming request/notification + +It is deliberately *not* MCP-aware. Method names are strings, params and +results are ``dict[str, Any]``. The MCP type layer (request/result models, +capability negotiation, ``Context``) sits above this; the wire encoding +(JSON-RPC, gRPC, in-process direct calls) sits below it. + +See ``JSONRPCDispatcher`` for the production implementation and +``DirectDispatcher`` for an in-memory implementation used in tests and for +embedding a server in-process. +""" + +from collections.abc import Awaitable, Callable, Mapping +from typing import Any, Protocol, TypedDict, TypeVar, runtime_checkable + +import anyio + +from mcp.shared.transport_context import TransportContext + +__all__ = [ + "CallOptions", + "DispatchContext", + "DispatchMiddleware", + "Dispatcher", + "OnCall", + "OnNotify", + "ProgressFnT", + "RequestSender", +] + +TransportT_co = TypeVar("TransportT_co", bound=TransportContext, covariant=True) + + +class ProgressFnT(Protocol): + """Callback invoked when a progress notification arrives for a pending call.""" + + async def __call__(self, progress: float, total: float | None, message: str | None) -> None: ... + + +class CallOptions(TypedDict, total=False): + """Per-call options for `RequestSender.send_request` / `Dispatcher.call`. + + All keys are optional. Dispatchers ignore keys they do not understand. + """ + + timeout: float + """Seconds to wait for a result before raising and sending ``notifications/cancelled``.""" + + on_progress: ProgressFnT + """Receive ``notifications/progress`` updates for this call.""" + + resumption_token: str + """Opaque token to resume a previously interrupted call (transport-dependent).""" + + on_resumption_token: Callable[[str], Awaitable[None]] + """Receive a resumption token when the transport issues one.""" + + +@runtime_checkable +class RequestSender(Protocol): + """Anything that can send a request and await its result. + + Both `Dispatcher` (for top-level outbound calls) and `DispatchContext` + (for server-to-client calls made *during* an inbound request) satisfy this. + """ + + async def send_request( + self, + method: str, + params: Mapping[str, Any] | None, + opts: CallOptions | None = None, + ) -> dict[str, Any]: ... + + +class DispatchContext(Protocol[TransportT_co]): + """Per-request context handed to ``on_call`` / ``on_notify``. + + Carries the transport metadata for the inbound message and provides the + back-channel for sending requests/notifications to the peer while handling + it. + """ + + @property + def transport(self) -> TransportT_co: + """Transport-specific metadata for this inbound message.""" + ... + + @property + def cancel_requested(self) -> anyio.Event: + """Set when the peer sends ``notifications/cancelled`` for this request.""" + ... + + async def notify(self, method: str, params: Mapping[str, Any] | None) -> None: + """Send a notification to the peer.""" + ... + + async def send_request( + self, + method: str, + params: Mapping[str, Any] | None, + opts: CallOptions | None = None, + ) -> dict[str, Any]: + """Send a request to the peer on the back-channel and await its result. + + Raises: + NoBackChannelError: if ``transport.can_send_request`` is ``False``. + """ + ... + + async def progress(self, progress: float, total: float | None = None, message: str | None = None) -> None: + """Report progress for the inbound request, if the peer supplied a progress token. + + A no-op when no token was supplied. + """ + ... + + +OnCall = Callable[[DispatchContext[TransportContext], str, Mapping[str, Any] | None], Awaitable[dict[str, Any]]] +"""Handler for inbound requests: ``(ctx, method, params) -> result``. Raise ``MCPError`` to send an error response.""" + +OnNotify = Callable[[DispatchContext[TransportContext], str, Mapping[str, Any] | None], Awaitable[None]] +"""Handler for inbound notifications: ``(ctx, method, params)``.""" + +DispatchMiddleware = Callable[[OnCall], OnCall] +"""Wraps an ``OnCall`` to produce another ``OnCall``. Applied outermost-first.""" + + +class Dispatcher(Protocol[TransportT_co]): + """A duplex request/notification channel with call-return semantics. + + Implementations own correlation of outbound calls to inbound results, the + receive loop, per-request concurrency, and cancellation/progress wiring. + """ + + async def call( + self, + method: str, + params: Mapping[str, Any] | None, + opts: CallOptions | None = None, + ) -> dict[str, Any]: + """Send a request and await its result. + + Raises: + MCPError: If the peer responded with an error, or the handler + raised. Implementations normalize all handler exceptions to + `MCPError` so callers see a single exception type. + """ + ... + + async def notify(self, method: str, params: Mapping[str, Any] | None) -> None: + """Send a fire-and-forget notification.""" + ... + + async def run(self, on_call: OnCall, on_notify: OnNotify) -> None: + """Drive the receive loop until the underlying channel closes. + + Each inbound request is dispatched to ``on_call`` in its own task; the + returned dict (or raised ``MCPError``) is sent back as the response. + Inbound notifications go to ``on_notify``. + """ + ... diff --git a/src/mcp/shared/exceptions.py b/src/mcp/shared/exceptions.py index f153ea319d..e9dd2c843e 100644 --- a/src/mcp/shared/exceptions.py +++ b/src/mcp/shared/exceptions.py @@ -2,7 +2,7 @@ from typing import Any, cast -from mcp.types import URL_ELICITATION_REQUIRED, ElicitRequestURLParams, ErrorData, JSONRPCError +from mcp.types import INVALID_REQUEST, URL_ELICITATION_REQUIRED, ElicitRequestURLParams, ErrorData, JSONRPCError class MCPError(Exception): @@ -41,6 +41,25 @@ def __str__(self) -> str: return self.message +class NoBackChannelError(MCPError): + """Raised when sending a server-initiated request over a transport that cannot deliver it. + + Stateless HTTP and JSON-response-mode HTTP have no channel for the server to + push requests (sampling, elicitation, roots/list) to the client. This is + raised by `DispatchContext.send_request` when `transport.can_send_request` + is ``False``, and serializes to an ``INVALID_REQUEST`` error response. + """ + + def __init__(self, method: str): + super().__init__( + code=INVALID_REQUEST, + message=( + f"Cannot send {method!r}: this transport context has no back-channel for server-initiated requests." + ), + ) + self.method = method + + class StatelessModeNotSupported(RuntimeError): """Raised when attempting to use a method that is not supported in stateless mode. diff --git a/src/mcp/shared/transport_context.py b/src/mcp/shared/transport_context.py new file mode 100644 index 0000000000..31230fda92 --- /dev/null +++ b/src/mcp/shared/transport_context.py @@ -0,0 +1,30 @@ +"""Transport-specific metadata attached to each inbound message. + +`TransportContext` is the base; each transport defines its own subclass with +whatever fields make sense (HTTP request id, ASGI scope, stdio process handle, +etc.). The dispatcher passes it through opaquely; only the layers above the +dispatcher (`ServerRunner`, `Context`, user handlers) read its concrete fields. +""" + +from dataclasses import dataclass + +__all__ = ["TransportContext"] + + +@dataclass(kw_only=True, frozen=True) +class TransportContext: + """Base transport metadata for an inbound message. + + Subclass per transport and add fields as needed. Instances are immutable. + """ + + kind: str + """Short identifier for the transport (e.g. ``"stdio"``, ``"streamable-http"``).""" + + can_send_request: bool + """Whether the transport can deliver server-initiated requests to the peer. + + ``False`` for stateless HTTP and HTTP with JSON response mode; ``True`` for + stdio, SSE, and stateful streamable HTTP. When ``False``, + `DispatchContext.send_request` raises `NoBackChannelError`. + """ diff --git a/src/mcp/types/__init__.py b/src/mcp/types/__init__.py index b442303937..ca1c328939 100644 --- a/src/mcp/types/__init__.py +++ b/src/mcp/types/__init__.py @@ -192,6 +192,7 @@ INVALID_REQUEST, METHOD_NOT_FOUND, PARSE_ERROR, + REQUEST_CANCELLED, REQUEST_TIMEOUT, URL_ELICITATION_REQUIRED, ErrorData, @@ -401,6 +402,7 @@ "INVALID_REQUEST", "METHOD_NOT_FOUND", "PARSE_ERROR", + "REQUEST_CANCELLED", "REQUEST_TIMEOUT", "URL_ELICITATION_REQUIRED", "ErrorData", diff --git a/src/mcp/types/jsonrpc.py b/src/mcp/types/jsonrpc.py index 84304a37c1..14743c33b0 100644 --- a/src/mcp/types/jsonrpc.py +++ b/src/mcp/types/jsonrpc.py @@ -43,6 +43,7 @@ class JSONRPCResponse(BaseModel): # SDK error codes CONNECTION_CLOSED = -32000 REQUEST_TIMEOUT = -32001 +REQUEST_CANCELLED = -32002 # Standard JSON-RPC error codes PARSE_ERROR = -32700 diff --git a/tests/shared/test_dispatcher.py b/tests/shared/test_dispatcher.py new file mode 100644 index 0000000000..dd8d40721a --- /dev/null +++ b/tests/shared/test_dispatcher.py @@ -0,0 +1,253 @@ +"""Behavioral tests for the Dispatcher Protocol via DirectDispatcher. + +These exercise the `Dispatcher` / `DispatchContext` contract end-to-end using +the in-memory `DirectDispatcher`. JSON-RPC framing is covered separately in +``test_jsonrpc_dispatcher.py``. +""" + +from collections.abc import AsyncIterator, Mapping +from contextlib import asynccontextmanager +from typing import TYPE_CHECKING, Any + +import anyio +import pytest + +from mcp.shared.direct_dispatcher import DirectDispatcher, create_direct_dispatcher_pair +from mcp.shared.dispatcher import DispatchContext, Dispatcher, OnCall, OnNotify +from mcp.shared.exceptions import MCPError, NoBackChannelError +from mcp.shared.transport_context import TransportContext +from mcp.types import INTERNAL_ERROR, INVALID_PARAMS, INVALID_REQUEST, REQUEST_TIMEOUT + + +class Recorder: + def __init__(self) -> None: + self.calls: list[tuple[str, Mapping[str, Any] | None]] = [] + self.notifications: list[tuple[str, Mapping[str, Any] | None]] = [] + self.contexts: list[DispatchContext[TransportContext]] = [] + self.notified = anyio.Event() + + +def echo_handlers(recorder: Recorder) -> tuple[OnCall, OnNotify]: + async def on_call( + ctx: DispatchContext[TransportContext], method: str, params: Mapping[str, Any] | None + ) -> dict[str, Any]: + recorder.calls.append((method, params)) + recorder.contexts.append(ctx) + return {"echoed": method, "params": dict(params or {})} + + async def on_notify(ctx: DispatchContext[TransportContext], method: str, params: Mapping[str, Any] | None) -> None: + recorder.notifications.append((method, params)) + recorder.notified.set() + + return on_call, on_notify + + +@asynccontextmanager +async def running_pair( + *, + server_on_call: OnCall | None = None, + server_on_notify: OnNotify | None = None, + client_on_call: OnCall | None = None, + client_on_notify: OnNotify | None = None, + can_send_request: bool = True, +) -> AsyncIterator[tuple[DirectDispatcher, DirectDispatcher, Recorder, Recorder]]: + """Yield ``(client, server, client_recorder, server_recorder)`` with both ``run()`` loops live.""" + client, server = create_direct_dispatcher_pair(can_send_request=can_send_request) + client_rec, server_rec = Recorder(), Recorder() + c_call, c_notify = echo_handlers(client_rec) + s_call, s_notify = echo_handlers(server_rec) + async with anyio.create_task_group() as tg: + tg.start_soon(client.run, client_on_call or c_call, client_on_notify or c_notify) + tg.start_soon(server.run, server_on_call or s_call, server_on_notify or s_notify) + try: + yield client, server, client_rec, server_rec + finally: + client.close() + server.close() + + +@pytest.mark.anyio +async def test_call_returns_result_from_peer_on_call(): + async with running_pair() as (client, _server, _crec, srec): + with anyio.fail_after(5): + result = await client.call("tools/list", {"cursor": "abc"}) + assert result == {"echoed": "tools/list", "params": {"cursor": "abc"}} + assert srec.calls == [("tools/list", {"cursor": "abc"})] + + +@pytest.mark.anyio +async def test_call_reraises_mcperror_from_handler_unchanged(): + async def on_call( + ctx: DispatchContext[TransportContext], method: str, params: Mapping[str, Any] | None + ) -> dict[str, Any]: + raise MCPError(code=INVALID_PARAMS, message="bad cursor") + + async with running_pair(server_on_call=on_call) as (client, *_): + with anyio.fail_after(5), pytest.raises(MCPError) as exc: + await client.call("tools/list", {}) + assert exc.value.error.code == INVALID_PARAMS + assert exc.value.error.message == "bad cursor" + + +@pytest.mark.anyio +async def test_call_wraps_non_mcperror_exception_as_internal_error(): + async def on_call( + ctx: DispatchContext[TransportContext], method: str, params: Mapping[str, Any] | None + ) -> dict[str, Any]: + raise ValueError("oops") + + async with running_pair(server_on_call=on_call) as (client, *_): + with anyio.fail_after(5), pytest.raises(MCPError) as exc: + await client.call("tools/list", {}) + assert exc.value.error.code == INTERNAL_ERROR + assert isinstance(exc.value.__cause__, ValueError) + + +@pytest.mark.anyio +async def test_call_with_timeout_raises_mcperror_request_timeout(): + async def on_call( + ctx: DispatchContext[TransportContext], method: str, params: Mapping[str, Any] | None + ) -> dict[str, Any]: + await anyio.sleep_forever() + return {} + + async with running_pair(server_on_call=on_call) as (client, *_): + with anyio.fail_after(5), pytest.raises(MCPError) as exc: + await client.call("slow", None, {"timeout": 0}) + assert exc.value.error.code == REQUEST_TIMEOUT + + +@pytest.mark.anyio +async def test_notify_invokes_peer_on_notify(): + async with running_pair() as (client, _server, _crec, srec): + with anyio.fail_after(5): + await client.notify("notifications/initialized", {"v": 1}) + await srec.notified.wait() + assert srec.notifications == [("notifications/initialized", {"v": 1})] + + +@pytest.mark.anyio +async def test_ctx_send_request_round_trips_to_calling_side(): + """A handler's ctx.send_request reaches the side that made the inbound call.""" + + async def server_on_call( + ctx: DispatchContext[TransportContext], method: str, params: Mapping[str, Any] | None + ) -> dict[str, Any]: + sample = await ctx.send_request("sampling/createMessage", {"prompt": "hi"}) + return {"sampled": sample} + + async with running_pair(server_on_call=server_on_call) as (client, _server, crec, _srec): + with anyio.fail_after(5): + result = await client.call("tools/call", None) + assert crec.calls == [("sampling/createMessage", {"prompt": "hi"})] + assert result == {"sampled": {"echoed": "sampling/createMessage", "params": {"prompt": "hi"}}} + + +@pytest.mark.anyio +async def test_ctx_send_request_raises_nobackchannelerror_when_transport_disallows(): + async def server_on_call( + ctx: DispatchContext[TransportContext], method: str, params: Mapping[str, Any] | None + ) -> dict[str, Any]: + await ctx.send_request("sampling/createMessage", None) + return {} + + async with running_pair(server_on_call=server_on_call, can_send_request=False) as (client, *_): + with anyio.fail_after(5), pytest.raises(NoBackChannelError) as exc: + await client.call("tools/call", None) + assert exc.value.method == "sampling/createMessage" + assert exc.value.error.code == INVALID_REQUEST + + +@pytest.mark.anyio +async def test_ctx_notify_invokes_calling_side_on_notify(): + async def server_on_call( + ctx: DispatchContext[TransportContext], method: str, params: Mapping[str, Any] | None + ) -> dict[str, Any]: + await ctx.notify("notifications/message", {"level": "info"}) + return {} + + async with running_pair(server_on_call=server_on_call) as (client, _server, crec, _srec): + with anyio.fail_after(5): + await client.call("tools/call", None) + await crec.notified.wait() + assert crec.notifications == [("notifications/message", {"level": "info"})] + + +@pytest.mark.anyio +async def test_ctx_progress_invokes_caller_on_progress_callback(): + async def server_on_call( + ctx: DispatchContext[TransportContext], method: str, params: Mapping[str, Any] | None + ) -> dict[str, Any]: + await ctx.progress(0.5, total=1.0, message="halfway") + return {} + + received: list[tuple[float, float | None, str | None]] = [] + + async def on_progress(progress: float, total: float | None, message: str | None) -> None: + received.append((progress, total, message)) + + async with running_pair(server_on_call=server_on_call) as (client, *_): + with anyio.fail_after(5): + await client.call("tools/call", None, {"on_progress": on_progress}) + assert received == [(0.5, 1.0, "halfway")] + + +@pytest.mark.anyio +async def test_call_issued_before_peer_run_blocks_until_peer_ready(): + client, server = create_direct_dispatcher_pair() + s_call, s_notify = echo_handlers(Recorder()) + c_call, c_notify = echo_handlers(Recorder()) + + async def late_start(): + await anyio.sleep(0) + await server.run(s_call, s_notify) + + async with anyio.create_task_group() as tg: + tg.start_soon(client.run, c_call, c_notify) + tg.start_soon(late_start) + with anyio.fail_after(5): + result = await client.call("ping", None) + assert result == {"echoed": "ping", "params": {}} + client.close() + server.close() + + +@pytest.mark.anyio +async def test_ctx_progress_is_noop_when_caller_supplied_no_callback(): + async def server_on_call( + ctx: DispatchContext[TransportContext], method: str, params: Mapping[str, Any] | None + ) -> dict[str, Any]: + await ctx.progress(0.5) + return {"ok": True} + + async with running_pair(server_on_call=server_on_call) as (client, *_): + with anyio.fail_after(5): + result = await client.call("tools/call", None) + assert result == {"ok": True} + + +@pytest.mark.anyio +async def test_call_and_notify_raise_runtimeerror_when_no_peer_connected(): + d = DirectDispatcher(TransportContext(kind="direct", can_send_request=True)) + with pytest.raises(RuntimeError, match="no peer"): + await d.call("ping", None) + with pytest.raises(RuntimeError, match="no peer"): + await d.notify("ping", None) + + +@pytest.mark.anyio +async def test_close_makes_run_return(): + client, server = create_direct_dispatcher_pair() + on_call, on_notify = echo_handlers(Recorder()) + with anyio.fail_after(5): + async with anyio.create_task_group() as tg: + tg.start_soon(server.run, on_call, on_notify) + tg.start_soon(client.run, on_call, on_notify) + client.close() + server.close() + + +if TYPE_CHECKING: + _dispatcher_check: Dispatcher[TransportContext] = DirectDispatcher( + TransportContext(kind="direct", can_send_request=True) + ) From 0b98454448919046fb3045e7611d3b530a50c73c Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Thu, 16 Apr 2026 12:03:48 +0000 Subject: [PATCH 02/27] fix: address coverage gaps and stale RequestSender docstring - tests: replace unreachable 'return {}' with 'raise NotImplementedError' (already in coverage exclude_also) and collapse send_request+return into one statement - dispatcher: RequestSender docstring no longer claims Dispatcher satisfies it (Dispatcher exposes call(), not send_request()) --- src/mcp/shared/dispatcher.py | 4 ++-- tests/shared/test_dispatcher.py | 5 ++--- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/src/mcp/shared/dispatcher.py b/src/mcp/shared/dispatcher.py index 09e5e87bb6..b63c00c0bf 100644 --- a/src/mcp/shared/dispatcher.py +++ b/src/mcp/shared/dispatcher.py @@ -66,8 +66,8 @@ class CallOptions(TypedDict, total=False): class RequestSender(Protocol): """Anything that can send a request and await its result. - Both `Dispatcher` (for top-level outbound calls) and `DispatchContext` - (for server-to-client calls made *during* an inbound request) satisfy this. + `DispatchContext` satisfies this; `PeerMixin` (and `Connection`/`Peer`) wrap + a `RequestSender` to provide typed request methods. """ async def send_request( diff --git a/tests/shared/test_dispatcher.py b/tests/shared/test_dispatcher.py index dd8d40721a..ddfe1f798f 100644 --- a/tests/shared/test_dispatcher.py +++ b/tests/shared/test_dispatcher.py @@ -109,7 +109,7 @@ async def on_call( ctx: DispatchContext[TransportContext], method: str, params: Mapping[str, Any] | None ) -> dict[str, Any]: await anyio.sleep_forever() - return {} + raise NotImplementedError async with running_pair(server_on_call=on_call) as (client, *_): with anyio.fail_after(5), pytest.raises(MCPError) as exc: @@ -148,8 +148,7 @@ async def test_ctx_send_request_raises_nobackchannelerror_when_transport_disallo async def server_on_call( ctx: DispatchContext[TransportContext], method: str, params: Mapping[str, Any] | None ) -> dict[str, Any]: - await ctx.send_request("sampling/createMessage", None) - return {} + return await ctx.send_request("sampling/createMessage", None) async with running_pair(server_on_call=server_on_call, can_send_request=False) as (client, *_): with anyio.fail_after(5), pytest.raises(NoBackChannelError) as exc: From 540d8162c74eabe5830c80d032a21e703612fb0e Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Thu, 16 Apr 2026 12:52:58 +0000 Subject: [PATCH 03/27] refactor: rename Dispatcher.call to send_request, replace RequestSender with Outbound MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The design doc's `send_request = call` alias only makes the concrete class satisfy RequestSender, not the abstract Dispatcher Protocol — so any consumer typed against `Dispatcher[TT]` (Connection, ServerRunner) couldn't pass it to something expecting a RequestSender without a cast or hand-written bridge. RequestSender was also half a contract: every implementor (Dispatcher, DispatchContext, Connection, Context) has `notify` too, and PeerMixin needs both for its typed sugar (elicit/sample are requests, log is a notification). Outbound(Protocol) declares both methods; Dispatcher and DispatchContext extend it. PeerMixin will wrap an Outbound. One verb everywhere, no aliases, no extra Protocols. - Dispatcher.call -> send_request - OnCall -> OnRequest, on_call -> on_request - RequestSender -> Outbound (now also declares notify) - Dispatcher(Outbound, Protocol[TT]), DispatchContext(Outbound, Protocol[TT]) --- src/mcp/shared/direct_dispatcher.py | 38 ++++----- src/mcp/shared/dispatcher.py | 100 ++++++++++-------------- tests/shared/test_dispatcher.py | 115 ++++++++++++++-------------- 3 files changed, 115 insertions(+), 138 deletions(-) diff --git a/src/mcp/shared/direct_dispatcher.py b/src/mcp/shared/direct_dispatcher.py index 4650619428..79b68d0547 100644 --- a/src/mcp/shared/direct_dispatcher.py +++ b/src/mcp/shared/direct_dispatcher.py @@ -1,7 +1,7 @@ """In-memory `Dispatcher` that wires two peers together with no transport. -`DirectDispatcher` is the simplest possible `Dispatcher` implementation: a call -on one side directly invokes the other side's `on_call`. There is no +`DirectDispatcher` is the simplest possible `Dispatcher` implementation: a +request on one side directly invokes the other side's `on_request`. There is no serialization, no JSON-RPC framing, and no streams. It exists to: * prove the `Dispatcher` Protocol is implementable without JSON-RPC @@ -21,7 +21,7 @@ import anyio -from mcp.shared.dispatcher import CallOptions, OnCall, OnNotify, ProgressFnT +from mcp.shared.dispatcher import CallOptions, OnNotify, OnRequest, ProgressFnT from mcp.shared.exceptions import MCPError, NoBackChannelError from mcp.shared.transport_context import TransportContext from mcp.types import INTERNAL_ERROR, REQUEST_TIMEOUT @@ -31,20 +31,20 @@ DIRECT_TRANSPORT_KIND = "direct" -_Call = Callable[[str, Mapping[str, Any] | None, CallOptions | None], Awaitable[dict[str, Any]]] +_Request = Callable[[str, Mapping[str, Any] | None, CallOptions | None], Awaitable[dict[str, Any]]] _Notify = Callable[[str, Mapping[str, Any] | None], Awaitable[None]] @dataclass class _DirectDispatchContext: - """`DispatchContext` for an inbound call on a `DirectDispatcher`. + """`DispatchContext` for an inbound request on a `DirectDispatcher`. The back-channel callables target the *originating* side, so a handler's - `send_request` reaches the peer that made the inbound call. + `send_request` reaches the peer that made the inbound request. """ transport: TransportContext - _back_call: _Call + _back_request: _Request _back_notify: _Notify _on_progress: ProgressFnT | None = None cancel_requested: anyio.Event = field(default_factory=anyio.Event) @@ -60,7 +60,7 @@ async def send_request( ) -> dict[str, Any]: if not self.transport.can_send_request: raise NoBackChannelError(method) - return await self._back_call(method, params, opts) + return await self._back_request(method, params, opts) async def progress(self, progress: float, total: float | None = None, message: str | None = None) -> None: if self._on_progress is not None: @@ -71,14 +71,14 @@ class DirectDispatcher: """A `Dispatcher` that calls a peer's handlers directly, in-process. Two instances are wired together with `create_direct_dispatcher_pair`; each - holds a reference to the other. `call` on one awaits the peer's `on_call`. - `run` parks until `close` is called. + holds a reference to the other. `send_request` on one awaits the peer's + `on_request`. `run` parks until `close` is called. """ def __init__(self, transport_ctx: TransportContext): self._transport_ctx = transport_ctx self._peer: DirectDispatcher | None = None - self._on_call: OnCall | None = None + self._on_request: OnRequest | None = None self._on_notify: OnNotify | None = None self._ready = anyio.Event() self._closed = anyio.Event() @@ -86,7 +86,7 @@ def __init__(self, transport_ctx: TransportContext): def connect_to(self, peer: DirectDispatcher) -> None: self._peer = peer - async def call( + async def send_request( self, method: str, params: Mapping[str, Any] | None, @@ -94,15 +94,15 @@ async def call( ) -> dict[str, Any]: if self._peer is None: raise RuntimeError("DirectDispatcher has no peer; use create_direct_dispatcher_pair()") - return await self._peer._dispatch_call(method, params, opts) + return await self._peer._dispatch_request(method, params, opts) async def notify(self, method: str, params: Mapping[str, Any] | None) -> None: if self._peer is None: raise RuntimeError("DirectDispatcher has no peer; use create_direct_dispatcher_pair()") await self._peer._dispatch_notify(method, params) - async def run(self, on_call: OnCall, on_notify: OnNotify) -> None: - self._on_call = on_call + async def run(self, on_request: OnRequest, on_notify: OnNotify) -> None: + self._on_request = on_request self._on_notify = on_notify self._ready.set() await self._closed.wait() @@ -115,25 +115,25 @@ def _make_context(self, on_progress: ProgressFnT | None = None) -> _DirectDispat peer = self._peer return _DirectDispatchContext( transport=self._transport_ctx, - _back_call=lambda m, p, o: peer._dispatch_call(m, p, o), + _back_request=lambda m, p, o: peer._dispatch_request(m, p, o), _back_notify=lambda m, p: peer._dispatch_notify(m, p), _on_progress=on_progress, ) - async def _dispatch_call( + async def _dispatch_request( self, method: str, params: Mapping[str, Any] | None, opts: CallOptions | None, ) -> dict[str, Any]: await self._ready.wait() - assert self._on_call is not None + assert self._on_request is not None opts = opts or {} dctx = self._make_context(on_progress=opts.get("on_progress")) try: with anyio.fail_after(opts.get("timeout")): try: - return await self._on_call(dctx, method, params) + return await self._on_request(dctx, method, params) except MCPError: raise except Exception as e: diff --git a/src/mcp/shared/dispatcher.py b/src/mcp/shared/dispatcher.py index b63c00c0bf..872fb01eaa 100644 --- a/src/mcp/shared/dispatcher.py +++ b/src/mcp/shared/dispatcher.py @@ -2,9 +2,9 @@ A Dispatcher turns a duplex message channel into two things: -* an outbound API: ``call(method, params)`` and ``notify(method, params)`` -* an inbound pump: ``run(on_call, on_notify)`` that drives the receive loop and - invokes the supplied handlers for each incoming request/notification +* an outbound API: ``send_request(method, params)`` and ``notify(method, params)`` +* an inbound pump: ``run(on_request, on_notify)`` that drives the receive loop + and invokes the supplied handlers for each incoming request/notification It is deliberately *not* MCP-aware. Method names are strings, params and results are ``dict[str, Any]``. The MCP type layer (request/result models, @@ -28,23 +28,23 @@ "DispatchContext", "DispatchMiddleware", "Dispatcher", - "OnCall", "OnNotify", + "OnRequest", + "Outbound", "ProgressFnT", - "RequestSender", ] TransportT_co = TypeVar("TransportT_co", bound=TransportContext, covariant=True) class ProgressFnT(Protocol): - """Callback invoked when a progress notification arrives for a pending call.""" + """Callback invoked when a progress notification arrives for a pending request.""" async def __call__(self, progress: float, total: float | None, message: str | None) -> None: ... class CallOptions(TypedDict, total=False): - """Per-call options for `RequestSender.send_request` / `Dispatcher.call`. + """Per-call options for `Outbound.send_request`. All keys are optional. Dispatchers ignore keys they do not understand. """ @@ -53,21 +53,22 @@ class CallOptions(TypedDict, total=False): """Seconds to wait for a result before raising and sending ``notifications/cancelled``.""" on_progress: ProgressFnT - """Receive ``notifications/progress`` updates for this call.""" + """Receive ``notifications/progress`` updates for this request.""" resumption_token: str - """Opaque token to resume a previously interrupted call (transport-dependent).""" + """Opaque token to resume a previously interrupted request (transport-dependent).""" on_resumption_token: Callable[[str], Awaitable[None]] """Receive a resumption token when the transport issues one.""" @runtime_checkable -class RequestSender(Protocol): - """Anything that can send a request and await its result. +class Outbound(Protocol): + """Anything that can send requests and notifications to the peer. - `DispatchContext` satisfies this; `PeerMixin` (and `Connection`/`Peer`) wrap - a `RequestSender` to provide typed request methods. + Both `Dispatcher` (top-level outbound) and `DispatchContext` (back-channel + during an inbound request) extend this. `PeerMixin` wraps an `Outbound` to + provide typed MCP request/notification methods. """ async def send_request( @@ -75,15 +76,28 @@ async def send_request( method: str, params: Mapping[str, Any] | None, opts: CallOptions | None = None, - ) -> dict[str, Any]: ... + ) -> dict[str, Any]: + """Send a request and await its result. + + Raises: + MCPError: If the peer responded with an error, or the handler + raised. Implementations normalize all handler exceptions to + `MCPError` so callers see a single exception type. + """ + ... + async def notify(self, method: str, params: Mapping[str, Any] | None) -> None: + """Send a fire-and-forget notification.""" + ... -class DispatchContext(Protocol[TransportT_co]): - """Per-request context handed to ``on_call`` / ``on_notify``. + +class DispatchContext(Outbound, Protocol[TransportT_co]): + """Per-request context handed to ``on_request`` / ``on_notify``. Carries the transport metadata for the inbound message and provides the back-channel for sending requests/notifications to the peer while handling - it. + it. `send_request` raises `NoBackChannelError` if + ``transport.can_send_request`` is ``False``. """ @property @@ -96,23 +110,6 @@ def cancel_requested(self) -> anyio.Event: """Set when the peer sends ``notifications/cancelled`` for this request.""" ... - async def notify(self, method: str, params: Mapping[str, Any] | None) -> None: - """Send a notification to the peer.""" - ... - - async def send_request( - self, - method: str, - params: Mapping[str, Any] | None, - opts: CallOptions | None = None, - ) -> dict[str, Any]: - """Send a request to the peer on the back-channel and await its result. - - Raises: - NoBackChannelError: if ``transport.can_send_request`` is ``False``. - """ - ... - async def progress(self, progress: float, total: float | None = None, message: str | None = None) -> None: """Report progress for the inbound request, if the peer supplied a progress token. @@ -121,47 +118,28 @@ async def progress(self, progress: float, total: float | None = None, message: s ... -OnCall = Callable[[DispatchContext[TransportContext], str, Mapping[str, Any] | None], Awaitable[dict[str, Any]]] +OnRequest = Callable[[DispatchContext[TransportContext], str, Mapping[str, Any] | None], Awaitable[dict[str, Any]]] """Handler for inbound requests: ``(ctx, method, params) -> result``. Raise ``MCPError`` to send an error response.""" OnNotify = Callable[[DispatchContext[TransportContext], str, Mapping[str, Any] | None], Awaitable[None]] """Handler for inbound notifications: ``(ctx, method, params)``.""" -DispatchMiddleware = Callable[[OnCall], OnCall] -"""Wraps an ``OnCall`` to produce another ``OnCall``. Applied outermost-first.""" +DispatchMiddleware = Callable[[OnRequest], OnRequest] +"""Wraps an ``OnRequest`` to produce another ``OnRequest``. Applied outermost-first.""" -class Dispatcher(Protocol[TransportT_co]): +class Dispatcher(Outbound, Protocol[TransportT_co]): """A duplex request/notification channel with call-return semantics. - Implementations own correlation of outbound calls to inbound results, the + Implementations own correlation of outbound requests to inbound results, the receive loop, per-request concurrency, and cancellation/progress wiring. """ - async def call( - self, - method: str, - params: Mapping[str, Any] | None, - opts: CallOptions | None = None, - ) -> dict[str, Any]: - """Send a request and await its result. - - Raises: - MCPError: If the peer responded with an error, or the handler - raised. Implementations normalize all handler exceptions to - `MCPError` so callers see a single exception type. - """ - ... - - async def notify(self, method: str, params: Mapping[str, Any] | None) -> None: - """Send a fire-and-forget notification.""" - ... - - async def run(self, on_call: OnCall, on_notify: OnNotify) -> None: + async def run(self, on_request: OnRequest, on_notify: OnNotify) -> None: """Drive the receive loop until the underlying channel closes. - Each inbound request is dispatched to ``on_call`` in its own task; the - returned dict (or raised ``MCPError``) is sent back as the response. + Each inbound request is dispatched to ``on_request`` in its own task; + the returned dict (or raised ``MCPError``) is sent back as the response. Inbound notifications go to ``on_notify``. """ ... diff --git a/tests/shared/test_dispatcher.py b/tests/shared/test_dispatcher.py index ddfe1f798f..44ab622ad6 100644 --- a/tests/shared/test_dispatcher.py +++ b/tests/shared/test_dispatcher.py @@ -13,7 +13,7 @@ import pytest from mcp.shared.direct_dispatcher import DirectDispatcher, create_direct_dispatcher_pair -from mcp.shared.dispatcher import DispatchContext, Dispatcher, OnCall, OnNotify +from mcp.shared.dispatcher import DispatchContext, Dispatcher, OnNotify, OnRequest, Outbound from mcp.shared.exceptions import MCPError, NoBackChannelError from mcp.shared.transport_context import TransportContext from mcp.types import INTERNAL_ERROR, INVALID_PARAMS, INVALID_REQUEST, REQUEST_TIMEOUT @@ -21,17 +21,17 @@ class Recorder: def __init__(self) -> None: - self.calls: list[tuple[str, Mapping[str, Any] | None]] = [] + self.requests: list[tuple[str, Mapping[str, Any] | None]] = [] self.notifications: list[tuple[str, Mapping[str, Any] | None]] = [] self.contexts: list[DispatchContext[TransportContext]] = [] self.notified = anyio.Event() -def echo_handlers(recorder: Recorder) -> tuple[OnCall, OnNotify]: - async def on_call( +def echo_handlers(recorder: Recorder) -> tuple[OnRequest, OnNotify]: + async def on_request( ctx: DispatchContext[TransportContext], method: str, params: Mapping[str, Any] | None ) -> dict[str, Any]: - recorder.calls.append((method, params)) + recorder.requests.append((method, params)) recorder.contexts.append(ctx) return {"echoed": method, "params": dict(params or {})} @@ -39,26 +39,26 @@ async def on_notify(ctx: DispatchContext[TransportContext], method: str, params: recorder.notifications.append((method, params)) recorder.notified.set() - return on_call, on_notify + return on_request, on_notify @asynccontextmanager async def running_pair( *, - server_on_call: OnCall | None = None, + server_on_request: OnRequest | None = None, server_on_notify: OnNotify | None = None, - client_on_call: OnCall | None = None, + client_on_request: OnRequest | None = None, client_on_notify: OnNotify | None = None, can_send_request: bool = True, ) -> AsyncIterator[tuple[DirectDispatcher, DirectDispatcher, Recorder, Recorder]]: """Yield ``(client, server, client_recorder, server_recorder)`` with both ``run()`` loops live.""" client, server = create_direct_dispatcher_pair(can_send_request=can_send_request) client_rec, server_rec = Recorder(), Recorder() - c_call, c_notify = echo_handlers(client_rec) - s_call, s_notify = echo_handlers(server_rec) + c_req, c_notify = echo_handlers(client_rec) + s_req, s_notify = echo_handlers(server_rec) async with anyio.create_task_group() as tg: - tg.start_soon(client.run, client_on_call or c_call, client_on_notify or c_notify) - tg.start_soon(server.run, server_on_call or s_call, server_on_notify or s_notify) + tg.start_soon(client.run, client_on_request or c_req, client_on_notify or c_notify) + tg.start_soon(server.run, server_on_request or s_req, server_on_notify or s_notify) try: yield client, server, client_rec, server_rec finally: @@ -67,53 +67,53 @@ async def running_pair( @pytest.mark.anyio -async def test_call_returns_result_from_peer_on_call(): +async def test_send_request_returns_result_from_peer_on_request(): async with running_pair() as (client, _server, _crec, srec): with anyio.fail_after(5): - result = await client.call("tools/list", {"cursor": "abc"}) + result = await client.send_request("tools/list", {"cursor": "abc"}) assert result == {"echoed": "tools/list", "params": {"cursor": "abc"}} - assert srec.calls == [("tools/list", {"cursor": "abc"})] + assert srec.requests == [("tools/list", {"cursor": "abc"})] @pytest.mark.anyio -async def test_call_reraises_mcperror_from_handler_unchanged(): - async def on_call( +async def test_send_request_reraises_mcperror_from_handler_unchanged(): + async def on_request( ctx: DispatchContext[TransportContext], method: str, params: Mapping[str, Any] | None ) -> dict[str, Any]: raise MCPError(code=INVALID_PARAMS, message="bad cursor") - async with running_pair(server_on_call=on_call) as (client, *_): + async with running_pair(server_on_request=on_request) as (client, *_): with anyio.fail_after(5), pytest.raises(MCPError) as exc: - await client.call("tools/list", {}) + await client.send_request("tools/list", {}) assert exc.value.error.code == INVALID_PARAMS assert exc.value.error.message == "bad cursor" @pytest.mark.anyio -async def test_call_wraps_non_mcperror_exception_as_internal_error(): - async def on_call( +async def test_send_request_wraps_non_mcperror_exception_as_internal_error(): + async def on_request( ctx: DispatchContext[TransportContext], method: str, params: Mapping[str, Any] | None ) -> dict[str, Any]: raise ValueError("oops") - async with running_pair(server_on_call=on_call) as (client, *_): + async with running_pair(server_on_request=on_request) as (client, *_): with anyio.fail_after(5), pytest.raises(MCPError) as exc: - await client.call("tools/list", {}) + await client.send_request("tools/list", {}) assert exc.value.error.code == INTERNAL_ERROR assert isinstance(exc.value.__cause__, ValueError) @pytest.mark.anyio -async def test_call_with_timeout_raises_mcperror_request_timeout(): - async def on_call( +async def test_send_request_with_timeout_raises_mcperror_request_timeout(): + async def on_request( ctx: DispatchContext[TransportContext], method: str, params: Mapping[str, Any] | None ) -> dict[str, Any]: await anyio.sleep_forever() raise NotImplementedError - async with running_pair(server_on_call=on_call) as (client, *_): + async with running_pair(server_on_request=on_request) as (client, *_): with anyio.fail_after(5), pytest.raises(MCPError) as exc: - await client.call("slow", None, {"timeout": 0}) + await client.send_request("slow", None, {"timeout": 0}) assert exc.value.error.code == REQUEST_TIMEOUT @@ -128,53 +128,53 @@ async def test_notify_invokes_peer_on_notify(): @pytest.mark.anyio async def test_ctx_send_request_round_trips_to_calling_side(): - """A handler's ctx.send_request reaches the side that made the inbound call.""" + """A handler's ctx.send_request reaches the side that made the inbound request.""" - async def server_on_call( + async def server_on_request( ctx: DispatchContext[TransportContext], method: str, params: Mapping[str, Any] | None ) -> dict[str, Any]: sample = await ctx.send_request("sampling/createMessage", {"prompt": "hi"}) return {"sampled": sample} - async with running_pair(server_on_call=server_on_call) as (client, _server, crec, _srec): + async with running_pair(server_on_request=server_on_request) as (client, _server, crec, _srec): with anyio.fail_after(5): - result = await client.call("tools/call", None) - assert crec.calls == [("sampling/createMessage", {"prompt": "hi"})] + result = await client.send_request("tools/call", None) + assert crec.requests == [("sampling/createMessage", {"prompt": "hi"})] assert result == {"sampled": {"echoed": "sampling/createMessage", "params": {"prompt": "hi"}}} @pytest.mark.anyio async def test_ctx_send_request_raises_nobackchannelerror_when_transport_disallows(): - async def server_on_call( + async def server_on_request( ctx: DispatchContext[TransportContext], method: str, params: Mapping[str, Any] | None ) -> dict[str, Any]: return await ctx.send_request("sampling/createMessage", None) - async with running_pair(server_on_call=server_on_call, can_send_request=False) as (client, *_): + async with running_pair(server_on_request=server_on_request, can_send_request=False) as (client, *_): with anyio.fail_after(5), pytest.raises(NoBackChannelError) as exc: - await client.call("tools/call", None) + await client.send_request("tools/call", None) assert exc.value.method == "sampling/createMessage" assert exc.value.error.code == INVALID_REQUEST @pytest.mark.anyio async def test_ctx_notify_invokes_calling_side_on_notify(): - async def server_on_call( + async def server_on_request( ctx: DispatchContext[TransportContext], method: str, params: Mapping[str, Any] | None ) -> dict[str, Any]: await ctx.notify("notifications/message", {"level": "info"}) return {} - async with running_pair(server_on_call=server_on_call) as (client, _server, crec, _srec): + async with running_pair(server_on_request=server_on_request) as (client, _server, crec, _srec): with anyio.fail_after(5): - await client.call("tools/call", None) + await client.send_request("tools/call", None) await crec.notified.wait() assert crec.notifications == [("notifications/message", {"level": "info"})] @pytest.mark.anyio async def test_ctx_progress_invokes_caller_on_progress_callback(): - async def server_on_call( + async def server_on_request( ctx: DispatchContext[TransportContext], method: str, params: Mapping[str, Any] | None ) -> dict[str, Any]: await ctx.progress(0.5, total=1.0, message="halfway") @@ -185,27 +185,27 @@ async def server_on_call( async def on_progress(progress: float, total: float | None, message: str | None) -> None: received.append((progress, total, message)) - async with running_pair(server_on_call=server_on_call) as (client, *_): + async with running_pair(server_on_request=server_on_request) as (client, *_): with anyio.fail_after(5): - await client.call("tools/call", None, {"on_progress": on_progress}) + await client.send_request("tools/call", None, {"on_progress": on_progress}) assert received == [(0.5, 1.0, "halfway")] @pytest.mark.anyio -async def test_call_issued_before_peer_run_blocks_until_peer_ready(): +async def test_send_request_issued_before_peer_run_blocks_until_peer_ready(): client, server = create_direct_dispatcher_pair() - s_call, s_notify = echo_handlers(Recorder()) - c_call, c_notify = echo_handlers(Recorder()) + s_req, s_notify = echo_handlers(Recorder()) + c_req, c_notify = echo_handlers(Recorder()) async def late_start(): await anyio.sleep(0) - await server.run(s_call, s_notify) + await server.run(s_req, s_notify) async with anyio.create_task_group() as tg: - tg.start_soon(client.run, c_call, c_notify) + tg.start_soon(client.run, c_req, c_notify) tg.start_soon(late_start) with anyio.fail_after(5): - result = await client.call("ping", None) + result = await client.send_request("ping", None) assert result == {"echoed": "ping", "params": {}} client.close() server.close() @@ -213,23 +213,23 @@ async def late_start(): @pytest.mark.anyio async def test_ctx_progress_is_noop_when_caller_supplied_no_callback(): - async def server_on_call( + async def server_on_request( ctx: DispatchContext[TransportContext], method: str, params: Mapping[str, Any] | None ) -> dict[str, Any]: await ctx.progress(0.5) return {"ok": True} - async with running_pair(server_on_call=server_on_call) as (client, *_): + async with running_pair(server_on_request=server_on_request) as (client, *_): with anyio.fail_after(5): - result = await client.call("tools/call", None) + result = await client.send_request("tools/call", None) assert result == {"ok": True} @pytest.mark.anyio -async def test_call_and_notify_raise_runtimeerror_when_no_peer_connected(): +async def test_send_request_and_notify_raise_runtimeerror_when_no_peer_connected(): d = DirectDispatcher(TransportContext(kind="direct", can_send_request=True)) with pytest.raises(RuntimeError, match="no peer"): - await d.call("ping", None) + await d.send_request("ping", None) with pytest.raises(RuntimeError, match="no peer"): await d.notify("ping", None) @@ -237,16 +237,15 @@ async def test_call_and_notify_raise_runtimeerror_when_no_peer_connected(): @pytest.mark.anyio async def test_close_makes_run_return(): client, server = create_direct_dispatcher_pair() - on_call, on_notify = echo_handlers(Recorder()) + on_request, on_notify = echo_handlers(Recorder()) with anyio.fail_after(5): async with anyio.create_task_group() as tg: - tg.start_soon(server.run, on_call, on_notify) - tg.start_soon(client.run, on_call, on_notify) + tg.start_soon(server.run, on_request, on_notify) + tg.start_soon(client.run, on_request, on_notify) client.close() server.close() if TYPE_CHECKING: - _dispatcher_check: Dispatcher[TransportContext] = DirectDispatcher( - TransportContext(kind="direct", can_send_request=True) - ) + _d: Dispatcher[TransportContext] = DirectDispatcher(TransportContext(kind="direct", can_send_request=True)) + _o: Outbound = _d From ebe9a9939493ebf59e0b486b486c7d5a9242f5df Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Thu, 16 Apr 2026 21:38:32 +0000 Subject: [PATCH 04/27] refactor: rename Outbound.send_request to send_raw_request MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The dispatcher-layer raw channel is now `send_raw_request(method, params) -> dict`. This frees the `send_request` name for the typed surface (`send_request(req: Request) -> Result`) that Connection/Context/Client add in later PRs. Mechanical rename across Outbound, Dispatcher, DispatchContext, DirectDispatcher, _DirectDispatchContext, and all tests. `can_send_request` (the transport capability flag) is unchanged — it names the capability, not the method. --- src/mcp/shared/direct_dispatcher.py | 8 +++--- src/mcp/shared/dispatcher.py | 15 +++++----- src/mcp/shared/exceptions.py | 2 +- src/mcp/shared/transport_context.py | 2 +- tests/shared/test_dispatcher.py | 44 ++++++++++++++--------------- 5 files changed, 36 insertions(+), 35 deletions(-) diff --git a/src/mcp/shared/direct_dispatcher.py b/src/mcp/shared/direct_dispatcher.py index 79b68d0547..bb5639a136 100644 --- a/src/mcp/shared/direct_dispatcher.py +++ b/src/mcp/shared/direct_dispatcher.py @@ -40,7 +40,7 @@ class _DirectDispatchContext: """`DispatchContext` for an inbound request on a `DirectDispatcher`. The back-channel callables target the *originating* side, so a handler's - `send_request` reaches the peer that made the inbound request. + `send_raw_request` reaches the peer that made the inbound request. """ transport: TransportContext @@ -52,7 +52,7 @@ class _DirectDispatchContext: async def notify(self, method: str, params: Mapping[str, Any] | None) -> None: await self._back_notify(method, params) - async def send_request( + async def send_raw_request( self, method: str, params: Mapping[str, Any] | None, @@ -71,7 +71,7 @@ class DirectDispatcher: """A `Dispatcher` that calls a peer's handlers directly, in-process. Two instances are wired together with `create_direct_dispatcher_pair`; each - holds a reference to the other. `send_request` on one awaits the peer's + holds a reference to the other. `send_raw_request` on one awaits the peer's `on_request`. `run` parks until `close` is called. """ @@ -86,7 +86,7 @@ def __init__(self, transport_ctx: TransportContext): def connect_to(self, peer: DirectDispatcher) -> None: self._peer = peer - async def send_request( + async def send_raw_request( self, method: str, params: Mapping[str, Any] | None, diff --git a/src/mcp/shared/dispatcher.py b/src/mcp/shared/dispatcher.py index 872fb01eaa..ee02e23896 100644 --- a/src/mcp/shared/dispatcher.py +++ b/src/mcp/shared/dispatcher.py @@ -2,7 +2,7 @@ A Dispatcher turns a duplex message channel into two things: -* an outbound API: ``send_request(method, params)`` and ``notify(method, params)`` +* an outbound API: ``send_raw_request(method, params)`` and ``notify(method, params)`` * an inbound pump: ``run(on_request, on_notify)`` that drives the receive loop and invokes the supplied handlers for each incoming request/notification @@ -44,7 +44,7 @@ async def __call__(self, progress: float, total: float | None, message: str | No class CallOptions(TypedDict, total=False): - """Per-call options for `Outbound.send_request`. + """Per-call options for `Outbound.send_raw_request`. All keys are optional. Dispatchers ignore keys they do not understand. """ @@ -67,17 +67,18 @@ class Outbound(Protocol): """Anything that can send requests and notifications to the peer. Both `Dispatcher` (top-level outbound) and `DispatchContext` (back-channel - during an inbound request) extend this. `PeerMixin` wraps an `Outbound` to - provide typed MCP request/notification methods. + during an inbound request) extend this. The MCP type layer (`PeerMixin`, + `Connection`, `Context`) builds typed ``send_request`` / convenience methods + on top of this raw channel. """ - async def send_request( + async def send_raw_request( self, method: str, params: Mapping[str, Any] | None, opts: CallOptions | None = None, ) -> dict[str, Any]: - """Send a request and await its result. + """Send a request and await its raw result dict. Raises: MCPError: If the peer responded with an error, or the handler @@ -96,7 +97,7 @@ class DispatchContext(Outbound, Protocol[TransportT_co]): Carries the transport metadata for the inbound message and provides the back-channel for sending requests/notifications to the peer while handling - it. `send_request` raises `NoBackChannelError` if + it. `send_raw_request` raises `NoBackChannelError` if ``transport.can_send_request`` is ``False``. """ diff --git a/src/mcp/shared/exceptions.py b/src/mcp/shared/exceptions.py index e9dd2c843e..b62629b6c8 100644 --- a/src/mcp/shared/exceptions.py +++ b/src/mcp/shared/exceptions.py @@ -46,7 +46,7 @@ class NoBackChannelError(MCPError): Stateless HTTP and JSON-response-mode HTTP have no channel for the server to push requests (sampling, elicitation, roots/list) to the client. This is - raised by `DispatchContext.send_request` when `transport.can_send_request` + raised by `DispatchContext.send_raw_request` when `transport.can_send_request` is ``False``, and serializes to an ``INVALID_REQUEST`` error response. """ diff --git a/src/mcp/shared/transport_context.py b/src/mcp/shared/transport_context.py index 31230fda92..832cead515 100644 --- a/src/mcp/shared/transport_context.py +++ b/src/mcp/shared/transport_context.py @@ -26,5 +26,5 @@ class TransportContext: ``False`` for stateless HTTP and HTTP with JSON response mode; ``True`` for stdio, SSE, and stateful streamable HTTP. When ``False``, - `DispatchContext.send_request` raises `NoBackChannelError`. + `DispatchContext.send_raw_request` raises `NoBackChannelError`. """ diff --git a/tests/shared/test_dispatcher.py b/tests/shared/test_dispatcher.py index 44ab622ad6..784ef6698f 100644 --- a/tests/shared/test_dispatcher.py +++ b/tests/shared/test_dispatcher.py @@ -67,16 +67,16 @@ async def running_pair( @pytest.mark.anyio -async def test_send_request_returns_result_from_peer_on_request(): +async def test_send_raw_request_returns_result_from_peer_on_request(): async with running_pair() as (client, _server, _crec, srec): with anyio.fail_after(5): - result = await client.send_request("tools/list", {"cursor": "abc"}) + result = await client.send_raw_request("tools/list", {"cursor": "abc"}) assert result == {"echoed": "tools/list", "params": {"cursor": "abc"}} assert srec.requests == [("tools/list", {"cursor": "abc"})] @pytest.mark.anyio -async def test_send_request_reraises_mcperror_from_handler_unchanged(): +async def test_send_raw_request_reraises_mcperror_from_handler_unchanged(): async def on_request( ctx: DispatchContext[TransportContext], method: str, params: Mapping[str, Any] | None ) -> dict[str, Any]: @@ -84,13 +84,13 @@ async def on_request( async with running_pair(server_on_request=on_request) as (client, *_): with anyio.fail_after(5), pytest.raises(MCPError) as exc: - await client.send_request("tools/list", {}) + await client.send_raw_request("tools/list", {}) assert exc.value.error.code == INVALID_PARAMS assert exc.value.error.message == "bad cursor" @pytest.mark.anyio -async def test_send_request_wraps_non_mcperror_exception_as_internal_error(): +async def test_send_raw_request_wraps_non_mcperror_exception_as_internal_error(): async def on_request( ctx: DispatchContext[TransportContext], method: str, params: Mapping[str, Any] | None ) -> dict[str, Any]: @@ -98,13 +98,13 @@ async def on_request( async with running_pair(server_on_request=on_request) as (client, *_): with anyio.fail_after(5), pytest.raises(MCPError) as exc: - await client.send_request("tools/list", {}) + await client.send_raw_request("tools/list", {}) assert exc.value.error.code == INTERNAL_ERROR assert isinstance(exc.value.__cause__, ValueError) @pytest.mark.anyio -async def test_send_request_with_timeout_raises_mcperror_request_timeout(): +async def test_send_raw_request_with_timeout_raises_mcperror_request_timeout(): async def on_request( ctx: DispatchContext[TransportContext], method: str, params: Mapping[str, Any] | None ) -> dict[str, Any]: @@ -113,7 +113,7 @@ async def on_request( async with running_pair(server_on_request=on_request) as (client, *_): with anyio.fail_after(5), pytest.raises(MCPError) as exc: - await client.send_request("slow", None, {"timeout": 0}) + await client.send_raw_request("slow", None, {"timeout": 0}) assert exc.value.error.code == REQUEST_TIMEOUT @@ -127,32 +127,32 @@ async def test_notify_invokes_peer_on_notify(): @pytest.mark.anyio -async def test_ctx_send_request_round_trips_to_calling_side(): - """A handler's ctx.send_request reaches the side that made the inbound request.""" +async def test_ctx_send_raw_request_round_trips_to_calling_side(): + """A handler's ctx.send_raw_request reaches the side that made the inbound request.""" async def server_on_request( ctx: DispatchContext[TransportContext], method: str, params: Mapping[str, Any] | None ) -> dict[str, Any]: - sample = await ctx.send_request("sampling/createMessage", {"prompt": "hi"}) + sample = await ctx.send_raw_request("sampling/createMessage", {"prompt": "hi"}) return {"sampled": sample} async with running_pair(server_on_request=server_on_request) as (client, _server, crec, _srec): with anyio.fail_after(5): - result = await client.send_request("tools/call", None) + result = await client.send_raw_request("tools/call", None) assert crec.requests == [("sampling/createMessage", {"prompt": "hi"})] assert result == {"sampled": {"echoed": "sampling/createMessage", "params": {"prompt": "hi"}}} @pytest.mark.anyio -async def test_ctx_send_request_raises_nobackchannelerror_when_transport_disallows(): +async def test_ctx_send_raw_request_raises_nobackchannelerror_when_transport_disallows(): async def server_on_request( ctx: DispatchContext[TransportContext], method: str, params: Mapping[str, Any] | None ) -> dict[str, Any]: - return await ctx.send_request("sampling/createMessage", None) + return await ctx.send_raw_request("sampling/createMessage", None) async with running_pair(server_on_request=server_on_request, can_send_request=False) as (client, *_): with anyio.fail_after(5), pytest.raises(NoBackChannelError) as exc: - await client.send_request("tools/call", None) + await client.send_raw_request("tools/call", None) assert exc.value.method == "sampling/createMessage" assert exc.value.error.code == INVALID_REQUEST @@ -167,7 +167,7 @@ async def server_on_request( async with running_pair(server_on_request=server_on_request) as (client, _server, crec, _srec): with anyio.fail_after(5): - await client.send_request("tools/call", None) + await client.send_raw_request("tools/call", None) await crec.notified.wait() assert crec.notifications == [("notifications/message", {"level": "info"})] @@ -187,12 +187,12 @@ async def on_progress(progress: float, total: float | None, message: str | None) async with running_pair(server_on_request=server_on_request) as (client, *_): with anyio.fail_after(5): - await client.send_request("tools/call", None, {"on_progress": on_progress}) + await client.send_raw_request("tools/call", None, {"on_progress": on_progress}) assert received == [(0.5, 1.0, "halfway")] @pytest.mark.anyio -async def test_send_request_issued_before_peer_run_blocks_until_peer_ready(): +async def test_send_raw_request_issued_before_peer_run_blocks_until_peer_ready(): client, server = create_direct_dispatcher_pair() s_req, s_notify = echo_handlers(Recorder()) c_req, c_notify = echo_handlers(Recorder()) @@ -205,7 +205,7 @@ async def late_start(): tg.start_soon(client.run, c_req, c_notify) tg.start_soon(late_start) with anyio.fail_after(5): - result = await client.send_request("ping", None) + result = await client.send_raw_request("ping", None) assert result == {"echoed": "ping", "params": {}} client.close() server.close() @@ -221,15 +221,15 @@ async def server_on_request( async with running_pair(server_on_request=server_on_request) as (client, *_): with anyio.fail_after(5): - result = await client.send_request("tools/call", None) + result = await client.send_raw_request("tools/call", None) assert result == {"ok": True} @pytest.mark.anyio -async def test_send_request_and_notify_raise_runtimeerror_when_no_peer_connected(): +async def test_send_raw_request_and_notify_raise_runtimeerror_when_no_peer_connected(): d = DirectDispatcher(TransportContext(kind="direct", can_send_request=True)) with pytest.raises(RuntimeError, match="no peer"): - await d.send_request("ping", None) + await d.send_raw_request("ping", None) with pytest.raises(RuntimeError, match="no peer"): await d.notify("ping", None) From 3b7fbc7067047405a81bb7267bb21e1eecaa7a0f Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Thu, 16 Apr 2026 13:26:22 +0000 Subject: [PATCH 05/27] feat: JSONRPCDispatcher outbound side + parametrized contract tests Chunk (a) of JSONRPCDispatcher: constructor, _Pending/_InFlight/_JSONRPCDispatchContext, send_request/notify and helpers. run() is stubbed. The Dispatcher contract tests are now parametrized over a pair_factory fixture (direct + jsonrpc). The 9 jsonrpc cases are strict-xfail until run()/ _handle_request land in the next commits; once those pass, strict xfail flips to XPASS and forces removal of the marker. Factories return (client, server, close) so running_pair can shut down any implementation uniformly. --- src/mcp/shared/jsonrpc_dispatcher.py | 283 +++++++++++++++++++++++++++ tests/shared/conftest.py | 67 +++++++ tests/shared/test_dispatcher.py | 135 +++++++------ 3 files changed, 421 insertions(+), 64 deletions(-) create mode 100644 src/mcp/shared/jsonrpc_dispatcher.py create mode 100644 tests/shared/conftest.py diff --git a/src/mcp/shared/jsonrpc_dispatcher.py b/src/mcp/shared/jsonrpc_dispatcher.py new file mode 100644 index 0000000000..2a6e0951b8 --- /dev/null +++ b/src/mcp/shared/jsonrpc_dispatcher.py @@ -0,0 +1,283 @@ +"""JSON-RPC `Dispatcher` implementation. + +Consumes the existing `SessionMessage`-based stream contract that all current +transports (stdio, SSE, streamable HTTP) speak. Owns request-id correlation, +the receive loop, per-request task isolation, cancellation/progress wiring, and +the single exception-to-wire boundary. + +The MCP type layer (`ServerRunner`, `Context`, `Client`) sits above this and +sees only `(ctx, method, params) -> dict`. Transports sit below and see only +`SessionMessage` reads/writes. +""" + +from __future__ import annotations + +import logging +from collections.abc import Callable, Mapping +from dataclasses import dataclass, field +from typing import Any, Generic, Literal, TypeVar, overload + +import anyio +from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream + +from mcp.shared._stream_protocols import ReadStream, WriteStream +from mcp.shared.dispatcher import CallOptions, OnNotify, OnRequest, ProgressFnT +from mcp.shared.exceptions import MCPError, NoBackChannelError +from mcp.shared.message import ( + ClientMessageMetadata, + MessageMetadata, + ServerMessageMetadata, + SessionMessage, +) +from mcp.shared.transport_context import TransportContext +from mcp.types import ( + REQUEST_TIMEOUT, + ErrorData, + JSONRPCMessage, + JSONRPCNotification, + JSONRPCRequest, + ProgressToken, + RequestId, +) + +__all__ = ["JSONRPCDispatcher"] + +logger = logging.getLogger(__name__) + +TransportT = TypeVar("TransportT", bound=TransportContext) + +PeerCancelMode = Literal["interrupt", "signal"] +"""How inbound ``notifications/cancelled`` is applied to a running handler. + +``"interrupt"`` (default) cancels the handler's scope. ``"signal"`` only sets +``ctx.cancel_requested`` and lets the handler observe it cooperatively. +""" + +TransportBuilder = Callable[[RequestId | None, MessageMetadata], TransportContext] +"""Builds the per-message `TransportContext` from the inbound JSON-RPC id and +the `SessionMessage.metadata` the transport attached. Defaults to a plain +`TransportContext(kind="jsonrpc", can_send_request=True)` when not supplied.""" + + +@dataclass(slots=True) +class _Pending: + """An outbound request awaiting its response.""" + + send: MemoryObjectSendStream[dict[str, Any] | ErrorData] + receive: MemoryObjectReceiveStream[dict[str, Any] | ErrorData] + on_progress: ProgressFnT | None = None + + +@dataclass(slots=True) +class _InFlight(Generic[TransportT]): + """An inbound request currently being handled.""" + + scope: anyio.CancelScope + dctx: _JSONRPCDispatchContext[TransportT] + cancelled_by_peer: bool = False + + +@dataclass +class _JSONRPCDispatchContext(Generic[TransportT]): + """Concrete `DispatchContext` produced for each inbound JSON-RPC message.""" + + transport: TransportT + _dispatcher: JSONRPCDispatcher[TransportT] + _request_id: RequestId | None + _progress_token: ProgressToken | None = None + _closed: bool = False + cancel_requested: anyio.Event = field(default_factory=anyio.Event) + + @property + def can_send_request(self) -> bool: + return self.transport.can_send_request and not self._closed + + async def notify(self, method: str, params: Mapping[str, Any] | None) -> None: + await self._dispatcher.notify(method, params, _related_request_id=self._request_id) + + async def send_request( + self, + method: str, + params: Mapping[str, Any] | None, + opts: CallOptions | None = None, + ) -> dict[str, Any]: + if not self.can_send_request: + raise NoBackChannelError(method) + return await self._dispatcher.send_request(method, params, opts, _related_request_id=self._request_id) + + async def progress(self, progress: float, total: float | None = None, message: str | None = None) -> None: + if self._progress_token is None: + return + params: dict[str, Any] = {"progressToken": self._progress_token, "progress": progress} + if total is not None: + params["total"] = total + if message is not None: + params["message"] = message + await self.notify("notifications/progress", params) + + def close(self) -> None: + self._closed = True + + +def _default_transport_builder(_request_id: RequestId | None, _meta: MessageMetadata) -> TransportContext: + return TransportContext(kind="jsonrpc", can_send_request=True) + + +def _outbound_metadata(related_request_id: RequestId | None, opts: CallOptions | None) -> MessageMetadata: + """Choose the `SessionMessage.metadata` for an outgoing request/notification. + + `ServerMessageMetadata` tags a server-to-client message with the inbound + request it belongs to (so streamable-HTTP can route it onto that request's + SSE stream). `ClientMessageMetadata` carries resumption hints to the + client transport. ``None`` is the common case. + """ + if related_request_id is not None: + return ServerMessageMetadata(related_request_id=related_request_id) + if opts: + token = opts.get("resumption_token") + on_token = opts.get("on_resumption_token") + if token is not None or on_token is not None: + return ClientMessageMetadata(resumption_token=token, on_resumption_token_update=on_token) + return None + + +class JSONRPCDispatcher(Generic[TransportT]): + """`Dispatcher` over the existing `SessionMessage` stream contract.""" + + @overload + def __init__( + self: JSONRPCDispatcher[TransportContext], + read_stream: ReadStream[SessionMessage | Exception], + write_stream: WriteStream[SessionMessage], + ) -> None: ... + @overload + def __init__( + self, + read_stream: ReadStream[SessionMessage | Exception], + write_stream: WriteStream[SessionMessage], + *, + transport_builder: Callable[[RequestId | None, MessageMetadata], TransportT], + peer_cancel_mode: PeerCancelMode = "interrupt", + raise_handler_exceptions: bool = False, + ) -> None: ... + def __init__( + self, + read_stream: ReadStream[SessionMessage | Exception], + write_stream: WriteStream[SessionMessage], + *, + transport_builder: Callable[[RequestId | None, MessageMetadata], TransportT] | None = None, + peer_cancel_mode: PeerCancelMode = "interrupt", + raise_handler_exceptions: bool = False, + ) -> None: + self._read_stream = read_stream + self._write_stream = write_stream + self._transport_builder = transport_builder or _default_transport_builder + self._peer_cancel_mode: PeerCancelMode = peer_cancel_mode + self._raise_handler_exceptions = raise_handler_exceptions + + self._next_id = 0 + self._pending: dict[RequestId, _Pending] = {} + self._in_flight: dict[RequestId, _InFlight[TransportT]] = {} + self._running = False + + async def send_request( + self, + method: str, + params: Mapping[str, Any] | None, + opts: CallOptions | None = None, + *, + _related_request_id: RequestId | None = None, + ) -> dict[str, Any]: + """Send a JSON-RPC request and await its response. + + ``_related_request_id`` is set only by `_JSONRPCDispatchContext` when a + handler makes a server-to-client request mid-flight; it routes the + outgoing message onto the correct per-request SSE stream (SHTTP) via + `ServerMessageMetadata`. Top-level callers leave it ``None``. + + Raises: + MCPError: The peer responded with a JSON-RPC error; or + ``REQUEST_TIMEOUT`` if ``opts["timeout"]`` elapsed; or + ``CONNECTION_CLOSED`` if the dispatcher shut down while + awaiting the response. + RuntimeError: Called before ``run()`` has started or after it has + finished. + """ + if not self._running: + raise RuntimeError("JSONRPCDispatcher.send_request called before run() / after close") + opts = opts or {} + request_id = self._allocate_id() + out_params = dict(params) if params is not None else None + on_progress = opts.get("on_progress") + if on_progress is not None: + # The caller wants progress updates. The spec mechanism is: include + # `_meta.progressToken` on the request; the peer echoes that token on + # any `notifications/progress` it sends. We use the request id as the + # token so the receive loop can find this `_Pending.on_progress` by + # `_pending[token]` without a second lookup table. + meta = dict((out_params or {}).get("_meta") or {}) + meta["progressToken"] = request_id + out_params = {**(out_params or {}), "_meta": meta} + + send, receive = anyio.create_memory_object_stream[dict[str, Any] | ErrorData](1) + pending = _Pending(send=send, receive=receive, on_progress=on_progress) + self._pending[request_id] = pending + + metadata = _outbound_metadata(_related_request_id, opts) + msg = JSONRPCRequest(jsonrpc="2.0", id=request_id, method=method, params=out_params) + try: + await self._write(msg, metadata) + with anyio.fail_after(opts.get("timeout")): + outcome = await receive.receive() + except TimeoutError: + # Spec-recommended courtesy: tell the peer we've given up so it can + # stop work and free resources. v1's BaseSession.send_request does + # NOT do this; it's new behaviour. + await self._cancel_outbound(request_id, f"timed out after {opts.get('timeout')}s") + raise MCPError(code=REQUEST_TIMEOUT, message=f"Request {method!r} timed out") from None + except anyio.get_cancelled_exc_class(): + # Our caller's scope was cancelled. We're already inside a cancelled + # scope, so any bare `await` here re-raises immediately — shield to + # let the courtesy cancel notification go out before we propagate. + with anyio.CancelScope(shield=True): + await self._cancel_outbound(request_id, "caller cancelled") + raise + finally: + # Always remove the waiter, even on cancel/timeout, so a late + # response from the peer (race) hits a closed stream and is dropped + # in `_dispatch` rather than leaking. + self._pending.pop(request_id, None) + send.close() + receive.close() + + if isinstance(outcome, ErrorData): + raise MCPError(code=outcome.code, message=outcome.message, data=outcome.data) + return outcome + + async def notify( + self, + method: str, + params: Mapping[str, Any] | None, + *, + _related_request_id: RequestId | None = None, + ) -> None: + msg = JSONRPCNotification(jsonrpc="2.0", method=method, params=dict(params) if params is not None else None) + await self._write(msg, _outbound_metadata(_related_request_id, None)) + + async def run(self, on_request: OnRequest, on_notify: OnNotify) -> None: + raise NotImplementedError # chunk (b) + + def _allocate_id(self) -> int: + self._next_id += 1 + return self._next_id + + async def _write(self, message: JSONRPCMessage, metadata: MessageMetadata = None) -> None: + await self._write_stream.send(SessionMessage(message=message, metadata=metadata)) + + async def _cancel_outbound(self, request_id: RequestId, reason: str) -> None: + try: + await self.notify("notifications/cancelled", {"requestId": request_id, "reason": reason}) + except anyio.BrokenResourceError: + pass + except anyio.ClosedResourceError: + pass diff --git a/tests/shared/conftest.py b/tests/shared/conftest.py new file mode 100644 index 0000000000..ffa254804e --- /dev/null +++ b/tests/shared/conftest.py @@ -0,0 +1,67 @@ +"""Shared fixtures for `Dispatcher` contract tests. + +The `pair_factory` fixture parametrizes contract tests over every `Dispatcher` +implementation, so the same behavioral assertions run against `DirectDispatcher` +(in-memory) and `JSONRPCDispatcher` (over crossed anyio memory streams). +""" + +from collections.abc import Callable + +import anyio +import pytest + +from mcp.shared.direct_dispatcher import create_direct_dispatcher_pair +from mcp.shared.dispatcher import Dispatcher +from mcp.shared.jsonrpc_dispatcher import JSONRPCDispatcher +from mcp.shared.message import SessionMessage +from mcp.shared.transport_context import TransportContext + +DispatcherTriple = tuple[Dispatcher[TransportContext], Dispatcher[TransportContext], Callable[[], None]] +PairFactory = Callable[..., DispatcherTriple] + + +def direct_pair(*, can_send_request: bool = True) -> DispatcherTriple: + client, server = create_direct_dispatcher_pair(can_send_request=can_send_request) + + def close() -> None: + client.close() + server.close() + + return client, server, close + + +def jsonrpc_pair(*, can_send_request: bool = True) -> DispatcherTriple: + """Two `JSONRPCDispatcher`s wired over crossed in-memory streams.""" + c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) + s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) + + def builder(_rid: object, _meta: object) -> TransportContext: + return TransportContext(kind="jsonrpc", can_send_request=can_send_request) + + client: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(s2c_recv, c2s_send, transport_builder=builder) + server: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(c2s_recv, s2c_send, transport_builder=builder) + + def close() -> None: + for s in (c2s_send, c2s_recv, s2c_send, s2c_recv): + s.close() + + return client, server, close + + +_JSONRPC_XFAIL = pytest.mark.xfail( + strict=True, + reason="JSONRPCDispatcher.run() not yet implemented (PR2 chunks b/c)", +) + + +@pytest.fixture( + params=[ + pytest.param(direct_pair, id="direct"), + pytest.param(jsonrpc_pair, id="jsonrpc", marks=_JSONRPC_XFAIL), + ] +) +def pair_factory(request: pytest.FixtureRequest) -> PairFactory: + return request.param + + +__all__ = ["PairFactory", "direct_pair", "jsonrpc_pair"] diff --git a/tests/shared/test_dispatcher.py b/tests/shared/test_dispatcher.py index 784ef6698f..31fba3dd5d 100644 --- a/tests/shared/test_dispatcher.py +++ b/tests/shared/test_dispatcher.py @@ -1,8 +1,9 @@ -"""Behavioral tests for the Dispatcher Protocol via DirectDispatcher. +"""Behavioral tests for the Dispatcher Protocol. -These exercise the `Dispatcher` / `DispatchContext` contract end-to-end using -the in-memory `DirectDispatcher`. JSON-RPC framing is covered separately in -``test_jsonrpc_dispatcher.py``. +The contract tests are parametrized over every `Dispatcher` implementation via +the `pair_factory` fixture (see ``conftest.py``); they must pass for both +`DirectDispatcher` and `JSONRPCDispatcher`. Implementation-specific tests pass +a concrete factory directly. """ from collections.abc import AsyncIterator, Mapping @@ -14,10 +15,12 @@ from mcp.shared.direct_dispatcher import DirectDispatcher, create_direct_dispatcher_pair from mcp.shared.dispatcher import DispatchContext, Dispatcher, OnNotify, OnRequest, Outbound -from mcp.shared.exceptions import MCPError, NoBackChannelError +from mcp.shared.exceptions import MCPError from mcp.shared.transport_context import TransportContext from mcp.types import INTERNAL_ERROR, INVALID_PARAMS, INVALID_REQUEST, REQUEST_TIMEOUT +from .conftest import PairFactory, direct_pair + class Recorder: def __init__(self) -> None: @@ -44,31 +47,34 @@ async def on_notify(ctx: DispatchContext[TransportContext], method: str, params: @asynccontextmanager async def running_pair( + factory: PairFactory, *, server_on_request: OnRequest | None = None, server_on_notify: OnNotify | None = None, client_on_request: OnRequest | None = None, client_on_notify: OnNotify | None = None, can_send_request: bool = True, -) -> AsyncIterator[tuple[DirectDispatcher, DirectDispatcher, Recorder, Recorder]]: +) -> AsyncIterator[tuple[Dispatcher[TransportContext], Dispatcher[TransportContext], Recorder, Recorder]]: """Yield ``(client, server, client_recorder, server_recorder)`` with both ``run()`` loops live.""" - client, server = create_direct_dispatcher_pair(can_send_request=can_send_request) + client, server, close = factory(can_send_request=can_send_request) client_rec, server_rec = Recorder(), Recorder() c_req, c_notify = echo_handlers(client_rec) s_req, s_notify = echo_handlers(server_rec) - async with anyio.create_task_group() as tg: - tg.start_soon(client.run, client_on_request or c_req, client_on_notify or c_notify) - tg.start_soon(server.run, server_on_request or s_req, server_on_notify or s_notify) - try: - yield client, server, client_rec, server_rec - finally: - client.close() - server.close() + try: + async with anyio.create_task_group() as tg: + tg.start_soon(client.run, client_on_request or c_req, client_on_notify or c_notify) + tg.start_soon(server.run, server_on_request or s_req, server_on_notify or s_notify) + try: + yield client, server, client_rec, server_rec + finally: + tg.cancel_scope.cancel() + finally: + close() @pytest.mark.anyio -async def test_send_raw_request_returns_result_from_peer_on_request(): - async with running_pair() as (client, _server, _crec, srec): +async def test_send_raw_request_returns_result_from_peer_on_request(pair_factory: PairFactory): + async with running_pair(pair_factory) as (client, _server, _crec, srec): with anyio.fail_after(5): result = await client.send_raw_request("tools/list", {"cursor": "abc"}) assert result == {"echoed": "tools/list", "params": {"cursor": "abc"}} @@ -76,13 +82,13 @@ async def test_send_raw_request_returns_result_from_peer_on_request(): @pytest.mark.anyio -async def test_send_raw_request_reraises_mcperror_from_handler_unchanged(): +async def test_send_raw_request_reraises_mcperror_from_handler_unchanged(pair_factory: PairFactory): async def on_request( ctx: DispatchContext[TransportContext], method: str, params: Mapping[str, Any] | None ) -> dict[str, Any]: raise MCPError(code=INVALID_PARAMS, message="bad cursor") - async with running_pair(server_on_request=on_request) as (client, *_): + async with running_pair(pair_factory, server_on_request=on_request) as (client, *_): with anyio.fail_after(5), pytest.raises(MCPError) as exc: await client.send_raw_request("tools/list", {}) assert exc.value.error.code == INVALID_PARAMS @@ -90,36 +96,22 @@ async def on_request( @pytest.mark.anyio -async def test_send_raw_request_wraps_non_mcperror_exception_as_internal_error(): - async def on_request( - ctx: DispatchContext[TransportContext], method: str, params: Mapping[str, Any] | None - ) -> dict[str, Any]: - raise ValueError("oops") - - async with running_pair(server_on_request=on_request) as (client, *_): - with anyio.fail_after(5), pytest.raises(MCPError) as exc: - await client.send_raw_request("tools/list", {}) - assert exc.value.error.code == INTERNAL_ERROR - assert isinstance(exc.value.__cause__, ValueError) - - -@pytest.mark.anyio -async def test_send_raw_request_with_timeout_raises_mcperror_request_timeout(): +async def test_send_raw_request_with_timeout_raises_mcperror_request_timeout(pair_factory: PairFactory): async def on_request( ctx: DispatchContext[TransportContext], method: str, params: Mapping[str, Any] | None ) -> dict[str, Any]: await anyio.sleep_forever() raise NotImplementedError - async with running_pair(server_on_request=on_request) as (client, *_): + async with running_pair(pair_factory, server_on_request=on_request) as (client, *_): with anyio.fail_after(5), pytest.raises(MCPError) as exc: await client.send_raw_request("slow", None, {"timeout": 0}) assert exc.value.error.code == REQUEST_TIMEOUT @pytest.mark.anyio -async def test_notify_invokes_peer_on_notify(): - async with running_pair() as (client, _server, _crec, srec): +async def test_notify_invokes_peer_on_notify(pair_factory: PairFactory): + async with running_pair(pair_factory) as (client, _server, _crec, srec): with anyio.fail_after(5): await client.notify("notifications/initialized", {"v": 1}) await srec.notified.wait() @@ -127,7 +119,7 @@ async def test_notify_invokes_peer_on_notify(): @pytest.mark.anyio -async def test_ctx_send_raw_request_round_trips_to_calling_side(): +async def test_ctx_send_raw_request_round_trips_to_calling_side(pair_factory: PairFactory): """A handler's ctx.send_raw_request reaches the side that made the inbound request.""" async def server_on_request( @@ -136,7 +128,7 @@ async def server_on_request( sample = await ctx.send_raw_request("sampling/createMessage", {"prompt": "hi"}) return {"sampled": sample} - async with running_pair(server_on_request=server_on_request) as (client, _server, crec, _srec): + async with running_pair(pair_factory, server_on_request=server_on_request) as (client, _server, crec, _srec): with anyio.fail_after(5): result = await client.send_raw_request("tools/call", None) assert crec.requests == [("sampling/createMessage", {"prompt": "hi"})] @@ -144,28 +136,27 @@ async def server_on_request( @pytest.mark.anyio -async def test_ctx_send_raw_request_raises_nobackchannelerror_when_transport_disallows(): +async def test_ctx_send_raw_request_raises_nobackchannelerror_when_transport_disallows(pair_factory: PairFactory): async def server_on_request( ctx: DispatchContext[TransportContext], method: str, params: Mapping[str, Any] | None ) -> dict[str, Any]: return await ctx.send_raw_request("sampling/createMessage", None) - async with running_pair(server_on_request=server_on_request, can_send_request=False) as (client, *_): - with anyio.fail_after(5), pytest.raises(NoBackChannelError) as exc: + async with running_pair(pair_factory, server_on_request=server_on_request, can_send_request=False) as (client, *_): + with anyio.fail_after(5), pytest.raises(MCPError) as exc: await client.send_raw_request("tools/call", None) - assert exc.value.method == "sampling/createMessage" assert exc.value.error.code == INVALID_REQUEST @pytest.mark.anyio -async def test_ctx_notify_invokes_calling_side_on_notify(): +async def test_ctx_notify_invokes_calling_side_on_notify(pair_factory: PairFactory): async def server_on_request( ctx: DispatchContext[TransportContext], method: str, params: Mapping[str, Any] | None ) -> dict[str, Any]: await ctx.notify("notifications/message", {"level": "info"}) return {} - async with running_pair(server_on_request=server_on_request) as (client, _server, crec, _srec): + async with running_pair(pair_factory, server_on_request=server_on_request) as (client, _server, crec, _srec): with anyio.fail_after(5): await client.send_raw_request("tools/call", None) await crec.notified.wait() @@ -173,7 +164,7 @@ async def server_on_request( @pytest.mark.anyio -async def test_ctx_progress_invokes_caller_on_progress_callback(): +async def test_ctx_progress_invokes_caller_on_progress_callback(pair_factory: PairFactory): async def server_on_request( ctx: DispatchContext[TransportContext], method: str, params: Mapping[str, Any] | None ) -> dict[str, Any]: @@ -185,14 +176,44 @@ async def server_on_request( async def on_progress(progress: float, total: float | None, message: str | None) -> None: received.append((progress, total, message)) - async with running_pair(server_on_request=server_on_request) as (client, *_): + async with running_pair(pair_factory, server_on_request=server_on_request) as (client, *_): with anyio.fail_after(5): await client.send_raw_request("tools/call", None, {"on_progress": on_progress}) assert received == [(0.5, 1.0, "halfway")] @pytest.mark.anyio -async def test_send_raw_request_issued_before_peer_run_blocks_until_peer_ready(): +async def test_ctx_progress_is_noop_when_caller_supplied_no_callback(pair_factory: PairFactory): + async def server_on_request( + ctx: DispatchContext[TransportContext], method: str, params: Mapping[str, Any] | None + ) -> dict[str, Any]: + await ctx.progress(0.5) + return {"ok": True} + + async with running_pair(pair_factory, server_on_request=server_on_request) as (client, *_): + with anyio.fail_after(5): + result = await client.send_raw_request("tools/call", None) + assert result == {"ok": True} + + +@pytest.mark.anyio +async def test_direct_send_raw_request_wraps_non_mcperror_exception_as_internal_error_with_cause(): + """DirectDispatcher-specific: the original exception is chained via __cause__.""" + + async def on_request( + ctx: DispatchContext[TransportContext], method: str, params: Mapping[str, Any] | None + ) -> dict[str, Any]: + raise ValueError("oops") + + async with running_pair(direct_pair, server_on_request=on_request) as (client, *_): + with anyio.fail_after(5), pytest.raises(MCPError) as exc: + await client.send_raw_request("tools/list", {}) + assert exc.value.error.code == INTERNAL_ERROR + assert isinstance(exc.value.__cause__, ValueError) + + +@pytest.mark.anyio +async def test_direct_send_raw_request_issued_before_peer_run_blocks_until_peer_ready(): client, server = create_direct_dispatcher_pair() s_req, s_notify = echo_handlers(Recorder()) c_req, c_notify = echo_handlers(Recorder()) @@ -212,21 +233,7 @@ async def late_start(): @pytest.mark.anyio -async def test_ctx_progress_is_noop_when_caller_supplied_no_callback(): - async def server_on_request( - ctx: DispatchContext[TransportContext], method: str, params: Mapping[str, Any] | None - ) -> dict[str, Any]: - await ctx.progress(0.5) - return {"ok": True} - - async with running_pair(server_on_request=server_on_request) as (client, *_): - with anyio.fail_after(5): - result = await client.send_raw_request("tools/call", None) - assert result == {"ok": True} - - -@pytest.mark.anyio -async def test_send_raw_request_and_notify_raise_runtimeerror_when_no_peer_connected(): +async def test_direct_send_raw_request_and_notify_raise_runtimeerror_when_no_peer_connected(): d = DirectDispatcher(TransportContext(kind="direct", can_send_request=True)) with pytest.raises(RuntimeError, match="no peer"): await d.send_raw_request("ping", None) @@ -235,7 +242,7 @@ async def test_send_raw_request_and_notify_raise_runtimeerror_when_no_peer_conne @pytest.mark.anyio -async def test_close_makes_run_return(): +async def test_direct_close_makes_run_return(): client, server = create_direct_dispatcher_pair() on_request, on_notify = echo_handlers(Recorder()) with anyio.fail_after(5): From ca77ffe31db198fbbacbc4a5b8e920879b51e710 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Thu, 16 Apr 2026 14:59:11 +0000 Subject: [PATCH 06/27] feat: JSONRPCDispatcher receive loop and dispatch (chunk b) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit run() drives the receive loop in a per-request task group; task_status.started() fires once send_request is usable. _dispatch routes each inbound message synchronously (no awaits — send_nowait/_spawn only) to avoid head-of-line blocking. _spawn propagates the sender's contextvars via Context.run(tg.start_soon, ...) so auth/OTel set by ASGI middleware survive. _fan_out_closed wakes pending send_request waiters with CONNECTION_CLOSED on shutdown (called both post-EOF and in finally; idempotent). Wire-param extraction (progressToken, cancelled.requestId, progress fields) uses structural match patterns — runtime narrowing, no casts, no mcp.types model coupling; malformed input fails to match and the correlation is skipped. _handle_request is happy-path only here (run on_request, write response); the exception-to-wire boundary lands in the next commit. Dispatcher.run() Protocol gained a task_status kwarg (it's a contract-level guarantee). DirectDispatcher.run() updated to match. running_pair now uses tg.start so the test body runs only once the dispatcher is ready. 20 contract tests pass; the 2 needing the exception boundary are strict-xfail. --- src/mcp/shared/direct_dispatcher.py | 10 +- src/mcp/shared/dispatcher.py | 13 +- src/mcp/shared/jsonrpc_dispatcher.py | 234 ++++++++++++++++++++++++++- tests/shared/conftest.py | 22 ++- tests/shared/test_dispatcher.py | 18 ++- 5 files changed, 274 insertions(+), 23 deletions(-) diff --git a/src/mcp/shared/direct_dispatcher.py b/src/mcp/shared/direct_dispatcher.py index bb5639a136..27443ec874 100644 --- a/src/mcp/shared/direct_dispatcher.py +++ b/src/mcp/shared/direct_dispatcher.py @@ -20,6 +20,7 @@ from typing import Any import anyio +import anyio.abc from mcp.shared.dispatcher import CallOptions, OnNotify, OnRequest, ProgressFnT from mcp.shared.exceptions import MCPError, NoBackChannelError @@ -101,10 +102,17 @@ async def notify(self, method: str, params: Mapping[str, Any] | None) -> None: raise RuntimeError("DirectDispatcher has no peer; use create_direct_dispatcher_pair()") await self._peer._dispatch_notify(method, params) - async def run(self, on_request: OnRequest, on_notify: OnNotify) -> None: + async def run( + self, + on_request: OnRequest, + on_notify: OnNotify, + *, + task_status: anyio.abc.TaskStatus[None] = anyio.TASK_STATUS_IGNORED, + ) -> None: self._on_request = on_request self._on_notify = on_notify self._ready.set() + task_status.started() await self._closed.wait() def close(self) -> None: diff --git a/src/mcp/shared/dispatcher.py b/src/mcp/shared/dispatcher.py index ee02e23896..20c090323b 100644 --- a/src/mcp/shared/dispatcher.py +++ b/src/mcp/shared/dispatcher.py @@ -20,6 +20,7 @@ from typing import Any, Protocol, TypedDict, TypeVar, runtime_checkable import anyio +import anyio.abc from mcp.shared.transport_context import TransportContext @@ -136,11 +137,21 @@ class Dispatcher(Outbound, Protocol[TransportT_co]): receive loop, per-request concurrency, and cancellation/progress wiring. """ - async def run(self, on_request: OnRequest, on_notify: OnNotify) -> None: + async def run( + self, + on_request: OnRequest, + on_notify: OnNotify, + *, + task_status: anyio.abc.TaskStatus[None] = anyio.TASK_STATUS_IGNORED, + ) -> None: """Drive the receive loop until the underlying channel closes. Each inbound request is dispatched to ``on_request`` in its own task; the returned dict (or raised ``MCPError``) is sent back as the response. Inbound notifications go to ``on_notify``. + + ``task_status.started()`` is called once the dispatcher is ready to + accept ``send_request``/``notify`` calls, so callers can use + ``await tg.start(dispatcher.run, on_request, on_notify)``. """ ... diff --git a/src/mcp/shared/jsonrpc_dispatcher.py b/src/mcp/shared/jsonrpc_dispatcher.py index 2a6e0951b8..6bf957c19a 100644 --- a/src/mcp/shared/jsonrpc_dispatcher.py +++ b/src/mcp/shared/jsonrpc_dispatcher.py @@ -8,20 +8,30 @@ The MCP type layer (`ServerRunner`, `Context`, `Client`) sits above this and sees only `(ctx, method, params) -> dict`. Transports sit below and see only `SessionMessage` reads/writes. + +The dispatcher is *mostly* MCP-agnostic — methods/params are opaque strings and +dicts — but it intercepts ``notifications/cancelled`` and +``notifications/progress`` because request correlation, cancellation and +progress are exactly the wiring this layer exists to provide. Those few wire +shapes are extracted with structural ``match`` patterns (no casts, no +``mcp.types`` model coupling); a malformed payload simply fails to match and +the correlation is skipped. """ from __future__ import annotations +import contextvars import logging -from collections.abc import Callable, Mapping +from collections.abc import Awaitable, Callable, Mapping from dataclasses import dataclass, field -from typing import Any, Generic, Literal, TypeVar, overload +from typing import Any, Generic, Literal, TypeVar, cast, overload import anyio +import anyio.abc from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from mcp.shared._stream_protocols import ReadStream, WriteStream -from mcp.shared.dispatcher import CallOptions, OnNotify, OnRequest, ProgressFnT +from mcp.shared.dispatcher import CallOptions, Dispatcher, OnNotify, OnRequest, ProgressFnT from mcp.shared.exceptions import MCPError, NoBackChannelError from mcp.shared.message import ( ClientMessageMetadata, @@ -31,11 +41,14 @@ ) from mcp.shared.transport_context import TransportContext from mcp.types import ( + CONNECTION_CLOSED, REQUEST_TIMEOUT, ErrorData, + JSONRPCError, JSONRPCMessage, JSONRPCNotification, JSONRPCRequest, + JSONRPCResponse, ProgressToken, RequestId, ) @@ -141,8 +154,12 @@ def _outbound_metadata(related_request_id: RequestId | None, opts: CallOptions | return None -class JSONRPCDispatcher(Generic[TransportT]): - """`Dispatcher` over the existing `SessionMessage` stream contract.""" +class JSONRPCDispatcher(Dispatcher[TransportT]): + """`Dispatcher` over the existing `SessionMessage` stream contract. + + Inherits the `Dispatcher` Protocol explicitly so pyright checks + conformance at the class definition rather than at first use. + """ @overload def __init__( @@ -171,13 +188,20 @@ def __init__( ) -> None: self._read_stream = read_stream self._write_stream = write_stream - self._transport_builder = transport_builder or _default_transport_builder + # The overloads guarantee that when `transport_builder` is omitted, + # `TransportT` is `TransportContext`, so the default is type-correct; + # pyright can't see across overloads, hence the cast. + self._transport_builder = cast( + "Callable[[RequestId | None, MessageMetadata], TransportT]", + transport_builder or _default_transport_builder, + ) self._peer_cancel_mode: PeerCancelMode = peer_cancel_mode self._raise_handler_exceptions = raise_handler_exceptions self._next_id = 0 self._pending: dict[RequestId, _Pending] = {} self._in_flight: dict[RequestId, _InFlight[TransportT]] = {} + self._tg: anyio.abc.TaskGroup | None = None self._running = False async def send_request( @@ -219,6 +243,11 @@ async def send_request( meta["progressToken"] = request_id out_params = {**(out_params or {}), "_meta": meta} + # buffer=1: at most one outcome is ever delivered. A `WouldBlock` from + # `_resolve_pending`/`_fan_out_closed` means the waiter already has an + # outcome and dropping the late/redundant signal is correct. buffer=0 + # is unsafe — there's a window between registering `_pending[id]` and + # parking in `receive()` where a close signal would be lost. send, receive = anyio.create_memory_object_stream[dict[str, Any] | ErrorData](1) pending = _Pending(send=send, receive=receive, on_progress=on_progress) self._pending[request_id] = pending @@ -264,8 +293,197 @@ async def notify( msg = JSONRPCNotification(jsonrpc="2.0", method=method, params=dict(params) if params is not None else None) await self._write(msg, _outbound_metadata(_related_request_id, None)) - async def run(self, on_request: OnRequest, on_notify: OnNotify) -> None: - raise NotImplementedError # chunk (b) + async def run( + self, + on_request: OnRequest, + on_notify: OnNotify, + *, + task_status: anyio.abc.TaskStatus[None] = anyio.TASK_STATUS_IGNORED, + ) -> None: + """Drive the receive loop until the read stream closes. + + Each inbound request is handled in its own task in an internal task + group; ``task_status.started()`` fires once that group is open, so + ``await tg.start(dispatcher.run, ...)`` resumes when ``send_request`` + is usable. + """ + try: + async with anyio.create_task_group() as tg: + self._tg = tg + self._running = True + task_status.started() + async with self._read_stream: + async for item in self._read_stream: + # Duck-typed: `_context_streams.ContextReceiveStream` + # exposes `.last_context` (the sender's contextvars + # snapshot per message). Plain memory streams don't. + sender_ctx: contextvars.Context | None = getattr(self._read_stream, "last_context", None) + self._dispatch(item, on_request, on_notify, sender_ctx) + # Read stream EOF: wake any blocked `send_request` waiters now, + # *before* the task group joins, so handlers parked in + # `dctx.send_request()` can unwind and the join doesn't deadlock. + self._running = False + self._fan_out_closed() + finally: + # Covers the cancel/crash paths where the inline fan-out above is + # never reached. Idempotent. + self._running = False + self._tg = None + self._fan_out_closed() + + def _dispatch( + self, + item: SessionMessage | Exception, + on_request: OnRequest, + on_notify: OnNotify, + sender_ctx: contextvars.Context | None, + ) -> None: + """Route one inbound item. Synchronous: never awaits. + + Everything here is `send_nowait` or `_spawn`. An `await` would let one + slow message head-of-line block the entire read loop. + """ + if isinstance(item, Exception): + logger.debug("transport yielded exception: %r", item) + return + metadata = item.metadata + msg = item.message + match msg: + case JSONRPCRequest(): + self._dispatch_request(msg, metadata, on_request, sender_ctx) + case JSONRPCNotification(): + self._dispatch_notification(msg, metadata, on_notify, sender_ctx) + case JSONRPCResponse(): + self._resolve_pending(msg.id, msg.result) + case JSONRPCError(): + # `id` may be None per JSON-RPC (parse error before id known). + self._resolve_pending(msg.id, msg.error) + + def _dispatch_request( + self, + req: JSONRPCRequest, + metadata: MessageMetadata, + on_request: OnRequest, + sender_ctx: contextvars.Context | None, + ) -> None: + progress_token: ProgressToken | None + match req.params: + case {"_meta": {"progressToken": str() | int() as progress_token}}: + pass + case _: + progress_token = None + transport_ctx = self._transport_builder(req.id, metadata) + dctx = _JSONRPCDispatchContext( + transport=transport_ctx, + _dispatcher=self, + _request_id=req.id, + _progress_token=progress_token, + ) + scope = anyio.CancelScope() + self._in_flight[req.id] = _InFlight(scope=scope, dctx=dctx) + self._spawn(self._handle_request, req, dctx, scope, on_request, sender_ctx=sender_ctx) + + def _dispatch_notification( + self, + msg: JSONRPCNotification, + metadata: MessageMetadata, + on_notify: OnNotify, + sender_ctx: contextvars.Context | None, + ) -> None: + if msg.method == "notifications/cancelled": + match msg.params: + case {"requestId": str() | int() as rid} if (in_flight := self._in_flight.get(rid)) is not None: + in_flight.cancelled_by_peer = True + in_flight.dctx.cancel_requested.set() + if self._peer_cancel_mode == "interrupt": + in_flight.scope.cancel() + case _: + pass + return + if msg.method == "notifications/progress": + match msg.params: + case {"progressToken": str() | int() as token, "progress": int() | float() as progress} if ( + pending := self._pending.get(token) + ) is not None and pending.on_progress is not None: + total = msg.params.get("total") + message = msg.params.get("message") + self._spawn( + pending.on_progress, + float(progress), + float(total) if isinstance(total, int | float) else None, + message if isinstance(message, str) else None, + sender_ctx=sender_ctx, + ) + case _: + pass + # fall through: progress is also teed to on_notify + transport_ctx = self._transport_builder(None, metadata) + dctx = _JSONRPCDispatchContext(transport=transport_ctx, _dispatcher=self, _request_id=None) + self._spawn(on_notify, dctx, msg.method, msg.params, sender_ctx=sender_ctx) + + def _resolve_pending(self, request_id: RequestId | None, outcome: dict[str, Any] | ErrorData) -> None: + pending = self._pending.get(request_id) if request_id is not None else None + if pending is None: + logger.debug("dropping response for unknown/late request id %r", request_id) + return + try: + pending.send.send_nowait(outcome) + except (anyio.WouldBlock, anyio.BrokenResourceError, anyio.ClosedResourceError): + logger.debug("waiter for request id %r already gone", request_id) + + def _spawn( + self, + fn: Callable[..., Awaitable[Any]], + *args: object, + sender_ctx: contextvars.Context | None, + ) -> None: + """Schedule ``fn(*args)`` in the run() task group, propagating the sender's contextvars. + + ASGI middleware (auth, OTel) sets contextvars on the request task that + wrote into the read stream. ``Context.run(tg.start_soon, ...)`` makes + the spawned handler inherit *that* context instead of the receive + loop's, so ``auth_context_var`` and OTel spans survive. + """ + assert self._tg is not None + if sender_ctx is not None: + sender_ctx.run(self._tg.start_soon, fn, *args) + else: + self._tg.start_soon(fn, *args) + + def _fan_out_closed(self) -> None: + """Wake every pending ``send_request`` waiter with ``CONNECTION_CLOSED``. + + Synchronous (uses ``send_nowait``) because it's called from ``finally`` + which may be inside a cancelled scope. Idempotent. + """ + closed = ErrorData(code=CONNECTION_CLOSED, message="connection closed") + for pending in self._pending.values(): + try: + pending.send.send_nowait(closed) + except (anyio.WouldBlock, anyio.BrokenResourceError, anyio.ClosedResourceError): + pass + self._pending.clear() + + async def _handle_request( + self, + req: JSONRPCRequest, + dctx: _JSONRPCDispatchContext[TransportT], + scope: anyio.CancelScope, + on_request: OnRequest, + ) -> None: + """Run ``on_request`` for one inbound request and write its response. + + Chunk (b): happy-path only. The full exception-to-wire boundary + (MCPError, ValidationError, INTERNAL_ERROR scrubbing, peer-cancel + no-response) lands in chunk (c). + """ + try: + with scope: + result = await on_request(dctx, req.method, req.params) + await self._write(JSONRPCResponse(jsonrpc="2.0", id=req.id, result=result)) + finally: + self._in_flight.pop(req.id, None) + dctx.close() def _allocate_id(self) -> int: self._next_id += 1 diff --git a/tests/shared/conftest.py b/tests/shared/conftest.py index ffa254804e..b7049493a2 100644 --- a/tests/shared/conftest.py +++ b/tests/shared/conftest.py @@ -48,20 +48,26 @@ def close() -> None: return client, server, close -_JSONRPC_XFAIL = pytest.mark.xfail( - strict=True, - reason="JSONRPCDispatcher.run() not yet implemented (PR2 chunks b/c)", -) - - @pytest.fixture( params=[ pytest.param(direct_pair, id="direct"), - pytest.param(jsonrpc_pair, id="jsonrpc", marks=_JSONRPC_XFAIL), + pytest.param(jsonrpc_pair, id="jsonrpc"), ] ) def pair_factory(request: pytest.FixtureRequest) -> PairFactory: return request.param -__all__ = ["PairFactory", "direct_pair", "jsonrpc_pair"] +def xfail_jsonrpc_chunk_c(request: pytest.FixtureRequest, factory: PairFactory) -> None: + """Apply a strict xfail when running against the JSON-RPC dispatcher. + + Use for contract tests that require `_handle_request`'s exception boundary + (PR2 chunk c). Remove once that lands. + """ + if factory is jsonrpc_pair: + request.applymarker( + pytest.mark.xfail(strict=True, reason="needs JSONRPCDispatcher._handle_request exception boundary") + ) + + +__all__ = ["PairFactory", "direct_pair", "jsonrpc_pair", "xfail_jsonrpc_chunk_c"] diff --git a/tests/shared/test_dispatcher.py b/tests/shared/test_dispatcher.py index 31fba3dd5d..aef6b60bcb 100644 --- a/tests/shared/test_dispatcher.py +++ b/tests/shared/test_dispatcher.py @@ -19,7 +19,7 @@ from mcp.shared.transport_context import TransportContext from mcp.types import INTERNAL_ERROR, INVALID_PARAMS, INVALID_REQUEST, REQUEST_TIMEOUT -from .conftest import PairFactory, direct_pair +from .conftest import PairFactory, direct_pair, xfail_jsonrpc_chunk_c class Recorder: @@ -62,8 +62,8 @@ async def running_pair( s_req, s_notify = echo_handlers(server_rec) try: async with anyio.create_task_group() as tg: - tg.start_soon(client.run, client_on_request or c_req, client_on_notify or c_notify) - tg.start_soon(server.run, server_on_request or s_req, server_on_notify or s_notify) + await tg.start(client.run, client_on_request or c_req, client_on_notify or c_notify) + await tg.start(server.run, server_on_request or s_req, server_on_notify or s_notify) try: yield client, server, client_rec, server_rec finally: @@ -82,7 +82,11 @@ async def test_send_raw_request_returns_result_from_peer_on_request(pair_factory @pytest.mark.anyio -async def test_send_raw_request_reraises_mcperror_from_handler_unchanged(pair_factory: PairFactory): +async def test_send_raw_request_reraises_mcperror_from_handler_unchanged( + pair_factory: PairFactory, request: pytest.FixtureRequest +): + xfail_jsonrpc_chunk_c(request, pair_factory) + async def on_request( ctx: DispatchContext[TransportContext], method: str, params: Mapping[str, Any] | None ) -> dict[str, Any]: @@ -136,7 +140,11 @@ async def server_on_request( @pytest.mark.anyio -async def test_ctx_send_raw_request_raises_nobackchannelerror_when_transport_disallows(pair_factory: PairFactory): +async def test_ctx_send_raw_request_raises_nobackchannelerror_when_transport_disallows( + pair_factory: PairFactory, request: pytest.FixtureRequest +): + xfail_jsonrpc_chunk_c(request, pair_factory) + async def server_on_request( ctx: DispatchContext[TransportContext], method: str, params: Mapping[str, Any] | None ) -> dict[str, Any]: From bb6f091d13b6657b09ef1bc60a57f1b7b3357267 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Thu, 16 Apr 2026 15:20:43 +0000 Subject: [PATCH 07/27] feat: JSONRPCDispatcher exception boundary (chunk c) _handle_request is now the single exception-to-wire boundary: - MCPError -> JSONRPCError(e.error) - pydantic ValidationError -> INVALID_PARAMS - Exception -> INTERNAL_ERROR(str(e)), logged, optionally re-raised - outer-cancel (run() TG shutdown) -> shielded REQUEST_CANCELLED write, re-raise - peer-cancel (notifications/cancelled) -> scope swallows, no response written dctx.close() runs in an inner finally so the back-channel shuts the moment the handler exits. _write_result/_write_error swallow Broken/ClosedResourceError so a dropped connection during the response write doesn't crash the dispatcher. All 22 contract tests now pass against both DirectDispatcher and JSONRPCDispatcher; chunk-c xfail markers removed. --- src/mcp/shared/jsonrpc_dispatcher.py | 54 ++++++++++++++++++++++++---- tests/shared/conftest.py | 14 +------- tests/shared/test_dispatcher.py | 14 ++------ 3 files changed, 52 insertions(+), 30 deletions(-) diff --git a/src/mcp/shared/jsonrpc_dispatcher.py b/src/mcp/shared/jsonrpc_dispatcher.py index 6bf957c19a..f35b37cf9d 100644 --- a/src/mcp/shared/jsonrpc_dispatcher.py +++ b/src/mcp/shared/jsonrpc_dispatcher.py @@ -29,6 +29,7 @@ import anyio import anyio.abc from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream +from pydantic import ValidationError from mcp.shared._stream_protocols import ReadStream, WriteStream from mcp.shared.dispatcher import CallOptions, Dispatcher, OnNotify, OnRequest, ProgressFnT @@ -42,6 +43,9 @@ from mcp.shared.transport_context import TransportContext from mcp.types import ( CONNECTION_CLOSED, + INTERNAL_ERROR, + INVALID_PARAMS, + REQUEST_CANCELLED, REQUEST_TIMEOUT, ErrorData, JSONRPCError, @@ -473,17 +477,43 @@ async def _handle_request( ) -> None: """Run ``on_request`` for one inbound request and write its response. - Chunk (b): happy-path only. The full exception-to-wire boundary - (MCPError, ValidationError, INTERNAL_ERROR scrubbing, peer-cancel - no-response) lands in chunk (c). + This is the single exception-to-wire boundary: handler exceptions are + caught here and serialized to ``JSONRPCError``. Nothing above this in + the stack constructs wire errors. """ try: with scope: - result = await on_request(dctx, req.method, req.params) - await self._write(JSONRPCResponse(jsonrpc="2.0", id=req.id, result=result)) + try: + result = await on_request(dctx, req.method, req.params) + finally: + # Close the back-channel the moment the handler exits + # (success or raise), before the response write — a handler + # spawning detached work that later calls + # `dctx.send_request()` should see `NoBackChannelError`. + dctx.close() + await self._write_result(req.id, result) + # Peer-cancel: `_dispatch_notification` cancelled this scope. anyio + # swallows a scope's *own* cancel at __exit__, so the result write + # (or the handler) is interrupted and execution lands here without + # reaching the `except cancelled` arm below. Spec SHOULD: send no + # response — fall through to `finally`. + except anyio.get_cancelled_exc_class(): + # Outer-cancel: run()'s task group is shutting down. Any bare + # `await` here re-raises immediately, so shield the courtesy write. + with anyio.CancelScope(shield=True): + await self._write_error(req.id, ErrorData(code=REQUEST_CANCELLED, message="Request cancelled")) + raise + except MCPError as e: + await self._write_error(req.id, e.error) + except ValidationError as e: + await self._write_error(req.id, ErrorData(code=INVALID_PARAMS, message=str(e))) + except Exception as e: + logger.exception("handler for %r raised", req.method) + await self._write_error(req.id, ErrorData(code=INTERNAL_ERROR, message=str(e))) + if self._raise_handler_exceptions: + raise finally: self._in_flight.pop(req.id, None) - dctx.close() def _allocate_id(self) -> int: self._next_id += 1 @@ -492,6 +522,18 @@ def _allocate_id(self) -> int: async def _write(self, message: JSONRPCMessage, metadata: MessageMetadata = None) -> None: await self._write_stream.send(SessionMessage(message=message, metadata=metadata)) + async def _write_result(self, request_id: RequestId, result: dict[str, Any]) -> None: + try: + await self._write(JSONRPCResponse(jsonrpc="2.0", id=request_id, result=result)) + except (anyio.BrokenResourceError, anyio.ClosedResourceError): + logger.debug("dropped result for %r: write stream closed", request_id) + + async def _write_error(self, request_id: RequestId, error: ErrorData) -> None: + try: + await self._write(JSONRPCError(jsonrpc="2.0", id=request_id, error=error)) + except (anyio.BrokenResourceError, anyio.ClosedResourceError): + logger.debug("dropped error for %r: write stream closed", request_id) + async def _cancel_outbound(self, request_id: RequestId, reason: str) -> None: try: await self.notify("notifications/cancelled", {"requestId": request_id, "reason": reason}) diff --git a/tests/shared/conftest.py b/tests/shared/conftest.py index b7049493a2..1222c05aba 100644 --- a/tests/shared/conftest.py +++ b/tests/shared/conftest.py @@ -58,16 +58,4 @@ def pair_factory(request: pytest.FixtureRequest) -> PairFactory: return request.param -def xfail_jsonrpc_chunk_c(request: pytest.FixtureRequest, factory: PairFactory) -> None: - """Apply a strict xfail when running against the JSON-RPC dispatcher. - - Use for contract tests that require `_handle_request`'s exception boundary - (PR2 chunk c). Remove once that lands. - """ - if factory is jsonrpc_pair: - request.applymarker( - pytest.mark.xfail(strict=True, reason="needs JSONRPCDispatcher._handle_request exception boundary") - ) - - -__all__ = ["PairFactory", "direct_pair", "jsonrpc_pair", "xfail_jsonrpc_chunk_c"] +__all__ = ["PairFactory", "direct_pair", "jsonrpc_pair"] diff --git a/tests/shared/test_dispatcher.py b/tests/shared/test_dispatcher.py index aef6b60bcb..fc967c1299 100644 --- a/tests/shared/test_dispatcher.py +++ b/tests/shared/test_dispatcher.py @@ -19,7 +19,7 @@ from mcp.shared.transport_context import TransportContext from mcp.types import INTERNAL_ERROR, INVALID_PARAMS, INVALID_REQUEST, REQUEST_TIMEOUT -from .conftest import PairFactory, direct_pair, xfail_jsonrpc_chunk_c +from .conftest import PairFactory, direct_pair class Recorder: @@ -82,11 +82,7 @@ async def test_send_raw_request_returns_result_from_peer_on_request(pair_factory @pytest.mark.anyio -async def test_send_raw_request_reraises_mcperror_from_handler_unchanged( - pair_factory: PairFactory, request: pytest.FixtureRequest -): - xfail_jsonrpc_chunk_c(request, pair_factory) - +async def test_send_raw_request_reraises_mcperror_from_handler_unchanged(pair_factory: PairFactory): async def on_request( ctx: DispatchContext[TransportContext], method: str, params: Mapping[str, Any] | None ) -> dict[str, Any]: @@ -140,11 +136,7 @@ async def server_on_request( @pytest.mark.anyio -async def test_ctx_send_raw_request_raises_nobackchannelerror_when_transport_disallows( - pair_factory: PairFactory, request: pytest.FixtureRequest -): - xfail_jsonrpc_chunk_c(request, pair_factory) - +async def test_ctx_send_raw_request_raises_nobackchannelerror_when_transport_disallows(pair_factory: PairFactory): async def server_on_request( ctx: DispatchContext[TransportContext], method: str, params: Mapping[str, Any] | None ) -> dict[str, Any]: From 11e83ffa8ea9b45acc23c7f866bf3c4784d0d2da Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Thu, 16 Apr 2026 15:46:17 +0000 Subject: [PATCH 08/27] test: JSON-RPC-specific dispatcher tests + coverage to 100% Covers behaviors with no DirectDispatcher analog: out-of-order response correlation, INTERNAL_ERROR over the wire, peer-cancel in interrupt and signal modes, CONNECTION_CLOSED on stream EOF mid-await, late-response drop, raise_handler_exceptions propagation, ServerMessageMetadata tagging on ctx.send_request, null-id JSONRPCError drop, ValidationError->INVALID_PARAMS, contextvar propagation via _spawn, and the defensive Broken/Closed/WouldBlock catches. Two small src tweaks for coverage: - _cancel_outbound: combine the two except arms into one tuple - _dispatch: pragma no-branch on the final case (match is exhaustive over JSONRPCMessage; the no-match arc is unreachable) 43 tests, 100% coverage on all PR2 modules, 0.15s wall-clock. --- src/mcp/shared/jsonrpc_dispatcher.py | 8 +- tests/shared/test_jsonrpc_dispatcher.py | 531 ++++++++++++++++++++++++ 2 files changed, 535 insertions(+), 4 deletions(-) create mode 100644 tests/shared/test_jsonrpc_dispatcher.py diff --git a/src/mcp/shared/jsonrpc_dispatcher.py b/src/mcp/shared/jsonrpc_dispatcher.py index f35b37cf9d..bbf5666069 100644 --- a/src/mcp/shared/jsonrpc_dispatcher.py +++ b/src/mcp/shared/jsonrpc_dispatcher.py @@ -359,8 +359,10 @@ def _dispatch( self._dispatch_notification(msg, metadata, on_notify, sender_ctx) case JSONRPCResponse(): self._resolve_pending(msg.id, msg.result) - case JSONRPCError(): + case JSONRPCError(): # pragma: no branch # `id` may be None per JSON-RPC (parse error before id known). + # The match is exhaustive over JSONRPCMessage; the no-match arc + # on this final case is unreachable. self._resolve_pending(msg.id, msg.error) def _dispatch_request( @@ -537,7 +539,5 @@ async def _write_error(self, request_id: RequestId, error: ErrorData) -> None: async def _cancel_outbound(self, request_id: RequestId, reason: str) -> None: try: await self.notify("notifications/cancelled", {"requestId": request_id, "reason": reason}) - except anyio.BrokenResourceError: - pass - except anyio.ClosedResourceError: + except (anyio.BrokenResourceError, anyio.ClosedResourceError): pass diff --git a/tests/shared/test_jsonrpc_dispatcher.py b/tests/shared/test_jsonrpc_dispatcher.py new file mode 100644 index 0000000000..ff24ef4c6b --- /dev/null +++ b/tests/shared/test_jsonrpc_dispatcher.py @@ -0,0 +1,531 @@ +"""JSON-RPC-specific Dispatcher tests. + +Behaviors with no `DirectDispatcher` analog: request-id correlation, the +exception-to-wire boundary, peer-cancel handling, and shutdown fan-out. +The contract tests shared with `DirectDispatcher` live in +``test_dispatcher.py``. +""" + +import contextvars +from collections.abc import Mapping +from typing import Any + +import anyio +import pytest + +from mcp.shared._context_streams import ContextReceiveStream, ContextSendStream +from mcp.shared.dispatcher import DispatchContext +from mcp.shared.exceptions import MCPError +from mcp.shared.jsonrpc_dispatcher import ( # pyright: ignore[reportPrivateUsage] + JSONRPCDispatcher, + _outbound_metadata, + _Pending, +) +from mcp.shared.message import ClientMessageMetadata, ServerMessageMetadata, SessionMessage +from mcp.shared.transport_context import TransportContext +from mcp.types import ( + CONNECTION_CLOSED, + INTERNAL_ERROR, + INVALID_PARAMS, + ErrorData, + JSONRPCError, + JSONRPCRequest, + JSONRPCResponse, + Tool, +) + +from .conftest import jsonrpc_pair +from .test_dispatcher import Recorder, echo_handlers, running_pair + +DCtx = DispatchContext[TransportContext] + + +@pytest.mark.anyio +async def test_concurrent_send_requests_correlate_by_id_when_responses_arrive_out_of_order(): + release_first = anyio.Event() + + async def server_on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: + if method == "first": + await release_first.wait() + return {"m": method} + + async with running_pair(jsonrpc_pair, server_on_request=server_on_request) as (client, *_): + results: dict[str, dict[str, Any]] = {} + + async def call(method: str) -> None: + results[method] = await client.send_request(method, None) + + with anyio.fail_after(5): + async with anyio.create_task_group() as tg: + tg.start_soon(call, "first") + await anyio.sleep(0) + tg.start_soon(call, "second") + await anyio.sleep(0) + # second resolves while first is still parked + assert "first" not in results + release_first.set() + assert results == {"first": {"m": "first"}, "second": {"m": "second"}} + + +@pytest.mark.anyio +async def test_handler_raising_exception_sends_internal_error_with_str_message(): + """Per design: INTERNAL_ERROR carries str(e), not a scrubbed message.""" + + async def server_on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: + raise RuntimeError("kaboom") + + async with running_pair(jsonrpc_pair, server_on_request=server_on_request) as (client, *_): + with anyio.fail_after(5), pytest.raises(MCPError) as exc: + await client.send_request("tools/list", None) + assert exc.value.error.code == INTERNAL_ERROR + assert exc.value.error.message == "kaboom" + assert exc.value.__cause__ is None # cause does not survive the wire + + +@pytest.mark.anyio +async def test_peer_cancel_interrupt_mode_sets_cancel_requested_and_sends_no_response(): + handler_started = anyio.Event() + handler_exited = anyio.Event() + seen_ctx: list[DCtx] = [] + + async def server_on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: + seen_ctx.append(ctx) + handler_started.set() + try: + await anyio.sleep_forever() + finally: + handler_exited.set() + raise NotImplementedError + + async with running_pair(jsonrpc_pair, server_on_request=server_on_request) as (client, *_): + with anyio.fail_after(5): + async with anyio.create_task_group() as tg: + + async def call_then_record() -> None: + with pytest.raises(MCPError): # we'll cancel via tg below + await client.send_request("slow", None) + + tg.start_soon(call_then_record) + await handler_started.wait() + # cancel just the handler (peer-cancel), not our caller + await client.notify("notifications/cancelled", {"requestId": 1}) + await handler_exited.wait() + # Handler torn down, no response was written; caller is still parked. + # Cancel the caller's task to end the test. + tg.cancel_scope.cancel() + assert seen_ctx[0].cancel_requested.is_set() + + +@pytest.mark.anyio +async def test_peer_cancel_signal_mode_sets_event_but_handler_runs_to_completion(): + handler_started = anyio.Event() + cancel_seen = anyio.Event() + + async def server_on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: + handler_started.set() + await ctx.cancel_requested.wait() + cancel_seen.set() + return {"finished": True} + + def factory(*, can_send_request: bool = True): + client, server, close = jsonrpc_pair(can_send_request=can_send_request) + # Reach in to set signal mode on the server side. + assert isinstance(server, JSONRPCDispatcher) + server._peer_cancel_mode = "signal" # pyright: ignore[reportPrivateUsage] + return client, server, close + + result_box: list[dict[str, Any]] = [] + async with running_pair(factory, server_on_request=server_on_request) as (client, *_): + with anyio.fail_after(5): + async with anyio.create_task_group() as tg: + + async def call() -> None: + result_box.append(await client.send_request("slow", None)) + + tg.start_soon(call) + await handler_started.wait() + await client.notify("notifications/cancelled", {"requestId": 1}) + await cancel_seen.wait() + assert result_box == [{"finished": True}] + + +@pytest.mark.anyio +async def test_send_request_raises_connection_closed_when_read_stream_eofs_mid_await(): + """A blocked send_request is woken with CONNECTION_CLOSED when run() exits.""" + c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) + s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) + client: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(s2c_recv, c2s_send) + on_request, on_notify = echo_handlers(Recorder()) + try: + async with anyio.create_task_group() as tg: + await tg.start(client.run, on_request, on_notify) + + async def caller() -> None: + with pytest.raises(MCPError) as exc: + await client.send_request("ping", None) + assert exc.value.error.code == CONNECTION_CLOSED + + tg.start_soon(caller) + await anyio.sleep(0) + # No server: simulate the peer dropping by closing the read side. + s2c_send.close() + finally: + for s in (c2s_send, c2s_recv, s2c_send, s2c_recv): + s.close() + + +@pytest.mark.anyio +async def test_late_response_after_timeout_is_dropped_without_crashing(): + handler_started = anyio.Event() + proceed = anyio.Event() + + async def server_on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: + handler_started.set() + await proceed.wait() + return {"late": True} + + async with running_pair(jsonrpc_pair, server_on_request=server_on_request) as (client, *_): + with anyio.fail_after(5): + with pytest.raises(MCPError): # REQUEST_TIMEOUT + await client.send_request("slow", None, {"timeout": 0}) + # The server handler is still running; let it finish and write a + # response for an id the client has already discarded. + await handler_started.wait() + proceed.set() + # One more round-trip proves the dispatcher is still healthy. + assert await client.send_request("ping", None) == {"late": True} + + +@pytest.mark.anyio +async def test_raise_handler_exceptions_true_propagates_out_of_run(): + c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) + s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) + + def builder(_rid: object, _meta: object) -> TransportContext: + return TransportContext(kind="jsonrpc", can_send_request=True) + + server: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher( + c2s_recv, s2c_send, transport_builder=builder, raise_handler_exceptions=True + ) + + async def boom(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: + raise RuntimeError("propagate me") + + async def on_notify(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> None: + raise NotImplementedError + + try: + with pytest.raises(BaseException) as exc: + async with anyio.create_task_group() as tg: + await tg.start(server.run, boom, on_notify) + # Inject a request directly onto the server's read stream. + await c2s_send.send( + SessionMessage(message=JSONRPCRequest(jsonrpc="2.0", id=1, method="x", params=None)) + ) + assert exc.group_contains(RuntimeError, match="propagate me") + # The error response was still written before re-raising. + sent = s2c_recv.receive_nowait() + assert isinstance(sent, SessionMessage) + assert isinstance(sent.message, JSONRPCError) + assert sent.message.error.code == INTERNAL_ERROR + finally: + for s in (c2s_send, c2s_recv, s2c_send, s2c_recv): + s.close() + + +@pytest.mark.anyio +async def test_ctx_send_request_tags_outbound_with_server_message_metadata(): + """Server-to-client requests carry related_request_id for SHTTP routing.""" + c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) + s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) + server: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(c2s_recv, s2c_send) + + async def server_on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: + return await ctx.send_request("sampling/createMessage", {"prompt": "hi"}) + + async def on_notify(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> None: + raise NotImplementedError + + try: + async with anyio.create_task_group() as tg: + await tg.start(server.run, server_on_request, on_notify) + # Kick the server with an inbound request id=7. + await c2s_send.send(SessionMessage(message=JSONRPCRequest(jsonrpc="2.0", id=7, method="t", params=None))) + with anyio.fail_after(5): + outbound = await s2c_recv.receive() + assert isinstance(outbound, SessionMessage) + assert isinstance(outbound.message, JSONRPCRequest) + assert isinstance(outbound.metadata, ServerMessageMetadata) + assert outbound.metadata.related_request_id == 7 + # Reply so the handler completes cleanly. + await c2s_send.send( + SessionMessage(message=JSONRPCResponse(jsonrpc="2.0", id=outbound.message.id, result={"ok": True})) + ) + with anyio.fail_after(5): + final = await s2c_recv.receive() + assert isinstance(final, SessionMessage) + assert isinstance(final.message, JSONRPCResponse) + assert final.message.id == 7 + tg.cancel_scope.cancel() + finally: + for s in (c2s_send, c2s_recv, s2c_send, s2c_recv): + s.close() + + +@pytest.mark.anyio +async def test_ctx_progress_with_only_progress_value_omits_total_and_message(): + received: list[tuple[float, float | None, str | None]] = [] + + async def on_progress(progress: float, total: float | None, message: str | None) -> None: + received.append((progress, total, message)) + + async def server_on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: + await ctx.progress(0.25) + return {} + + async with running_pair(jsonrpc_pair, server_on_request=server_on_request) as (client, *_): + with anyio.fail_after(5): + await client.send_request("t", None, {"on_progress": on_progress}) + assert received == [(0.25, None, None)] + + +@pytest.mark.anyio +async def test_handler_raising_validation_error_sends_invalid_params(): + async def server_on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: + Tool.model_validate({"name": 123}) # raises ValidationError + raise NotImplementedError + + async with running_pair(jsonrpc_pair, server_on_request=server_on_request) as (client, *_): + with anyio.fail_after(5), pytest.raises(MCPError) as exc: + await client.send_request("t", None) + assert exc.value.error.code == INVALID_PARAMS + + +@pytest.mark.anyio +async def test_send_request_before_run_raises_runtimeerror(): + c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](1) + s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](1) + d: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(s2c_recv, c2s_send) + try: + with pytest.raises(RuntimeError, match="before run"): + await d.send_request("ping", None) + finally: + for s in (c2s_send, c2s_recv, s2c_send, s2c_recv): + s.close() + + +@pytest.mark.anyio +async def test_transport_exception_in_read_stream_is_logged_and_dropped(): + c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](4) + s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](4) + server: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(c2s_recv, s2c_send) + on_request, on_notify = echo_handlers(Recorder()) + try: + async with anyio.create_task_group() as tg: + await tg.start(server.run, on_request, on_notify) + await c2s_send.send(ValueError("transport hiccup")) + # Dispatcher must remain healthy after the dropped exception. + await c2s_send.send(SessionMessage(message=JSONRPCRequest(jsonrpc="2.0", id=1, method="t", params=None))) + with anyio.fail_after(5): + resp = await s2c_recv.receive() + assert isinstance(resp, SessionMessage) + assert isinstance(resp.message, JSONRPCResponse) + tg.cancel_scope.cancel() + finally: + for s in (c2s_send, c2s_recv, s2c_send, s2c_recv): + s.close() + + +@pytest.mark.anyio +async def test_progress_notification_for_unknown_token_falls_through_to_on_notify(): + async with running_pair(jsonrpc_pair) as (client, _server, _crec, srec): + with anyio.fail_after(5): + await client.notify("notifications/progress", {"progressToken": 999, "progress": 0.5}) + await srec.notified.wait() + assert srec.notifications == [("notifications/progress", {"progressToken": 999, "progress": 0.5})] + + +@pytest.mark.anyio +async def test_cancelled_notification_for_unknown_request_id_is_noop(): + async with running_pair(jsonrpc_pair) as (client, _server, _crec, srec): + with anyio.fail_after(5): + await client.notify("notifications/cancelled", {"requestId": 999}) + # No effect; dispatcher remains healthy. + assert await client.send_request("t", None) == {"echoed": "t", "params": {}} + assert srec.notifications == [] # cancelled is fully consumed, never teed + + +_probe: contextvars.ContextVar[str] = contextvars.ContextVar("probe", default="unset") + + +@pytest.mark.anyio +async def test_handler_inherits_sender_contextvars_via_spawn(): + """The handler task sees contextvars set by the task that wrote into the read stream.""" + raw_send, raw_recv = anyio.create_memory_object_stream[tuple[contextvars.Context, SessionMessage | Exception]](4) + read_stream = ContextReceiveStream[SessionMessage | Exception](raw_recv) + write_send = ContextSendStream[SessionMessage | Exception](raw_send) + out_send, out_recv = anyio.create_memory_object_stream[SessionMessage | Exception](4) + server: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(read_stream, out_send) + + seen: list[str] = [] + + async def server_on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: + seen.append(_probe.get()) + return {} + + async def on_notify(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> None: + raise NotImplementedError + + try: + async with anyio.create_task_group() as tg: + await tg.start(server.run, server_on_request, on_notify) + + async def sender() -> None: + _probe.set("from-sender") + await write_send.send( + SessionMessage(message=JSONRPCRequest(jsonrpc="2.0", id=1, method="t", params=None)) + ) + + tg.start_soon(sender) + with anyio.fail_after(5): + resp = await out_recv.receive() + assert isinstance(resp, SessionMessage) + tg.cancel_scope.cancel() + finally: + for s in (raw_send, raw_recv, out_send, out_recv): + s.close() + assert seen == ["from-sender"] + + +@pytest.mark.anyio +async def test_response_write_after_peer_drop_is_swallowed(): + """Handler completes after the write stream is closed; the dropped write doesn't crash run().""" + c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](4) + s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](4) + server: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(c2s_recv, s2c_send) + proceed = anyio.Event() + handlers_done = anyio.Event() + + async def server_on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: + await proceed.wait() + if method == "raise": + handlers_done.set() + raise MCPError(code=INTERNAL_ERROR, message="x") + return {"ok": True} + + async def on_notify(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> None: + raise NotImplementedError + + try: + async with anyio.create_task_group() as tg: + await tg.start(server.run, server_on_request, on_notify) + await c2s_send.send(SessionMessage(message=JSONRPCRequest(jsonrpc="2.0", id=1, method="ok", params=None))) + await c2s_send.send( + SessionMessage(message=JSONRPCRequest(jsonrpc="2.0", id=2, method="raise", params=None)) + ) + await anyio.sleep(0) + # Peer drops: close the receive end so the server's writes hit BrokenResourceError. + s2c_recv.close() + proceed.set() + with anyio.fail_after(5): + await handlers_done.wait() + # run() must still be healthy — close the read side to let it exit cleanly. + c2s_send.close() + finally: + for s in (c2s_send, c2s_recv, s2c_send, s2c_recv): + s.close() + + +@pytest.mark.anyio +async def test_cancel_outbound_after_write_stream_closed_is_swallowed(): + """Courtesy-cancel write hits a closed stream; the error is swallowed and cancellation propagates.""" + c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](4) + s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](4) + client: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(s2c_recv, c2s_send) + on_request, on_notify = echo_handlers(Recorder()) + caller_done = anyio.Event() + try: + async with anyio.create_task_group() as tg: + await tg.start(client.run, on_request, on_notify) + caller_scope = anyio.CancelScope() + + async def caller() -> None: + with caller_scope: + await client.send_request("slow", None) + caller_done.set() + + tg.start_soon(caller) + # Deterministic proof the request write completed: pull it off the wire. + with anyio.fail_after(5): + sent = await c2s_recv.receive() + assert isinstance(sent, SessionMessage) + assert isinstance(sent.message, JSONRPCRequest) + # Now safe: close the client's write end, then cancel the caller. The + # shielded `_cancel_outbound` write hits ClosedResourceError and is + # swallowed; cancellation propagates cleanly. + c2s_send.close() + caller_scope.cancel() + with anyio.fail_after(5): + await caller_done.wait() + tg.cancel_scope.cancel() + finally: + for s in (c2s_send, c2s_recv, s2c_send, s2c_recv): + s.close() + + +def test_resolve_pending_drops_outcome_when_waiter_stream_already_closed(): + """White-box: a response for an id still in _pending but whose waiter has gone.""" + c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](1) + s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](1) + d: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(s2c_recv, c2s_send) + send, recv = anyio.create_memory_object_stream[dict[str, Any] | ErrorData](1) + d._pending[1] = _Pending(send=send, receive=recv) # pyright: ignore[reportPrivateUsage] + recv.close() # waiter gone — send_nowait will raise BrokenResourceError + d._resolve_pending(1, {"late": True}) # pyright: ignore[reportPrivateUsage] + for s in (c2s_send, c2s_recv, s2c_send, s2c_recv, send): + s.close() + + +def test_fan_out_closed_drops_signal_when_waiter_already_has_outcome(): + """White-box: the buffer=1 invariant — WouldBlock means waiter already has an outcome.""" + c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](1) + s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](1) + d: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(s2c_recv, c2s_send) + send, recv = anyio.create_memory_object_stream[dict[str, Any] | ErrorData](1) + # Register a fake pending and pre-fill its single buffer slot. + d._pending[1] = _Pending(send=send, receive=recv) # pyright: ignore[reportPrivateUsage] + send.send_nowait({"real": "result"}) + d._fan_out_closed() # pyright: ignore[reportPrivateUsage] + # The real result is still there; the close signal was dropped. + assert recv.receive_nowait() == {"real": "result"} + assert d._pending == {} # pyright: ignore[reportPrivateUsage] + for s in (c2s_send, c2s_recv, s2c_send, s2c_recv, send, recv): + s.close() + + +def test_outbound_metadata_with_resumption_token_returns_client_metadata(): + md = _outbound_metadata(None, {"resumption_token": "abc"}) + assert isinstance(md, ClientMessageMetadata) + assert md.resumption_token == "abc" + assert _outbound_metadata(None, None) is None + assert _outbound_metadata(None, {}) is None + + +@pytest.mark.anyio +async def test_jsonrpc_error_response_with_null_id_is_dropped(): + """Parse-error responses (id=null) have no waiter; they're logged and dropped.""" + c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) + s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) + client: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(s2c_recv, c2s_send) + on_request, on_notify = echo_handlers(Recorder()) + try: + async with anyio.create_task_group() as tg: + await tg.start(client.run, on_request, on_notify) + await s2c_send.send( + SessionMessage(message=JSONRPCError(jsonrpc="2.0", id=None, error=ErrorData(code=-32700, message="x"))) + ) + await anyio.sleep(0) + tg.cancel_scope.cancel() + finally: + for s in (c2s_send, c2s_recv, s2c_send, s2c_recv): + s.close() From 64f23487347bfda1d93e52155005616be518d360 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Thu, 16 Apr 2026 17:18:48 +0000 Subject: [PATCH 09/27] ci: run full matrix on PRs targeting any branch The pull_request branch filter meant the test/lint/coverage matrix only ran on PRs targeting main or v1.x. Stacked PRs (targeting feature branches) only got the conformance checks, which are continue-on-error and don't exercise unit tests. Removing the filter so the full matrix runs on every PR. --- .github/workflows/main.yml | 1 - 1 file changed, 1 deletion(-) diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index d34e438fc9..341df0abb8 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -5,7 +5,6 @@ on: branches: ["main", "v1.x"] tags: ["v*.*.*"] pull_request: - branches: ["main", "v1.x"] permissions: contents: read From af09bf0d323c376b9824f916a16dfbce5fb552e4 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Thu, 16 Apr 2026 17:29:41 +0000 Subject: [PATCH 10/27] test: address 3.11/3.14 coverage instrumentation quirks MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 3.14: nested async-with arc misreporting on three create_task_group lines (the documented AGENTS.md case) — pragma: no branch. 3.11: lines after async-CM exit with pytest.raises mis-traced in one test — moved the asserts inside the context manager. --- tests/shared/test_dispatcher.py | 4 ++-- tests/shared/test_jsonrpc_dispatcher.py | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/shared/test_dispatcher.py b/tests/shared/test_dispatcher.py index fc967c1299..bdadd4cdae 100644 --- a/tests/shared/test_dispatcher.py +++ b/tests/shared/test_dispatcher.py @@ -208,8 +208,8 @@ async def on_request( async with running_pair(direct_pair, server_on_request=on_request) as (client, *_): with anyio.fail_after(5), pytest.raises(MCPError) as exc: await client.send_raw_request("tools/list", {}) - assert exc.value.error.code == INTERNAL_ERROR - assert isinstance(exc.value.__cause__, ValueError) + assert exc.value.error.code == INTERNAL_ERROR + assert isinstance(exc.value.__cause__, ValueError) @pytest.mark.anyio diff --git a/tests/shared/test_jsonrpc_dispatcher.py b/tests/shared/test_jsonrpc_dispatcher.py index ff24ef4c6b..be6386d090 100644 --- a/tests/shared/test_jsonrpc_dispatcher.py +++ b/tests/shared/test_jsonrpc_dispatcher.py @@ -56,7 +56,7 @@ async def call(method: str) -> None: results[method] = await client.send_request(method, None) with anyio.fail_after(5): - async with anyio.create_task_group() as tg: + async with anyio.create_task_group() as tg: # pragma: no branch tg.start_soon(call, "first") await anyio.sleep(0) tg.start_soon(call, "second") @@ -99,7 +99,7 @@ async def server_on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | async with running_pair(jsonrpc_pair, server_on_request=server_on_request) as (client, *_): with anyio.fail_after(5): - async with anyio.create_task_group() as tg: + async with anyio.create_task_group() as tg: # pragma: no branch async def call_then_record() -> None: with pytest.raises(MCPError): # we'll cancel via tg below @@ -137,7 +137,7 @@ def factory(*, can_send_request: bool = True): result_box: list[dict[str, Any]] = [] async with running_pair(factory, server_on_request=server_on_request) as (client, *_): with anyio.fail_after(5): - async with anyio.create_task_group() as tg: + async with anyio.create_task_group() as tg: # pragma: no branch async def call() -> None: result_box.append(await client.send_request("slow", None)) From 0ec78d0197194e5b8ce1446062e44d56fdf95e07 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Thu, 16 Apr 2026 21:49:25 +0000 Subject: [PATCH 11/27] refactor: rename send_request to send_raw_request in JSONRPCDispatcher Follows the Outbound Protocol rename in the previous commit. Mechanical rename across JSONRPCDispatcher, _JSONRPCDispatchContext, and tests. --- src/mcp/shared/jsonrpc_dispatcher.py | 18 ++++++------- tests/shared/test_jsonrpc_dispatcher.py | 36 ++++++++++++------------- 2 files changed, 27 insertions(+), 27 deletions(-) diff --git a/src/mcp/shared/jsonrpc_dispatcher.py b/src/mcp/shared/jsonrpc_dispatcher.py index bbf5666069..f1e7b3675e 100644 --- a/src/mcp/shared/jsonrpc_dispatcher.py +++ b/src/mcp/shared/jsonrpc_dispatcher.py @@ -112,7 +112,7 @@ def can_send_request(self) -> bool: async def notify(self, method: str, params: Mapping[str, Any] | None) -> None: await self._dispatcher.notify(method, params, _related_request_id=self._request_id) - async def send_request( + async def send_raw_request( self, method: str, params: Mapping[str, Any] | None, @@ -120,7 +120,7 @@ async def send_request( ) -> dict[str, Any]: if not self.can_send_request: raise NoBackChannelError(method) - return await self._dispatcher.send_request(method, params, opts, _related_request_id=self._request_id) + return await self._dispatcher.send_raw_request(method, params, opts, _related_request_id=self._request_id) async def progress(self, progress: float, total: float | None = None, message: str | None = None) -> None: if self._progress_token is None: @@ -208,7 +208,7 @@ def __init__( self._tg: anyio.abc.TaskGroup | None = None self._running = False - async def send_request( + async def send_raw_request( self, method: str, params: Mapping[str, Any] | None, @@ -232,7 +232,7 @@ async def send_request( finished. """ if not self._running: - raise RuntimeError("JSONRPCDispatcher.send_request called before run() / after close") + raise RuntimeError("JSONRPCDispatcher.send_raw_request called before run() / after close") opts = opts or {} request_id = self._allocate_id() out_params = dict(params) if params is not None else None @@ -308,7 +308,7 @@ async def run( Each inbound request is handled in its own task in an internal task group; ``task_status.started()`` fires once that group is open, so - ``await tg.start(dispatcher.run, ...)`` resumes when ``send_request`` + ``await tg.start(dispatcher.run, ...)`` resumes when ``send_raw_request`` is usable. """ try: @@ -323,9 +323,9 @@ async def run( # snapshot per message). Plain memory streams don't. sender_ctx: contextvars.Context | None = getattr(self._read_stream, "last_context", None) self._dispatch(item, on_request, on_notify, sender_ctx) - # Read stream EOF: wake any blocked `send_request` waiters now, + # Read stream EOF: wake any blocked `send_raw_request` waiters now, # *before* the task group joins, so handlers parked in - # `dctx.send_request()` can unwind and the join doesn't deadlock. + # `dctx.send_raw_request()` can unwind and the join doesn't deadlock. self._running = False self._fan_out_closed() finally: @@ -457,7 +457,7 @@ def _spawn( self._tg.start_soon(fn, *args) def _fan_out_closed(self) -> None: - """Wake every pending ``send_request`` waiter with ``CONNECTION_CLOSED``. + """Wake every pending ``send_raw_request`` waiter with ``CONNECTION_CLOSED``. Synchronous (uses ``send_nowait``) because it's called from ``finally`` which may be inside a cancelled scope. Idempotent. @@ -491,7 +491,7 @@ async def _handle_request( # Close the back-channel the moment the handler exits # (success or raise), before the response write — a handler # spawning detached work that later calls - # `dctx.send_request()` should see `NoBackChannelError`. + # `dctx.send_raw_request()` should see `NoBackChannelError`. dctx.close() await self._write_result(req.id, result) # Peer-cancel: `_dispatch_notification` cancelled this scope. anyio diff --git a/tests/shared/test_jsonrpc_dispatcher.py b/tests/shared/test_jsonrpc_dispatcher.py index be6386d090..7f9f11718b 100644 --- a/tests/shared/test_jsonrpc_dispatcher.py +++ b/tests/shared/test_jsonrpc_dispatcher.py @@ -41,7 +41,7 @@ @pytest.mark.anyio -async def test_concurrent_send_requests_correlate_by_id_when_responses_arrive_out_of_order(): +async def test_concurrent_send_raw_requests_correlate_by_id_when_responses_arrive_out_of_order(): release_first = anyio.Event() async def server_on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: @@ -53,7 +53,7 @@ async def server_on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | results: dict[str, dict[str, Any]] = {} async def call(method: str) -> None: - results[method] = await client.send_request(method, None) + results[method] = await client.send_raw_request(method, None) with anyio.fail_after(5): async with anyio.create_task_group() as tg: # pragma: no branch @@ -76,7 +76,7 @@ async def server_on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | async with running_pair(jsonrpc_pair, server_on_request=server_on_request) as (client, *_): with anyio.fail_after(5), pytest.raises(MCPError) as exc: - await client.send_request("tools/list", None) + await client.send_raw_request("tools/list", None) assert exc.value.error.code == INTERNAL_ERROR assert exc.value.error.message == "kaboom" assert exc.value.__cause__ is None # cause does not survive the wire @@ -103,7 +103,7 @@ async def server_on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | async def call_then_record() -> None: with pytest.raises(MCPError): # we'll cancel via tg below - await client.send_request("slow", None) + await client.send_raw_request("slow", None) tg.start_soon(call_then_record) await handler_started.wait() @@ -140,7 +140,7 @@ def factory(*, can_send_request: bool = True): async with anyio.create_task_group() as tg: # pragma: no branch async def call() -> None: - result_box.append(await client.send_request("slow", None)) + result_box.append(await client.send_raw_request("slow", None)) tg.start_soon(call) await handler_started.wait() @@ -150,8 +150,8 @@ async def call() -> None: @pytest.mark.anyio -async def test_send_request_raises_connection_closed_when_read_stream_eofs_mid_await(): - """A blocked send_request is woken with CONNECTION_CLOSED when run() exits.""" +async def test_send_raw_request_raises_connection_closed_when_read_stream_eofs_mid_await(): + """A blocked send_raw_request is woken with CONNECTION_CLOSED when run() exits.""" c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) client: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(s2c_recv, c2s_send) @@ -162,7 +162,7 @@ async def test_send_request_raises_connection_closed_when_read_stream_eofs_mid_a async def caller() -> None: with pytest.raises(MCPError) as exc: - await client.send_request("ping", None) + await client.send_raw_request("ping", None) assert exc.value.error.code == CONNECTION_CLOSED tg.start_soon(caller) @@ -187,13 +187,13 @@ async def server_on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | async with running_pair(jsonrpc_pair, server_on_request=server_on_request) as (client, *_): with anyio.fail_after(5): with pytest.raises(MCPError): # REQUEST_TIMEOUT - await client.send_request("slow", None, {"timeout": 0}) + await client.send_raw_request("slow", None, {"timeout": 0}) # The server handler is still running; let it finish and write a # response for an id the client has already discarded. await handler_started.wait() proceed.set() # One more round-trip proves the dispatcher is still healthy. - assert await client.send_request("ping", None) == {"late": True} + assert await client.send_raw_request("ping", None) == {"late": True} @pytest.mark.anyio @@ -234,14 +234,14 @@ async def on_notify(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> @pytest.mark.anyio -async def test_ctx_send_request_tags_outbound_with_server_message_metadata(): +async def test_ctx_send_raw_request_tags_outbound_with_server_message_metadata(): """Server-to-client requests carry related_request_id for SHTTP routing.""" c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) server: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(c2s_recv, s2c_send) async def server_on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: - return await ctx.send_request("sampling/createMessage", {"prompt": "hi"}) + return await ctx.send_raw_request("sampling/createMessage", {"prompt": "hi"}) async def on_notify(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> None: raise NotImplementedError @@ -285,7 +285,7 @@ async def server_on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | async with running_pair(jsonrpc_pair, server_on_request=server_on_request) as (client, *_): with anyio.fail_after(5): - await client.send_request("t", None, {"on_progress": on_progress}) + await client.send_raw_request("t", None, {"on_progress": on_progress}) assert received == [(0.25, None, None)] @@ -297,18 +297,18 @@ async def server_on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | async with running_pair(jsonrpc_pair, server_on_request=server_on_request) as (client, *_): with anyio.fail_after(5), pytest.raises(MCPError) as exc: - await client.send_request("t", None) + await client.send_raw_request("t", None) assert exc.value.error.code == INVALID_PARAMS @pytest.mark.anyio -async def test_send_request_before_run_raises_runtimeerror(): +async def test_send_raw_request_before_run_raises_runtimeerror(): c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](1) s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](1) d: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(s2c_recv, c2s_send) try: with pytest.raises(RuntimeError, match="before run"): - await d.send_request("ping", None) + await d.send_raw_request("ping", None) finally: for s in (c2s_send, c2s_recv, s2c_send, s2c_recv): s.close() @@ -351,7 +351,7 @@ async def test_cancelled_notification_for_unknown_request_id_is_noop(): with anyio.fail_after(5): await client.notify("notifications/cancelled", {"requestId": 999}) # No effect; dispatcher remains healthy. - assert await client.send_request("t", None) == {"echoed": "t", "params": {}} + assert await client.send_raw_request("t", None) == {"echoed": "t", "params": {}} assert srec.notifications == [] # cancelled is fully consumed, never teed @@ -451,7 +451,7 @@ async def test_cancel_outbound_after_write_stream_closed_is_swallowed(): async def caller() -> None: with caller_scope: - await client.send_request("slow", None) + await client.send_raw_request("slow", None) caller_done.set() tg.start_soon(caller) From 17340aa3d261c8083e110daa9aca63aff5479a13 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Thu, 16 Apr 2026 19:21:27 +0000 Subject: [PATCH 12/27] feat: PeerMixin and Peer wrapper MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit PeerMixin defines the typed server-to-client request methods (sample with overloads, elicit_form, elicit_url, list_roots, ping) once. Each method constrains `self: Outbound` so any class with send_request/notify can mix it in — pyright checks the host structurally at the call site. The mixin does no capability gating; that's the host's send_request's job. Peer is a trivial standalone wrapper for when you have a bare Outbound (e.g. a dispatcher) and want the typed sugar without writing your own host class. 6 tests over DirectDispatcher, 0.03s. --- src/mcp/shared/peer.py | 194 ++++++++++++++++++++++++++++++++++++++ tests/shared/test_peer.py | 128 +++++++++++++++++++++++++ 2 files changed, 322 insertions(+) create mode 100644 src/mcp/shared/peer.py create mode 100644 tests/shared/test_peer.py diff --git a/src/mcp/shared/peer.py b/src/mcp/shared/peer.py new file mode 100644 index 0000000000..b5d4b960ed --- /dev/null +++ b/src/mcp/shared/peer.py @@ -0,0 +1,194 @@ +"""Typed MCP request sugar over an `Outbound`. + +`PeerMixin` defines the server-to-client request methods (sampling, elicitation, +roots, ping) once. Any class that satisfies `Outbound` (i.e. has `send_request` +and `notify`) can mix it in and get the typed methods for free — `Context`, +`Connection`, `Client`, or the bare `Peer` wrapper below. + +The mixin does no capability gating: it builds the params, calls +``self.send_request(method, params)``, and parses the result into the typed +model. Gating (and `NoBackChannelError`) is the host's `send_request`'s job. +""" + +from collections.abc import Mapping +from typing import Any, overload + +from pydantic import BaseModel + +from mcp.shared.dispatcher import CallOptions, Outbound +from mcp.types import ( + CreateMessageRequestParams, + CreateMessageResult, + CreateMessageResultWithTools, + ElicitRequestedSchema, + ElicitRequestFormParams, + ElicitRequestURLParams, + ElicitResult, + IncludeContext, + ListRootsResult, + ModelPreferences, + SamplingMessage, + Tool, + ToolChoice, +) + +__all__ = ["Peer", "PeerMixin"] + + +def _dump(model: BaseModel) -> dict[str, Any]: + return model.model_dump(by_alias=True, mode="json", exclude_none=True) + + +class PeerMixin: + """Typed server-to-client request methods. + + Each method constrains ``self`` to `Outbound` so the mixin can be applied + to anything with ``send_request``/``notify`` — pyright checks the host + class structurally at the call site. + """ + + @overload + async def sample( + self: Outbound, + messages: list[SamplingMessage], + *, + max_tokens: int, + system_prompt: str | None = None, + include_context: IncludeContext | None = None, + temperature: float | None = None, + stop_sequences: list[str] | None = None, + metadata: dict[str, Any] | None = None, + model_preferences: ModelPreferences | None = None, + tools: None = None, + tool_choice: ToolChoice | None = None, + opts: CallOptions | None = None, + ) -> CreateMessageResult: ... + @overload + async def sample( + self: Outbound, + messages: list[SamplingMessage], + *, + max_tokens: int, + system_prompt: str | None = None, + include_context: IncludeContext | None = None, + temperature: float | None = None, + stop_sequences: list[str] | None = None, + metadata: dict[str, Any] | None = None, + model_preferences: ModelPreferences | None = None, + tools: list[Tool], + tool_choice: ToolChoice | None = None, + opts: CallOptions | None = None, + ) -> CreateMessageResultWithTools: ... + async def sample( + self: Outbound, + messages: list[SamplingMessage], + *, + max_tokens: int, + system_prompt: str | None = None, + include_context: IncludeContext | None = None, + temperature: float | None = None, + stop_sequences: list[str] | None = None, + metadata: dict[str, Any] | None = None, + model_preferences: ModelPreferences | None = None, + tools: list[Tool] | None = None, + tool_choice: ToolChoice | None = None, + opts: CallOptions | None = None, + ) -> CreateMessageResult | CreateMessageResultWithTools: + """Send a ``sampling/createMessage`` request to the peer. + + Raises: + MCPError: The peer responded with an error. + NoBackChannelError: The host's transport context has no + back-channel for server-initiated requests. + """ + params = CreateMessageRequestParams( + messages=messages, + system_prompt=system_prompt, + include_context=include_context, + temperature=temperature, + max_tokens=max_tokens, + stop_sequences=stop_sequences, + metadata=metadata, + model_preferences=model_preferences, + tools=tools, + tool_choice=tool_choice, + ) + result = await self.send_request("sampling/createMessage", _dump(params), opts) + if tools is not None: + return CreateMessageResultWithTools.model_validate(result) + return CreateMessageResult.model_validate(result) + + async def elicit_form( + self: Outbound, + message: str, + requested_schema: ElicitRequestedSchema, + opts: CallOptions | None = None, + ) -> ElicitResult: + """Send a form-mode ``elicitation/create`` request. + + Raises: + MCPError: The peer responded with an error. + NoBackChannelError: No back-channel for server-initiated requests. + """ + params = ElicitRequestFormParams(message=message, requested_schema=requested_schema) + result = await self.send_request("elicitation/create", _dump(params), opts) + return ElicitResult.model_validate(result) + + async def elicit_url( + self: Outbound, + message: str, + url: str, + elicitation_id: str, + opts: CallOptions | None = None, + ) -> ElicitResult: + """Send a URL-mode ``elicitation/create`` request. + + Raises: + MCPError: The peer responded with an error. + NoBackChannelError: No back-channel for server-initiated requests. + """ + params = ElicitRequestURLParams(message=message, url=url, elicitation_id=elicitation_id) + result = await self.send_request("elicitation/create", _dump(params), opts) + return ElicitResult.model_validate(result) + + async def list_roots(self: Outbound, opts: CallOptions | None = None) -> ListRootsResult: + """Send a ``roots/list`` request. + + Raises: + MCPError: The peer responded with an error. + NoBackChannelError: No back-channel for server-initiated requests. + """ + result = await self.send_request("roots/list", None, opts) + return ListRootsResult.model_validate(result) + + async def ping(self: Outbound, opts: CallOptions | None = None) -> None: + """Send a ``ping`` request and ignore the result. + + Raises: + MCPError: The peer responded with an error. + NoBackChannelError: No back-channel for server-initiated requests. + """ + await self.send_request("ping", None, opts) + + +class Peer(PeerMixin): + """Standalone wrapper that gives any `Outbound` the `PeerMixin` sugar. + + `Context` and `Connection` mix `PeerMixin` in directly; use `Peer` when + you have a bare dispatcher (or any `Outbound`) and want the typed methods + without writing your own host class. + """ + + def __init__(self, outbound: Outbound) -> None: + self._outbound = outbound + + async def send_request( + self, + method: str, + params: Mapping[str, Any] | None, + opts: CallOptions | None = None, + ) -> dict[str, Any]: + return await self._outbound.send_request(method, params, opts) + + async def notify(self, method: str, params: Mapping[str, Any] | None) -> None: + await self._outbound.notify(method, params) diff --git a/tests/shared/test_peer.py b/tests/shared/test_peer.py new file mode 100644 index 0000000000..43d49252cb --- /dev/null +++ b/tests/shared/test_peer.py @@ -0,0 +1,128 @@ +"""Tests for `PeerMixin` and `Peer`. + +Each PeerMixin method is tested by wrapping a `DirectDispatcher` in `Peer`, +calling the typed method, and asserting (a) the right method+params went out +and (b) the return value is the typed result model. +""" + +from collections.abc import Mapping +from typing import Any + +import anyio +import pytest + +from mcp.shared.dispatcher import DispatchContext +from mcp.shared.peer import Peer +from mcp.shared.transport_context import TransportContext +from mcp.types import ( + CreateMessageResult, + CreateMessageResultWithTools, + ElicitResult, + ListRootsResult, + SamplingMessage, + TextContent, + Tool, +) + +from .conftest import direct_pair +from .test_dispatcher import running_pair + +DCtx = DispatchContext[TransportContext] + + +class _Recorder: + def __init__(self, result: dict[str, Any]) -> None: + self.result = result + self.seen: list[tuple[str, Mapping[str, Any] | None]] = [] + + async def on_request(self, ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: + self.seen.append((method, params)) + return self.result + + +@pytest.mark.anyio +async def test_peer_sample_sends_create_message_and_returns_typed_result(): + rec = _Recorder({"role": "assistant", "content": {"type": "text", "text": "hi"}, "model": "m"}) + async with running_pair(direct_pair, server_on_request=rec.on_request) as (client, *_): + peer = Peer(client) + with anyio.fail_after(5): + result = await peer.sample( + [SamplingMessage(role="user", content=TextContent(type="text", text="hello"))], + max_tokens=10, + ) + method, params = rec.seen[0] + assert method == "sampling/createMessage" + assert params is not None and params["maxTokens"] == 10 + assert isinstance(result, CreateMessageResult) + assert result.model == "m" + + +@pytest.mark.anyio +async def test_peer_sample_with_tools_returns_with_tools_result(): + rec = _Recorder({"role": "assistant", "content": [{"type": "text", "text": "x"}], "model": "m"}) + async with running_pair(direct_pair, server_on_request=rec.on_request) as (client, *_): + peer = Peer(client) + with anyio.fail_after(5): + result = await peer.sample( + [SamplingMessage(role="user", content=TextContent(type="text", text="q"))], + max_tokens=5, + tools=[Tool(name="t", input_schema={"type": "object"})], + ) + method, params = rec.seen[0] + assert method == "sampling/createMessage" + assert params is not None and params["tools"][0]["name"] == "t" + assert isinstance(result, CreateMessageResultWithTools) + + +@pytest.mark.anyio +async def test_peer_elicit_form_sends_elicitation_create_with_form_params(): + rec = _Recorder({"action": "accept", "content": {"name": "Max"}}) + async with running_pair(direct_pair, server_on_request=rec.on_request) as (client, *_): + peer = Peer(client) + with anyio.fail_after(5): + result = await peer.elicit_form("Your name?", requested_schema={"type": "object", "properties": {}}) + method, params = rec.seen[0] + assert method == "elicitation/create" + assert params is not None and params["mode"] == "form" + assert params["message"] == "Your name?" + assert isinstance(result, ElicitResult) + + +@pytest.mark.anyio +async def test_peer_elicit_url_sends_elicitation_create_with_url_params(): + rec = _Recorder({"action": "accept"}) + async with running_pair(direct_pair, server_on_request=rec.on_request) as (client, *_): + peer = Peer(client) + with anyio.fail_after(5): + result = await peer.elicit_url("Auth needed", url="https://example.com/auth", elicitation_id="e1") + method, params = rec.seen[0] + assert method == "elicitation/create" + assert params is not None and params["mode"] == "url" + assert params["url"] == "https://example.com/auth" + assert isinstance(result, ElicitResult) + + +@pytest.mark.anyio +async def test_peer_list_roots_sends_roots_list_and_returns_typed_result(): + rec = _Recorder({"roots": [{"uri": "file:///workspace"}]}) + async with running_pair(direct_pair, server_on_request=rec.on_request) as (client, *_): + peer = Peer(client) + with anyio.fail_after(5): + result = await peer.list_roots() + method, _ = rec.seen[0] + assert method == "roots/list" + assert isinstance(result, ListRootsResult) + assert len(result.roots) == 1 + assert str(result.roots[0].uri) == "file:///workspace" + + +@pytest.mark.anyio +async def test_peer_ping_sends_ping_and_returns_none(): + rec = _Recorder({}) + async with running_pair(direct_pair, server_on_request=rec.on_request) as (client, *_): + peer = Peer(client) + with anyio.fail_after(5): + result = await peer.ping() + method, _ = rec.seen[0] + assert method == "ping" + assert result is None From 88539b30259b2a432c3ed41a0f1d6d1c7b695dc6 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Thu, 16 Apr 2026 19:42:46 +0000 Subject: [PATCH 13/27] feat: BaseContext Composition over a DispatchContext: forwards transport/cancel_requested/ send_request/notify/progress and adds meta. Satisfies Outbound so PeerMixin works on it (proven by Peer(bctx).ping() round-tripping). The server Context (next commit) extends this with lifespan/connection; ClientContext will be an alias once ClientSession is reworked. --- src/mcp/shared/context.py | 82 +++++++++++++++++++++++++ tests/shared/test_context.py | 115 +++++++++++++++++++++++++++++++++++ 2 files changed, 197 insertions(+) create mode 100644 src/mcp/shared/context.py create mode 100644 tests/shared/test_context.py diff --git a/src/mcp/shared/context.py b/src/mcp/shared/context.py new file mode 100644 index 0000000000..f6a33d719a --- /dev/null +++ b/src/mcp/shared/context.py @@ -0,0 +1,82 @@ +"""`BaseContext` — the user-facing per-request context. + +Composition over a `DispatchContext`: forwards the transport metadata, the +back-channel (`send_request`/`notify`), progress reporting, and the cancel +event. Adds `meta` (the inbound request's `_meta` field). + +Satisfies `Outbound`, so `PeerMixin` works on it (the server-side `Context` +mixes that in directly). Shared between client and server: the server's +`Context` extends this with `lifespan`/`connection`; `ClientContext` is just an +alias. +""" + +from collections.abc import Mapping +from typing import Any, Generic + +import anyio +from typing_extensions import TypeVar + +from mcp.shared.dispatcher import CallOptions, DispatchContext +from mcp.shared.transport_context import TransportContext +from mcp.types import RequestParamsMeta + +__all__ = ["BaseContext"] + +TransportT = TypeVar("TransportT", bound=TransportContext, default=TransportContext) + + +class BaseContext(Generic[TransportT]): + """Per-request context wrapping a `DispatchContext`. + + `ServerRunner` (PR4) constructs one per inbound request and passes it to + the user's handler. + """ + + def __init__(self, dctx: DispatchContext[TransportT], meta: RequestParamsMeta | None = None) -> None: + self._dctx = dctx + self._meta = meta + + @property + def transport(self) -> TransportT: + """Transport-specific metadata for this inbound request.""" + return self._dctx.transport + + @property + def cancel_requested(self) -> anyio.Event: + """Set when the peer sends ``notifications/cancelled`` for this request.""" + return self._dctx.cancel_requested + + @property + def can_send_request(self) -> bool: + """Whether the back-channel can deliver server-initiated requests.""" + return self._dctx.transport.can_send_request + + @property + def meta(self) -> RequestParamsMeta | None: + """The inbound request's ``_meta`` field, if present.""" + return self._meta + + async def send_request( + self, + method: str, + params: Mapping[str, Any] | None, + opts: CallOptions | None = None, + ) -> dict[str, Any]: + """Send a request to the peer on the back-channel. + + Raises: + MCPError: The peer responded with an error. + NoBackChannelError: ``can_send_request`` is ``False``. + """ + return await self._dctx.send_request(method, params, opts) + + async def notify(self, method: str, params: Mapping[str, Any] | None) -> None: + """Send a notification to the peer on the back-channel.""" + await self._dctx.notify(method, params) + + async def report_progress(self, progress: float, total: float | None = None, message: str | None = None) -> None: + """Report progress for this request, if the peer supplied a progress token. + + A no-op when no token was supplied. + """ + await self._dctx.progress(progress, total, message) diff --git a/tests/shared/test_context.py b/tests/shared/test_context.py new file mode 100644 index 0000000000..5d93768433 --- /dev/null +++ b/tests/shared/test_context.py @@ -0,0 +1,115 @@ +"""Tests for `BaseContext`. + +`BaseContext` is composition over a `DispatchContext` — it forwards +``transport``/``cancel_requested``/``send_request``/``notify``/``progress`` +and adds ``meta``. It must satisfy `Outbound` so `PeerMixin` works on it. +""" + +from collections.abc import Mapping +from typing import Any + +import anyio +import pytest + +from mcp.shared.context import BaseContext +from mcp.shared.dispatcher import DispatchContext +from mcp.shared.peer import Peer +from mcp.shared.transport_context import TransportContext + +from .conftest import direct_pair +from .test_dispatcher import Recorder, echo_handlers, running_pair + +DCtx = DispatchContext[TransportContext] + + +@pytest.mark.anyio +async def test_base_context_forwards_transport_and_cancel_requested(): + captured: list[BaseContext[TransportContext]] = [] + + async def server_on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: + bctx = BaseContext(ctx) + captured.append(bctx) + return {} + + async with running_pair(direct_pair, server_on_request=server_on_request) as (client, *_): + with anyio.fail_after(5): + await client.send_request("t", None) + bctx = captured[0] + assert bctx.transport.kind == "direct" + assert isinstance(bctx.cancel_requested, anyio.Event) + assert bctx.can_send_request is True + assert bctx.meta is None + + +@pytest.mark.anyio +async def test_base_context_send_request_and_notify_forward_to_dispatch_context(): + crec = Recorder() + c_req, c_notify = echo_handlers(crec) + + async def server_on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: + bctx = BaseContext(ctx) + sample = await bctx.send_request("sampling/createMessage", {"x": 1}) + await bctx.notify("notifications/message", {"level": "info"}) + return {"sample": sample} + + async with running_pair( + direct_pair, + server_on_request=server_on_request, + client_on_request=c_req, + client_on_notify=c_notify, + ) as (client, *_): + with anyio.fail_after(5): + result = await client.send_request("tools/call", None) + await crec.notified.wait() + assert crec.requests == [("sampling/createMessage", {"x": 1})] + assert crec.notifications == [("notifications/message", {"level": "info"})] + assert result["sample"] == {"echoed": "sampling/createMessage", "params": {"x": 1}} + + +@pytest.mark.anyio +async def test_base_context_report_progress_invokes_caller_on_progress(): + received: list[tuple[float, float | None, str | None]] = [] + + async def on_progress(progress: float, total: float | None, message: str | None) -> None: + received.append((progress, total, message)) + + async def server_on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: + bctx = BaseContext(ctx) + await bctx.report_progress(0.5, total=1.0, message="halfway") + return {} + + async with running_pair(direct_pair, server_on_request=server_on_request) as (client, *_): + with anyio.fail_after(5): + await client.send_request("t", None, {"on_progress": on_progress}) + assert received == [(0.5, 1.0, "halfway")] + + +@pytest.mark.anyio +async def test_base_context_satisfies_outbound_so_peer_mixin_works(): + """Wrapping a BaseContext in Peer proves it satisfies Outbound structurally.""" + + async def server_on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: + bctx = BaseContext(ctx) + await Peer(bctx).ping() + return {} + + crec = Recorder() + c_req, c_notify = echo_handlers(crec) + async with running_pair( + direct_pair, server_on_request=server_on_request, client_on_request=c_req, client_on_notify=c_notify + ) as (client, *_): + with anyio.fail_after(5): + await client.send_request("t", None) + assert crec.requests == [("ping", None)] + + +@pytest.mark.anyio +async def test_base_context_meta_holds_supplied_request_params_meta(): + async def server_on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: + bctx = BaseContext(ctx, meta={"progressToken": "abc"}) + assert bctx.meta is not None and bctx.meta.get("progressToken") == "abc" + return {} + + async with running_pair(direct_pair, server_on_request=server_on_request) as (client, *_): + with anyio.fail_after(5): + await client.send_request("t", None) From edda6a80e919fbaa677d20844fb9392dbc5113a1 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Thu, 16 Apr 2026 22:02:36 +0000 Subject: [PATCH 14/27] refactor: follow Outbound.send_raw_request rename in PeerMixin/BaseContext PeerMixin methods and Peer/BaseContext now call/expose send_raw_request. The typed send_request lands on Connection/Context in the next commit. --- src/mcp/shared/context.py | 6 +++--- src/mcp/shared/peer.py | 22 +++++++++++----------- tests/shared/test_context.py | 16 ++++++++-------- 3 files changed, 22 insertions(+), 22 deletions(-) diff --git a/src/mcp/shared/context.py b/src/mcp/shared/context.py index f6a33d719a..68f439b738 100644 --- a/src/mcp/shared/context.py +++ b/src/mcp/shared/context.py @@ -1,7 +1,7 @@ """`BaseContext` — the user-facing per-request context. Composition over a `DispatchContext`: forwards the transport metadata, the -back-channel (`send_request`/`notify`), progress reporting, and the cancel +back-channel (`send_raw_request`/`notify`), progress reporting, and the cancel event. Adds `meta` (the inbound request's `_meta` field). Satisfies `Outbound`, so `PeerMixin` works on it (the server-side `Context` @@ -56,7 +56,7 @@ def meta(self) -> RequestParamsMeta | None: """The inbound request's ``_meta`` field, if present.""" return self._meta - async def send_request( + async def send_raw_request( self, method: str, params: Mapping[str, Any] | None, @@ -68,7 +68,7 @@ async def send_request( MCPError: The peer responded with an error. NoBackChannelError: ``can_send_request`` is ``False``. """ - return await self._dctx.send_request(method, params, opts) + return await self._dctx.send_raw_request(method, params, opts) async def notify(self, method: str, params: Mapping[str, Any] | None) -> None: """Send a notification to the peer on the back-channel.""" diff --git a/src/mcp/shared/peer.py b/src/mcp/shared/peer.py index b5d4b960ed..9951081104 100644 --- a/src/mcp/shared/peer.py +++ b/src/mcp/shared/peer.py @@ -1,13 +1,13 @@ """Typed MCP request sugar over an `Outbound`. `PeerMixin` defines the server-to-client request methods (sampling, elicitation, -roots, ping) once. Any class that satisfies `Outbound` (i.e. has `send_request` +roots, ping) once. Any class that satisfies `Outbound` (i.e. has `send_raw_request` and `notify`) can mix it in and get the typed methods for free — `Context`, `Connection`, `Client`, or the bare `Peer` wrapper below. The mixin does no capability gating: it builds the params, calls -``self.send_request(method, params)``, and parses the result into the typed -model. Gating (and `NoBackChannelError`) is the host's `send_request`'s job. +``self.send_raw_request(method, params)``, and parses the result into the typed +model. Gating (and `NoBackChannelError`) is the host's `send_raw_request`'s job. """ from collections.abc import Mapping @@ -43,7 +43,7 @@ class PeerMixin: """Typed server-to-client request methods. Each method constrains ``self`` to `Outbound` so the mixin can be applied - to anything with ``send_request``/``notify`` — pyright checks the host + to anything with ``send_raw_request``/``notify`` — pyright checks the host class structurally at the call site. """ @@ -113,7 +113,7 @@ async def sample( tools=tools, tool_choice=tool_choice, ) - result = await self.send_request("sampling/createMessage", _dump(params), opts) + result = await self.send_raw_request("sampling/createMessage", _dump(params), opts) if tools is not None: return CreateMessageResultWithTools.model_validate(result) return CreateMessageResult.model_validate(result) @@ -131,7 +131,7 @@ async def elicit_form( NoBackChannelError: No back-channel for server-initiated requests. """ params = ElicitRequestFormParams(message=message, requested_schema=requested_schema) - result = await self.send_request("elicitation/create", _dump(params), opts) + result = await self.send_raw_request("elicitation/create", _dump(params), opts) return ElicitResult.model_validate(result) async def elicit_url( @@ -148,7 +148,7 @@ async def elicit_url( NoBackChannelError: No back-channel for server-initiated requests. """ params = ElicitRequestURLParams(message=message, url=url, elicitation_id=elicitation_id) - result = await self.send_request("elicitation/create", _dump(params), opts) + result = await self.send_raw_request("elicitation/create", _dump(params), opts) return ElicitResult.model_validate(result) async def list_roots(self: Outbound, opts: CallOptions | None = None) -> ListRootsResult: @@ -158,7 +158,7 @@ async def list_roots(self: Outbound, opts: CallOptions | None = None) -> ListRoo MCPError: The peer responded with an error. NoBackChannelError: No back-channel for server-initiated requests. """ - result = await self.send_request("roots/list", None, opts) + result = await self.send_raw_request("roots/list", None, opts) return ListRootsResult.model_validate(result) async def ping(self: Outbound, opts: CallOptions | None = None) -> None: @@ -168,7 +168,7 @@ async def ping(self: Outbound, opts: CallOptions | None = None) -> None: MCPError: The peer responded with an error. NoBackChannelError: No back-channel for server-initiated requests. """ - await self.send_request("ping", None, opts) + await self.send_raw_request("ping", None, opts) class Peer(PeerMixin): @@ -182,13 +182,13 @@ class Peer(PeerMixin): def __init__(self, outbound: Outbound) -> None: self._outbound = outbound - async def send_request( + async def send_raw_request( self, method: str, params: Mapping[str, Any] | None, opts: CallOptions | None = None, ) -> dict[str, Any]: - return await self._outbound.send_request(method, params, opts) + return await self._outbound.send_raw_request(method, params, opts) async def notify(self, method: str, params: Mapping[str, Any] | None) -> None: await self._outbound.notify(method, params) diff --git a/tests/shared/test_context.py b/tests/shared/test_context.py index 5d93768433..951690028f 100644 --- a/tests/shared/test_context.py +++ b/tests/shared/test_context.py @@ -1,7 +1,7 @@ """Tests for `BaseContext`. `BaseContext` is composition over a `DispatchContext` — it forwards -``transport``/``cancel_requested``/``send_request``/``notify``/``progress`` +``transport``/``cancel_requested``/``send_raw_request``/``notify``/``progress`` and adds ``meta``. It must satisfy `Outbound` so `PeerMixin` works on it. """ @@ -33,7 +33,7 @@ async def server_on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | async with running_pair(direct_pair, server_on_request=server_on_request) as (client, *_): with anyio.fail_after(5): - await client.send_request("t", None) + await client.send_raw_request("t", None) bctx = captured[0] assert bctx.transport.kind == "direct" assert isinstance(bctx.cancel_requested, anyio.Event) @@ -42,13 +42,13 @@ async def server_on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | @pytest.mark.anyio -async def test_base_context_send_request_and_notify_forward_to_dispatch_context(): +async def test_base_context_send_raw_request_and_notify_forward_to_dispatch_context(): crec = Recorder() c_req, c_notify = echo_handlers(crec) async def server_on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: bctx = BaseContext(ctx) - sample = await bctx.send_request("sampling/createMessage", {"x": 1}) + sample = await bctx.send_raw_request("sampling/createMessage", {"x": 1}) await bctx.notify("notifications/message", {"level": "info"}) return {"sample": sample} @@ -59,7 +59,7 @@ async def server_on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | client_on_notify=c_notify, ) as (client, *_): with anyio.fail_after(5): - result = await client.send_request("tools/call", None) + result = await client.send_raw_request("tools/call", None) await crec.notified.wait() assert crec.requests == [("sampling/createMessage", {"x": 1})] assert crec.notifications == [("notifications/message", {"level": "info"})] @@ -80,7 +80,7 @@ async def server_on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | async with running_pair(direct_pair, server_on_request=server_on_request) as (client, *_): with anyio.fail_after(5): - await client.send_request("t", None, {"on_progress": on_progress}) + await client.send_raw_request("t", None, {"on_progress": on_progress}) assert received == [(0.5, 1.0, "halfway")] @@ -99,7 +99,7 @@ async def server_on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | direct_pair, server_on_request=server_on_request, client_on_request=c_req, client_on_notify=c_notify ) as (client, *_): with anyio.fail_after(5): - await client.send_request("t", None) + await client.send_raw_request("t", None) assert crec.requests == [("ping", None)] @@ -112,4 +112,4 @@ async def server_on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | async with running_pair(direct_pair, server_on_request=server_on_request) as (client, *_): with anyio.fail_after(5): - await client.send_request("t", None) + await client.send_raw_request("t", None) From b538bca27485bcbe84fa7dae45a8f0eb817ee97d Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Thu, 16 Apr 2026 22:21:02 +0000 Subject: [PATCH 15/27] feat: Connection, server Context, typed send_request, meta kwarg TypedServerRequestMixin (server/_typed_request.py) provides shape-2 typed send_request: per-spec overloads (CreateMessage/Elicit/ListRoots/Ping) infer the result type; custom requests pass result_type explicitly. Mixed into both Connection and the server Context. Connection (server/connection.py) wraps an Outbound for the standalone stream. notify is best-effort (never raises); send_raw_request gated on has_standalone_channel; check_capability mirrors v1 for now (FOLLOWUP). Holds peer info populated at initialize time and the per-connection lifespan state. Context (server/context.py, alongside v1's ServerRequestContext) composes BaseContext + PeerMixin + TypedServerRequestMixin and adds lifespan/connection. Request-scoped log() rides the request's back-channel; ctx.connection.log() uses the standalone stream. dump_params(model, meta) merges user-supplied meta into _meta; threaded through every PeerMixin and Connection convenience method. 31 tests, 0.06s. --- src/mcp/server/_typed_request.py | 85 +++++++++++++ src/mcp/server/connection.py | 146 ++++++++++++++++++++++ src/mcp/server/context.py | 60 +++++++++ src/mcp/shared/peer.py | 48 ++++++-- tests/server/test_connection.py | 184 ++++++++++++++++++++++++++++ tests/server/test_server_context.py | 131 ++++++++++++++++++++ tests/shared/test_peer.py | 21 +++- 7 files changed, 661 insertions(+), 14 deletions(-) create mode 100644 src/mcp/server/_typed_request.py create mode 100644 src/mcp/server/connection.py create mode 100644 tests/server/test_connection.py create mode 100644 tests/server/test_server_context.py diff --git a/src/mcp/server/_typed_request.py b/src/mcp/server/_typed_request.py new file mode 100644 index 0000000000..50cae159d1 --- /dev/null +++ b/src/mcp/server/_typed_request.py @@ -0,0 +1,85 @@ +"""Shape-2 typed ``send_request`` for server-to-client requests. + +`TypedServerRequestMixin` provides a typed `send_request(req) -> Result` over +the host's raw `Outbound.send_raw_request`. Spec server-to-client request types +have their result type inferred via per-type overloads; custom requests pass +``result_type=`` explicitly. + +A `HasResult[R]` protocol (one generic signature, mapping declared on the +request type) is the cleaner long-term shape — see FOLLOWUPS.md. This per-spec +overload set is used for now to avoid touching `mcp.types`. +""" + +from typing import Any, TypeVar, overload + +from pydantic import BaseModel + +from mcp.shared.dispatcher import CallOptions, Outbound +from mcp.shared.peer import dump_params +from mcp.types import ( + CreateMessageRequest, + CreateMessageResult, + ElicitRequest, + ElicitResult, + EmptyResult, + ListRootsRequest, + ListRootsResult, + PingRequest, + Request, +) + +__all__ = ["TypedServerRequestMixin"] + +ResultT = TypeVar("ResultT", bound=BaseModel) + +_RESULT_FOR: dict[type[Request[Any, Any]], type[BaseModel]] = { + CreateMessageRequest: CreateMessageResult, + ElicitRequest: ElicitResult, + ListRootsRequest: ListRootsResult, + PingRequest: EmptyResult, +} + + +class TypedServerRequestMixin: + """Typed ``send_request`` for the server-to-client request set. + + Mixed into `Connection` and the server `Context`. Each method constrains + ``self`` to `Outbound` so any host with ``send_raw_request`` works. + """ + + @overload + async def send_request( + self: Outbound, req: CreateMessageRequest, *, opts: CallOptions | None = None + ) -> CreateMessageResult: ... + @overload + async def send_request(self: Outbound, req: ElicitRequest, *, opts: CallOptions | None = None) -> ElicitResult: ... + @overload + async def send_request( + self: Outbound, req: ListRootsRequest, *, opts: CallOptions | None = None + ) -> ListRootsResult: ... + @overload + async def send_request(self: Outbound, req: PingRequest, *, opts: CallOptions | None = None) -> EmptyResult: ... + @overload + async def send_request( + self: Outbound, req: Request[Any, Any], *, result_type: type[ResultT], opts: CallOptions | None = None + ) -> ResultT: ... + async def send_request( + self: Outbound, + req: Request[Any, Any], + *, + result_type: type[BaseModel] | None = None, + opts: CallOptions | None = None, + ) -> BaseModel: + """Send a typed server-to-client request and return its typed result. + + For spec request types the result type is inferred. For custom requests + pass ``result_type=`` explicitly. + + Raises: + MCPError: The peer responded with an error. + NoBackChannelError: No back-channel for server-initiated requests. + KeyError: ``result_type`` omitted for a non-spec request type. + """ + raw = await self.send_raw_request(req.method, dump_params(req.params), opts) + cls = result_type if result_type is not None else _RESULT_FOR[type(req)] + return cls.model_validate(raw) diff --git a/src/mcp/server/connection.py b/src/mcp/server/connection.py new file mode 100644 index 0000000000..72c4ed062f --- /dev/null +++ b/src/mcp/server/connection.py @@ -0,0 +1,146 @@ +"""`Connection` — per-client connection state and the standalone outbound channel. + +Always present on `Context` (never ``None``), even in stateless deployments. +Holds peer info populated at ``initialize`` time, the per-connection lifespan +output, and an `Outbound` for the standalone stream (the SSE GET stream in +streamable HTTP, or the single duplex stream in stdio). + +`notify` is best-effort: it never raises. If there's no standalone channel +(stateless HTTP) or the stream has been dropped, the notification is +debug-logged and silently discarded — server-initiated notifications are +inherently advisory. `send_raw_request` *does* raise `NoBackChannelError` when +there's no channel; `ping` is the only spec-sanctioned standalone request. +""" + +import logging +from collections.abc import Mapping +from typing import Any + +import anyio + +from mcp.server._typed_request import TypedServerRequestMixin +from mcp.shared.dispatcher import CallOptions, Outbound +from mcp.shared.exceptions import NoBackChannelError +from mcp.shared.peer import Meta, dump_params +from mcp.types import ClientCapabilities, Implementation, LoggingLevel + +__all__ = ["Connection"] + +logger = logging.getLogger(__name__) + + +def _notification_params(payload: dict[str, Any] | None, meta: Meta | None) -> dict[str, Any] | None: + if not meta: + return payload + out = dict(payload or {}) + out["_meta"] = meta + return out + + +class Connection(TypedServerRequestMixin): + """Per-client connection state and standalone-stream `Outbound`. + + Constructed by `ServerRunner` once per connection. The peer-info fields are + ``None`` until ``initialize`` completes; ``initialized`` is set then. + """ + + def __init__(self, outbound: Outbound, *, has_standalone_channel: bool) -> None: + self._outbound = outbound + self.has_standalone_channel = has_standalone_channel + + self.client_info: Implementation | None = None + self.client_capabilities: ClientCapabilities | None = None + self.protocol_version: str | None = None + self.initialized: anyio.Event = anyio.Event() + # TODO: make this generic (Connection[StateT]) once connection_lifespan + # wiring lands in ServerRunner — see FOLLOWUPS.md. + self.state: Any = None + + async def send_raw_request( + self, + method: str, + params: Mapping[str, Any] | None, + opts: CallOptions | None = None, + ) -> dict[str, Any]: + """Send a raw request on the standalone stream. + + Low-level `Outbound` channel. Prefer the typed ``send_request`` (from + `TypedServerRequestMixin`) or the convenience methods below; use this + directly only for off-spec messages. + + Raises: + MCPError: The peer responded with an error. + NoBackChannelError: ``has_standalone_channel`` is ``False``. + """ + if not self.has_standalone_channel: + raise NoBackChannelError(method) + return await self._outbound.send_raw_request(method, params, opts) + + async def notify(self, method: str, params: Mapping[str, Any] | None) -> None: + """Send a best-effort notification on the standalone stream. + + Never raises. If there's no standalone channel or the stream is broken, + the notification is dropped and debug-logged. + """ + if not self.has_standalone_channel: + logger.debug("dropped %s: no standalone channel", method) + return + try: + await self._outbound.notify(method, params) + except (anyio.BrokenResourceError, anyio.ClosedResourceError): + logger.debug("dropped %s: standalone stream closed", method) + + async def ping(self, *, meta: Meta | None = None, opts: CallOptions | None = None) -> None: + """Send a ``ping`` request on the standalone stream. + + Raises: + MCPError: The peer responded with an error. + NoBackChannelError: ``has_standalone_channel`` is ``False``. + """ + await self.send_raw_request("ping", dump_params(None, meta), opts) + + async def log(self, level: LoggingLevel, data: Any, logger: str | None = None, *, meta: Meta | None = None) -> None: + """Send a ``notifications/message`` log entry on the standalone stream. Best-effort.""" + params: dict[str, Any] = {"level": level, "data": data} + if logger is not None: + params["logger"] = logger + await self.notify("notifications/message", _notification_params(params, meta)) + + async def send_tool_list_changed(self, *, meta: Meta | None = None) -> None: + await self.notify("notifications/tools/list_changed", _notification_params(None, meta)) + + async def send_prompt_list_changed(self, *, meta: Meta | None = None) -> None: + await self.notify("notifications/prompts/list_changed", _notification_params(None, meta)) + + async def send_resource_list_changed(self, *, meta: Meta | None = None) -> None: + await self.notify("notifications/resources/list_changed", _notification_params(None, meta)) + + async def send_resource_updated(self, uri: str, *, meta: Meta | None = None) -> None: + await self.notify("notifications/resources/updated", _notification_params({"uri": uri}, meta)) + + def check_capability(self, capability: ClientCapabilities) -> bool: + """Return whether the connected client declared the given capability. + + Returns ``False`` if ``initialize`` hasn't completed yet. + """ + # TODO: redesign — mirrors v1 ServerSession.check_client_capability + # verbatim for parity. See FOLLOWUPS.md. + if self.client_capabilities is None: + return False + have = self.client_capabilities + if capability.roots is not None: + if have.roots is None: + return False + if capability.roots.list_changed and not have.roots.list_changed: + return False + if capability.sampling is not None and have.sampling is None: + return False + if capability.elicitation is not None and have.elicitation is None: + return False + if capability.experimental is not None: + if have.experimental is None: + return False + for k in capability.experimental: + if k not in have.experimental: + return False + return True diff --git a/src/mcp/server/context.py b/src/mcp/server/context.py index d8e11d78b2..b7b97acf8b 100644 --- a/src/mcp/server/context.py +++ b/src/mcp/server/context.py @@ -5,10 +5,17 @@ from typing_extensions import TypeVar +from mcp.server._typed_request import TypedServerRequestMixin +from mcp.server.connection import Connection from mcp.server.experimental.request_context import Experimental from mcp.server.session import ServerSession from mcp.shared._context import RequestContext +from mcp.shared.context import BaseContext +from mcp.shared.dispatcher import DispatchContext from mcp.shared.message import CloseSSEStreamCallback +from mcp.shared.peer import Meta, PeerMixin +from mcp.shared.transport_context import TransportContext +from mcp.types import LoggingLevel, RequestParamsMeta LifespanContextT = TypeVar("LifespanContextT", default=dict[str, Any]) RequestT = TypeVar("RequestT", default=Any) @@ -21,3 +28,56 @@ class ServerRequestContext(RequestContext[ServerSession], Generic[LifespanContex request: RequestT | None = None close_sse_stream: CloseSSEStreamCallback | None = None close_standalone_sse_stream: CloseSSEStreamCallback | None = None + + +LifespanT = TypeVar("LifespanT", default=Any) +TransportT = TypeVar("TransportT", bound=TransportContext, default=TransportContext) + + +class Context(BaseContext[TransportT], PeerMixin, TypedServerRequestMixin, Generic[LifespanT, TransportT]): + """Server-side per-request context. + + Composes `BaseContext` (forwards to `DispatchContext`, satisfies `Outbound`), + `PeerMixin` (kwarg-style ``sample``/``elicit_*``/``list_roots``/``ping``), + and `TypedServerRequestMixin` (typed ``send_request(req) -> Result``). Adds + ``lifespan`` and ``connection``. + + Constructed by `ServerRunner` (PR4) per inbound request and handed to the + user's handler. + """ + + def __init__( + self, + dctx: DispatchContext[TransportT], + *, + lifespan: LifespanT, + connection: Connection, + meta: RequestParamsMeta | None = None, + ) -> None: + super().__init__(dctx, meta=meta) + self._lifespan = lifespan + self._connection = connection + + @property + def lifespan(self) -> LifespanT: + """The server-wide lifespan output (what `Server(..., lifespan=...)` yielded).""" + return self._lifespan + + @property + def connection(self) -> Connection: + """The per-client `Connection` for this request's connection.""" + return self._connection + + async def log(self, level: LoggingLevel, data: Any, logger: str | None = None, *, meta: Meta | None = None) -> None: + """Send a request-scoped ``notifications/message`` log entry. + + Uses this request's back-channel (so the entry rides the request's SSE + stream in streamable HTTP), not the standalone stream — use + ``ctx.connection.log(...)`` for that. + """ + params: dict[str, Any] = {"level": level, "data": data} + if logger is not None: + params["logger"] = logger + if meta: + params["_meta"] = meta + await self.notify("notifications/message", params) diff --git a/src/mcp/shared/peer.py b/src/mcp/shared/peer.py index 9951081104..47b64c7769 100644 --- a/src/mcp/shared/peer.py +++ b/src/mcp/shared/peer.py @@ -1,9 +1,9 @@ """Typed MCP request sugar over an `Outbound`. `PeerMixin` defines the server-to-client request methods (sampling, elicitation, -roots, ping) once. Any class that satisfies `Outbound` (i.e. has `send_raw_request` -and `notify`) can mix it in and get the typed methods for free — `Context`, -`Connection`, `Client`, or the bare `Peer` wrapper below. +roots, ping) once. Any class that satisfies `Outbound` (i.e. has +``send_raw_request`` and ``notify``) can mix it in and get the typed methods for +free — `Context`, `Connection`, `Client`, or the bare `Peer` wrapper below. The mixin does no capability gating: it builds the params, calls ``self.send_raw_request(method, params)``, and parses the result into the typed @@ -32,11 +32,24 @@ ToolChoice, ) -__all__ = ["Peer", "PeerMixin"] +__all__ = ["Meta", "Peer", "PeerMixin", "dump_params"] +Meta = dict[str, Any] +"""Type alias for the ``_meta`` field carried on request/notification params.""" -def _dump(model: BaseModel) -> dict[str, Any]: - return model.model_dump(by_alias=True, mode="json", exclude_none=True) + +def dump_params(model: BaseModel | None, meta: Meta | None = None) -> dict[str, Any] | None: + """Serialize a params model to a wire dict, merging ``meta`` into ``_meta``. + + Shared by `PeerMixin`, `Connection`, and `TypedServerRequestMixin` so every + typed convenience method gets the same `_meta` handling. ``meta`` keys take + precedence over any ``_meta`` already present on the model. + """ + out = model.model_dump(by_alias=True, mode="json", exclude_none=True) if model is not None else None + if meta: + out = dict(out or {}) + out["_meta"] = {**out.get("_meta", {}), **meta} + return out class PeerMixin: @@ -61,6 +74,7 @@ async def sample( model_preferences: ModelPreferences | None = None, tools: None = None, tool_choice: ToolChoice | None = None, + meta: Meta | None = None, opts: CallOptions | None = None, ) -> CreateMessageResult: ... @overload @@ -77,6 +91,7 @@ async def sample( model_preferences: ModelPreferences | None = None, tools: list[Tool], tool_choice: ToolChoice | None = None, + meta: Meta | None = None, opts: CallOptions | None = None, ) -> CreateMessageResultWithTools: ... async def sample( @@ -92,6 +107,7 @@ async def sample( model_preferences: ModelPreferences | None = None, tools: list[Tool] | None = None, tool_choice: ToolChoice | None = None, + meta: Meta | None = None, opts: CallOptions | None = None, ) -> CreateMessageResult | CreateMessageResultWithTools: """Send a ``sampling/createMessage`` request to the peer. @@ -113,7 +129,7 @@ async def sample( tools=tools, tool_choice=tool_choice, ) - result = await self.send_raw_request("sampling/createMessage", _dump(params), opts) + result = await self.send_raw_request("sampling/createMessage", dump_params(params, meta), opts) if tools is not None: return CreateMessageResultWithTools.model_validate(result) return CreateMessageResult.model_validate(result) @@ -122,6 +138,8 @@ async def elicit_form( self: Outbound, message: str, requested_schema: ElicitRequestedSchema, + *, + meta: Meta | None = None, opts: CallOptions | None = None, ) -> ElicitResult: """Send a form-mode ``elicitation/create`` request. @@ -131,7 +149,7 @@ async def elicit_form( NoBackChannelError: No back-channel for server-initiated requests. """ params = ElicitRequestFormParams(message=message, requested_schema=requested_schema) - result = await self.send_raw_request("elicitation/create", _dump(params), opts) + result = await self.send_raw_request("elicitation/create", dump_params(params, meta), opts) return ElicitResult.model_validate(result) async def elicit_url( @@ -139,6 +157,8 @@ async def elicit_url( message: str, url: str, elicitation_id: str, + *, + meta: Meta | None = None, opts: CallOptions | None = None, ) -> ElicitResult: """Send a URL-mode ``elicitation/create`` request. @@ -148,27 +168,29 @@ async def elicit_url( NoBackChannelError: No back-channel for server-initiated requests. """ params = ElicitRequestURLParams(message=message, url=url, elicitation_id=elicitation_id) - result = await self.send_raw_request("elicitation/create", _dump(params), opts) + result = await self.send_raw_request("elicitation/create", dump_params(params, meta), opts) return ElicitResult.model_validate(result) - async def list_roots(self: Outbound, opts: CallOptions | None = None) -> ListRootsResult: + async def list_roots( + self: Outbound, *, meta: Meta | None = None, opts: CallOptions | None = None + ) -> ListRootsResult: """Send a ``roots/list`` request. Raises: MCPError: The peer responded with an error. NoBackChannelError: No back-channel for server-initiated requests. """ - result = await self.send_raw_request("roots/list", None, opts) + result = await self.send_raw_request("roots/list", dump_params(None, meta), opts) return ListRootsResult.model_validate(result) - async def ping(self: Outbound, opts: CallOptions | None = None) -> None: + async def ping(self: Outbound, *, meta: Meta | None = None, opts: CallOptions | None = None) -> None: """Send a ``ping`` request and ignore the result. Raises: MCPError: The peer responded with an error. NoBackChannelError: No back-channel for server-initiated requests. """ - await self.send_raw_request("ping", None, opts) + await self.send_raw_request("ping", dump_params(None, meta), opts) class Peer(PeerMixin): diff --git a/tests/server/test_connection.py b/tests/server/test_connection.py new file mode 100644 index 0000000000..eb3440085a --- /dev/null +++ b/tests/server/test_connection.py @@ -0,0 +1,184 @@ +"""Tests for `Connection`. + +`Connection` wraps an `Outbound` (the standalone stream). Its `notify` is +best-effort (never raises); `send_raw_request` is gated on +``has_standalone_channel``. Tested with a stub `Outbound` so we can assert wire +shape and inject failures. +""" + +import logging +from collections.abc import Mapping +from typing import Any + +import anyio +import pytest + +from mcp.server.connection import Connection +from mcp.shared.dispatcher import CallOptions +from mcp.shared.exceptions import NoBackChannelError +from mcp.types import ( + ClientCapabilities, + ElicitationCapability, + EmptyResult, + ListRootsRequest, + ListRootsResult, + PingRequest, + RootsCapability, + SamplingCapability, +) + + +class StubOutbound: + def __init__( + self, *, result: dict[str, Any] | None = None, raise_on_send: type[BaseException] | None = None + ) -> None: + self.requests: list[tuple[str, Mapping[str, Any] | None]] = [] + self.notifications: list[tuple[str, Mapping[str, Any] | None]] = [] + self._result = result if result is not None else {} + self._raise_on_send = raise_on_send + + async def send_raw_request( + self, method: str, params: Mapping[str, Any] | None, opts: CallOptions | None = None + ) -> dict[str, Any]: + self.requests.append((method, params)) + return self._result + + async def notify(self, method: str, params: Mapping[str, Any] | None) -> None: + if self._raise_on_send is not None: + raise self._raise_on_send() + self.notifications.append((method, params)) + + +@pytest.mark.anyio +async def test_connection_notify_forwards_to_outbound(): + out = StubOutbound() + conn = Connection(out, has_standalone_channel=True) + await conn.notify("notifications/message", {"level": "info", "data": "hi"}) + assert out.notifications == [("notifications/message", {"level": "info", "data": "hi"})] + + +@pytest.mark.anyio +async def test_connection_notify_swallows_broken_stream_and_debug_logs(caplog: pytest.LogCaptureFixture): + caplog.set_level(logging.DEBUG, logger="mcp.server.connection") + out = StubOutbound(raise_on_send=anyio.BrokenResourceError) + conn = Connection(out, has_standalone_channel=True) + await conn.notify("notifications/message", {"data": "x"}) # must not raise + assert "stream closed" in caplog.text.lower() + + +@pytest.mark.anyio +async def test_connection_notify_drops_when_no_standalone_channel(caplog: pytest.LogCaptureFixture): + caplog.set_level(logging.DEBUG, logger="mcp.server.connection") + out = StubOutbound() + conn = Connection(out, has_standalone_channel=False) + await conn.notify("notifications/message", {"data": "x"}) # must not raise + assert out.notifications == [] + assert "no standalone channel" in caplog.text.lower() + + +@pytest.mark.anyio +async def test_connection_send_raw_request_raises_nobackchannel_when_no_standalone_channel(): + conn = Connection(StubOutbound(), has_standalone_channel=False) + with pytest.raises(NoBackChannelError): + await conn.send_raw_request("ping", None) + + +@pytest.mark.anyio +async def test_connection_send_raw_request_forwards_when_standalone_channel_present(): + out = StubOutbound() + conn = Connection(out, has_standalone_channel=True) + result = await conn.send_raw_request("ping", None) + assert out.requests == [("ping", None)] + assert result == {} + + +@pytest.mark.anyio +async def test_connection_send_request_with_spec_type_infers_result_type(): + out = StubOutbound(result={"roots": [{"uri": "file:///ws"}]}) + conn = Connection(out, has_standalone_channel=True) + result = await conn.send_request(ListRootsRequest()) + method, _ = out.requests[0] + assert method == "roots/list" + assert isinstance(result, ListRootsResult) + assert str(result.roots[0].uri) == "file:///ws" + + +@pytest.mark.anyio +async def test_connection_send_request_with_result_type_kwarg_validates_custom_type(): + out = StubOutbound(result={}) + conn = Connection(out, has_standalone_channel=True) + result = await conn.send_request(PingRequest(), result_type=EmptyResult) + assert isinstance(result, EmptyResult) + + +@pytest.mark.anyio +async def test_connection_ping_sends_ping_on_standalone(): + out = StubOutbound() + conn = Connection(out, has_standalone_channel=True) + await conn.ping() + assert out.requests == [("ping", None)] + + +@pytest.mark.anyio +async def test_connection_log_sends_logging_message_notification(): + out = StubOutbound() + conn = Connection(out, has_standalone_channel=True) + await conn.log("info", {"k": "v"}, logger="my.logger") + method, params = out.notifications[0] + assert method == "notifications/message" + assert params is not None + assert params["level"] == "info" + assert params["data"] == {"k": "v"} + assert params["logger"] == "my.logger" + + +@pytest.mark.anyio +async def test_connection_log_with_meta_includes_meta_in_params(): + out = StubOutbound() + conn = Connection(out, has_standalone_channel=True) + await conn.log("info", "x", meta={"traceId": "abc"}) + _, params = out.notifications[0] + assert params is not None + assert params["_meta"] == {"traceId": "abc"} + + +@pytest.mark.anyio +async def test_connection_list_changed_notifications_send_correct_methods(): + out = StubOutbound() + conn = Connection(out, has_standalone_channel=True) + await conn.send_tool_list_changed() + await conn.send_prompt_list_changed() + await conn.send_resource_list_changed() + await conn.send_resource_updated("file:///workspace/a.txt") + methods = [m for m, _ in out.notifications] + assert methods == [ + "notifications/tools/list_changed", + "notifications/prompts/list_changed", + "notifications/resources/list_changed", + "notifications/resources/updated", + ] + assert out.notifications[-1][1] == {"uri": "file:///workspace/a.txt"} + + +@pytest.mark.anyio +async def test_connection_send_tool_list_changed_with_meta_includes_meta_only_params(): + out = StubOutbound() + conn = Connection(out, has_standalone_channel=True) + await conn.send_tool_list_changed(meta={"k": 1}) + assert out.notifications == [("notifications/tools/list_changed", {"_meta": {"k": 1}})] + + +def test_connection_check_capability_false_before_initialized(): + conn = Connection(StubOutbound(), has_standalone_channel=True) + assert conn.check_capability(ClientCapabilities(sampling=SamplingCapability())) is False + + +def test_connection_check_capability_true_when_client_declares_it(): + conn = Connection(StubOutbound(), has_standalone_channel=True) + conn.client_capabilities = ClientCapabilities( + sampling=SamplingCapability(), roots=RootsCapability(list_changed=True) + ) + conn.initialized.set() + assert conn.check_capability(ClientCapabilities(sampling=SamplingCapability())) is True + assert conn.check_capability(ClientCapabilities(roots=RootsCapability(list_changed=True))) is True + assert conn.check_capability(ClientCapabilities(elicitation=ElicitationCapability())) is False diff --git a/tests/server/test_server_context.py b/tests/server/test_server_context.py new file mode 100644 index 0000000000..65db51c4a5 --- /dev/null +++ b/tests/server/test_server_context.py @@ -0,0 +1,131 @@ +"""Tests for the server-side `Context`. + +`Context` composes `BaseContext` (forwarding to a `DispatchContext`) with +`PeerMixin` (typed sample/elicit/roots/ping) plus `lifespan` and `connection`. +End-to-end tested over `DirectDispatcher`. +""" + +from collections.abc import Mapping +from dataclasses import dataclass +from typing import Any + +import anyio +import pytest + +from mcp.server.connection import Connection +from mcp.server.context import Context +from mcp.shared.dispatcher import DispatchContext +from mcp.shared.transport_context import TransportContext +from mcp.types import CreateMessageResult, ListRootsRequest, ListRootsResult, SamplingMessage, TextContent + +from ..shared.conftest import direct_pair +from ..shared.test_dispatcher import Recorder, echo_handlers, running_pair + +DCtx = DispatchContext[TransportContext] + + +@dataclass +class _Lifespan: + name: str + + +@pytest.mark.anyio +async def test_context_exposes_lifespan_and_connection_and_forwards_base_context(): + captured: list[Context[_Lifespan, TransportContext]] = [] + conn = Connection.__new__(Connection) # placeholder until running_pair gives us the dispatcher + + async def server_on_request(dctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: + ctx: Context[_Lifespan, TransportContext] = Context(dctx, lifespan=_Lifespan("app"), connection=conn) + captured.append(ctx) + return {} + + async with running_pair(direct_pair, server_on_request=server_on_request) as (client, server, *_): + # Now we have the server dispatcher; build the real Connection bound to it. + conn.__init__(server, has_standalone_channel=True) + with anyio.fail_after(5): + await client.send_raw_request("t", None) + ctx = captured[0] + assert ctx.lifespan.name == "app" + assert ctx.connection is conn + assert ctx.transport.kind == "direct" + assert ctx.can_send_request is True + + +@pytest.mark.anyio +async def test_context_sample_round_trips_via_peer_mixin_on_base_context_outbound(): + crec = Recorder() + + async def client_on_request(dctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: + crec.requests.append((method, params)) + return {"role": "assistant", "content": {"type": "text", "text": "ok"}, "model": "m"} + + results: list[CreateMessageResult] = [] + + async def server_on_request(dctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: + ctx: Context[_Lifespan, TransportContext] = Context( + dctx, lifespan=_Lifespan("app"), connection=Connection(dctx, has_standalone_channel=True) + ) + results.append( + await ctx.sample( + [SamplingMessage(role="user", content=TextContent(type="text", text="hi"))], + max_tokens=5, + ) + ) + return {} + + async with running_pair( + direct_pair, + server_on_request=server_on_request, + client_on_request=client_on_request, + ) as (client, *_): + with anyio.fail_after(5): + await client.send_raw_request("tools/call", None) + assert crec.requests[0][0] == "sampling/createMessage" + assert isinstance(results[0], CreateMessageResult) + + +@pytest.mark.anyio +async def test_context_send_request_with_spec_type_infers_result_via_typed_mixin(): + async def client_on_request(dctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: + return {"roots": []} + + results: list[ListRootsResult] = [] + + async def server_on_request(dctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: + ctx: Context[_Lifespan, TransportContext] = Context( + dctx, lifespan=_Lifespan("app"), connection=Connection(dctx, has_standalone_channel=True) + ) + results.append(await ctx.send_request(ListRootsRequest())) + return {} + + async with running_pair(direct_pair, server_on_request=server_on_request, client_on_request=client_on_request) as ( + client, + *_, + ): + with anyio.fail_after(5): + await client.send_raw_request("t", None) + assert isinstance(results[0], ListRootsResult) + + +@pytest.mark.anyio +async def test_context_log_sends_request_scoped_message_notification(): + crec = Recorder() + _, c_notify = echo_handlers(crec) + + async def server_on_request(dctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: + ctx: Context[_Lifespan, TransportContext] = Context( + dctx, lifespan=_Lifespan("app"), connection=Connection(dctx, has_standalone_channel=True) + ) + await ctx.log("debug", "hello") + return {} + + async with running_pair(direct_pair, server_on_request=server_on_request, client_on_notify=c_notify) as ( + client, + *_, + ): + with anyio.fail_after(5): + await client.send_raw_request("t", None) + await crec.notified.wait() + method, params = crec.notifications[0] + assert method == "notifications/message" + assert params is not None and params["level"] == "debug" and params["data"] == "hello" diff --git a/tests/shared/test_peer.py b/tests/shared/test_peer.py index 43d49252cb..0d7d9e9bae 100644 --- a/tests/shared/test_peer.py +++ b/tests/shared/test_peer.py @@ -12,7 +12,7 @@ import pytest from mcp.shared.dispatcher import DispatchContext -from mcp.shared.peer import Peer +from mcp.shared.peer import Peer, dump_params from mcp.shared.transport_context import TransportContext from mcp.types import ( CreateMessageResult, @@ -116,6 +116,25 @@ async def test_peer_list_roots_sends_roots_list_and_returns_typed_result(): assert str(result.roots[0].uri) == "file:///workspace" +@pytest.mark.anyio +async def test_peer_list_roots_with_meta_sends_meta_in_params(): + rec = _Recorder({"roots": []}) + async with running_pair(direct_pair, server_on_request=rec.on_request) as (client, *_): + peer = Peer(client) + with anyio.fail_after(5): + await peer.list_roots(meta={"traceId": "t1"}) + method, params = rec.seen[0] + assert method == "roots/list" + assert params == {"_meta": {"traceId": "t1"}} + + +def test_dump_params_merges_meta_over_model_meta(): + out = dump_params(None, None) + assert out is None + out = dump_params(None, {"k": 1}) + assert out == {"_meta": {"k": 1}} + + @pytest.mark.anyio async def test_peer_ping_sends_ping_and_returns_none(): rec = _Recorder({}) From 1457e393b786cbee81934191b7e75517ab43293a Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Thu, 16 Apr 2026 22:44:41 +0000 Subject: [PATCH 16/27] test: close PR3 coverage gaps to 100% - Connection.check_capability per-field branches (parametrized) - Context.log with logger and meta supplied - Peer.notify forwards to wrapped Outbound --- tests/server/test_connection.py | 21 +++++++++++++++++++++ tests/server/test_server_context.py | 25 +++++++++++++++++++++++++ tests/shared/test_peer.py | 17 +++++++++++++++++ 3 files changed, 63 insertions(+) diff --git a/tests/server/test_connection.py b/tests/server/test_connection.py index eb3440085a..ded9dfd6ac 100644 --- a/tests/server/test_connection.py +++ b/tests/server/test_connection.py @@ -173,6 +173,27 @@ def test_connection_check_capability_false_before_initialized(): assert conn.check_capability(ClientCapabilities(sampling=SamplingCapability())) is False +@pytest.mark.parametrize( + ("have", "want", "expected"), + [ + (ClientCapabilities(roots=None), ClientCapabilities(roots=RootsCapability()), False), + ( + ClientCapabilities(roots=RootsCapability(list_changed=False)), + ClientCapabilities(roots=RootsCapability(list_changed=True)), + False, + ), + (ClientCapabilities(sampling=None), ClientCapabilities(sampling=SamplingCapability()), False), + (ClientCapabilities(experimental=None), ClientCapabilities(experimental={"a": {}}), False), + (ClientCapabilities(experimental={"a": {}}), ClientCapabilities(experimental={"b": {}}), False), + (ClientCapabilities(experimental={"a": {}}), ClientCapabilities(experimental={"a": {}}), True), + ], +) +def test_check_capability_per_field_branches(have: ClientCapabilities, want: ClientCapabilities, expected: bool): + conn = Connection(StubOutbound(), has_standalone_channel=True) + conn.client_capabilities = have + assert conn.check_capability(want) is expected + + def test_connection_check_capability_true_when_client_declares_it(): conn = Connection(StubOutbound(), has_standalone_channel=True) conn.client_capabilities = ClientCapabilities( diff --git a/tests/server/test_server_context.py b/tests/server/test_server_context.py index 65db51c4a5..eb2df9b649 100644 --- a/tests/server/test_server_context.py +++ b/tests/server/test_server_context.py @@ -129,3 +129,28 @@ async def server_on_request(dctx: DCtx, method: str, params: Mapping[str, Any] | method, params = crec.notifications[0] assert method == "notifications/message" assert params is not None and params["level"] == "debug" and params["data"] == "hello" + + +@pytest.mark.anyio +async def test_context_log_includes_logger_and_meta_when_supplied(): + crec = Recorder() + _, c_notify = echo_handlers(crec) + + async def server_on_request(dctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: + ctx: Context[_Lifespan, TransportContext] = Context( + dctx, lifespan=_Lifespan("app"), connection=Connection(dctx, has_standalone_channel=True) + ) + await ctx.log("info", "x", logger="my.log", meta={"traceId": "t"}) + return {} + + async with running_pair(direct_pair, server_on_request=server_on_request, client_on_notify=c_notify) as ( + client, + *_, + ): + with anyio.fail_after(5): + await client.send_raw_request("t", None) + await crec.notified.wait() + _, params = crec.notifications[0] + assert params is not None + assert params["logger"] == "my.log" + assert params["_meta"] == {"traceId": "t"} diff --git a/tests/shared/test_peer.py b/tests/shared/test_peer.py index 0d7d9e9bae..589994c818 100644 --- a/tests/shared/test_peer.py +++ b/tests/shared/test_peer.py @@ -135,6 +135,23 @@ def test_dump_params_merges_meta_over_model_meta(): assert out == {"_meta": {"k": 1}} +@pytest.mark.anyio +async def test_peer_notify_forwards_to_wrapped_outbound(): + sent: list[tuple[str, Mapping[str, Any] | None]] = [] + + class _Out: + async def send_raw_request( + self, method: str, params: Mapping[str, Any] | None, opts: Any = None + ) -> dict[str, Any]: + raise NotImplementedError + + async def notify(self, method: str, params: Mapping[str, Any] | None) -> None: + sent.append((method, params)) + + await Peer(_Out()).notify("n", {"x": 1}) + assert sent == [("n", {"x": 1})] + + @pytest.mark.anyio async def test_peer_ping_sends_ping_and_returns_none(): rec = _Recorder({}) From ceeb84a8d53cdb231bf7a5c8904c6fd47a7ee0a5 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Mon, 20 Apr 2026 16:58:42 +0000 Subject: [PATCH 17/27] test: move asserts inside async-with for 3.11 coverage instrumentation MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit coverage.py on Python 3.11 doesn't record statements after an 'async with running_pair(...)' exit when there's a nested 'with anyio.fail_after()' inside. Same workaround as 0a8f0f4 in PR2 — move the asserts inside the async-with block. --- tests/server/test_server_context.py | 30 +++++++-------- tests/shared/test_context.py | 20 +++++----- tests/shared/test_peer.py | 60 ++++++++++++++--------------- 3 files changed, 55 insertions(+), 55 deletions(-) diff --git a/tests/server/test_server_context.py b/tests/server/test_server_context.py index eb2df9b649..e01de34d33 100644 --- a/tests/server/test_server_context.py +++ b/tests/server/test_server_context.py @@ -44,11 +44,11 @@ async def server_on_request(dctx: DCtx, method: str, params: Mapping[str, Any] | conn.__init__(server, has_standalone_channel=True) with anyio.fail_after(5): await client.send_raw_request("t", None) - ctx = captured[0] - assert ctx.lifespan.name == "app" - assert ctx.connection is conn - assert ctx.transport.kind == "direct" - assert ctx.can_send_request is True + ctx = captured[0] + assert ctx.lifespan.name == "app" + assert ctx.connection is conn + assert ctx.transport.kind == "direct" + assert ctx.can_send_request is True @pytest.mark.anyio @@ -80,8 +80,8 @@ async def server_on_request(dctx: DCtx, method: str, params: Mapping[str, Any] | ) as (client, *_): with anyio.fail_after(5): await client.send_raw_request("tools/call", None) - assert crec.requests[0][0] == "sampling/createMessage" - assert isinstance(results[0], CreateMessageResult) + assert crec.requests[0][0] == "sampling/createMessage" + assert isinstance(results[0], CreateMessageResult) @pytest.mark.anyio @@ -104,7 +104,7 @@ async def server_on_request(dctx: DCtx, method: str, params: Mapping[str, Any] | ): with anyio.fail_after(5): await client.send_raw_request("t", None) - assert isinstance(results[0], ListRootsResult) + assert isinstance(results[0], ListRootsResult) @pytest.mark.anyio @@ -126,9 +126,9 @@ async def server_on_request(dctx: DCtx, method: str, params: Mapping[str, Any] | with anyio.fail_after(5): await client.send_raw_request("t", None) await crec.notified.wait() - method, params = crec.notifications[0] - assert method == "notifications/message" - assert params is not None and params["level"] == "debug" and params["data"] == "hello" + method, params = crec.notifications[0] + assert method == "notifications/message" + assert params is not None and params["level"] == "debug" and params["data"] == "hello" @pytest.mark.anyio @@ -150,7 +150,7 @@ async def server_on_request(dctx: DCtx, method: str, params: Mapping[str, Any] | with anyio.fail_after(5): await client.send_raw_request("t", None) await crec.notified.wait() - _, params = crec.notifications[0] - assert params is not None - assert params["logger"] == "my.log" - assert params["_meta"] == {"traceId": "t"} + _, params = crec.notifications[0] + assert params is not None + assert params["logger"] == "my.log" + assert params["_meta"] == {"traceId": "t"} diff --git a/tests/shared/test_context.py b/tests/shared/test_context.py index 951690028f..882f90bfab 100644 --- a/tests/shared/test_context.py +++ b/tests/shared/test_context.py @@ -34,11 +34,11 @@ async def server_on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | async with running_pair(direct_pair, server_on_request=server_on_request) as (client, *_): with anyio.fail_after(5): await client.send_raw_request("t", None) - bctx = captured[0] - assert bctx.transport.kind == "direct" - assert isinstance(bctx.cancel_requested, anyio.Event) - assert bctx.can_send_request is True - assert bctx.meta is None + bctx = captured[0] + assert bctx.transport.kind == "direct" + assert isinstance(bctx.cancel_requested, anyio.Event) + assert bctx.can_send_request is True + assert bctx.meta is None @pytest.mark.anyio @@ -61,9 +61,9 @@ async def server_on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | with anyio.fail_after(5): result = await client.send_raw_request("tools/call", None) await crec.notified.wait() - assert crec.requests == [("sampling/createMessage", {"x": 1})] - assert crec.notifications == [("notifications/message", {"level": "info"})] - assert result["sample"] == {"echoed": "sampling/createMessage", "params": {"x": 1}} + assert crec.requests == [("sampling/createMessage", {"x": 1})] + assert crec.notifications == [("notifications/message", {"level": "info"})] + assert result["sample"] == {"echoed": "sampling/createMessage", "params": {"x": 1}} @pytest.mark.anyio @@ -81,7 +81,7 @@ async def server_on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | async with running_pair(direct_pair, server_on_request=server_on_request) as (client, *_): with anyio.fail_after(5): await client.send_raw_request("t", None, {"on_progress": on_progress}) - assert received == [(0.5, 1.0, "halfway")] + assert received == [(0.5, 1.0, "halfway")] @pytest.mark.anyio @@ -100,7 +100,7 @@ async def server_on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | ) as (client, *_): with anyio.fail_after(5): await client.send_raw_request("t", None) - assert crec.requests == [("ping", None)] + assert crec.requests == [("ping", None)] @pytest.mark.anyio diff --git a/tests/shared/test_peer.py b/tests/shared/test_peer.py index 589994c818..0be4225818 100644 --- a/tests/shared/test_peer.py +++ b/tests/shared/test_peer.py @@ -50,11 +50,11 @@ async def test_peer_sample_sends_create_message_and_returns_typed_result(): [SamplingMessage(role="user", content=TextContent(type="text", text="hello"))], max_tokens=10, ) - method, params = rec.seen[0] - assert method == "sampling/createMessage" - assert params is not None and params["maxTokens"] == 10 - assert isinstance(result, CreateMessageResult) - assert result.model == "m" + method, params = rec.seen[0] + assert method == "sampling/createMessage" + assert params is not None and params["maxTokens"] == 10 + assert isinstance(result, CreateMessageResult) + assert result.model == "m" @pytest.mark.anyio @@ -68,10 +68,10 @@ async def test_peer_sample_with_tools_returns_with_tools_result(): max_tokens=5, tools=[Tool(name="t", input_schema={"type": "object"})], ) - method, params = rec.seen[0] - assert method == "sampling/createMessage" - assert params is not None and params["tools"][0]["name"] == "t" - assert isinstance(result, CreateMessageResultWithTools) + method, params = rec.seen[0] + assert method == "sampling/createMessage" + assert params is not None and params["tools"][0]["name"] == "t" + assert isinstance(result, CreateMessageResultWithTools) @pytest.mark.anyio @@ -81,11 +81,11 @@ async def test_peer_elicit_form_sends_elicitation_create_with_form_params(): peer = Peer(client) with anyio.fail_after(5): result = await peer.elicit_form("Your name?", requested_schema={"type": "object", "properties": {}}) - method, params = rec.seen[0] - assert method == "elicitation/create" - assert params is not None and params["mode"] == "form" - assert params["message"] == "Your name?" - assert isinstance(result, ElicitResult) + method, params = rec.seen[0] + assert method == "elicitation/create" + assert params is not None and params["mode"] == "form" + assert params["message"] == "Your name?" + assert isinstance(result, ElicitResult) @pytest.mark.anyio @@ -95,11 +95,11 @@ async def test_peer_elicit_url_sends_elicitation_create_with_url_params(): peer = Peer(client) with anyio.fail_after(5): result = await peer.elicit_url("Auth needed", url="https://example.com/auth", elicitation_id="e1") - method, params = rec.seen[0] - assert method == "elicitation/create" - assert params is not None and params["mode"] == "url" - assert params["url"] == "https://example.com/auth" - assert isinstance(result, ElicitResult) + method, params = rec.seen[0] + assert method == "elicitation/create" + assert params is not None and params["mode"] == "url" + assert params["url"] == "https://example.com/auth" + assert isinstance(result, ElicitResult) @pytest.mark.anyio @@ -109,11 +109,11 @@ async def test_peer_list_roots_sends_roots_list_and_returns_typed_result(): peer = Peer(client) with anyio.fail_after(5): result = await peer.list_roots() - method, _ = rec.seen[0] - assert method == "roots/list" - assert isinstance(result, ListRootsResult) - assert len(result.roots) == 1 - assert str(result.roots[0].uri) == "file:///workspace" + method, _ = rec.seen[0] + assert method == "roots/list" + assert isinstance(result, ListRootsResult) + assert len(result.roots) == 1 + assert str(result.roots[0].uri) == "file:///workspace" @pytest.mark.anyio @@ -123,9 +123,9 @@ async def test_peer_list_roots_with_meta_sends_meta_in_params(): peer = Peer(client) with anyio.fail_after(5): await peer.list_roots(meta={"traceId": "t1"}) - method, params = rec.seen[0] - assert method == "roots/list" - assert params == {"_meta": {"traceId": "t1"}} + method, params = rec.seen[0] + assert method == "roots/list" + assert params == {"_meta": {"traceId": "t1"}} def test_dump_params_merges_meta_over_model_meta(): @@ -159,6 +159,6 @@ async def test_peer_ping_sends_ping_and_returns_none(): peer = Peer(client) with anyio.fail_after(5): result = await peer.ping() - method, _ = rec.seen[0] - assert method == "ping" - assert result is None + method, _ = rec.seen[0] + assert method == "ping" + assert result is None From 0fa4defce9667d73d400a15b35abcfbe63770b6c Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Mon, 20 Apr 2026 17:04:07 +0000 Subject: [PATCH 18/27] docs: drop development-journal language from docstrings/comments Remove references to PR numbers, internal scratch notes, and design-spike shorthand that won't make sense to a fresh reader of the codebase. --- src/mcp/server/_typed_request.py | 9 +++++---- src/mcp/server/connection.py | 4 ++-- src/mcp/server/context.py | 4 ++-- src/mcp/shared/context.py | 4 ++-- 4 files changed, 11 insertions(+), 10 deletions(-) diff --git a/src/mcp/server/_typed_request.py b/src/mcp/server/_typed_request.py index 50cae159d1..4334b20a94 100644 --- a/src/mcp/server/_typed_request.py +++ b/src/mcp/server/_typed_request.py @@ -1,13 +1,14 @@ -"""Shape-2 typed ``send_request`` for server-to-client requests. +"""Typed ``send_request`` for server-to-client requests. `TypedServerRequestMixin` provides a typed `send_request(req) -> Result` over the host's raw `Outbound.send_raw_request`. Spec server-to-client request types have their result type inferred via per-type overloads; custom requests pass ``result_type=`` explicitly. -A `HasResult[R]` protocol (one generic signature, mapping declared on the -request type) is the cleaner long-term shape — see FOLLOWUPS.md. This per-spec -overload set is used for now to avoid touching `mcp.types`. +If the spec's request set grows substantially, consider declaring the result +mapping on the request types themselves (a ``__mcp_result__`` ClassVar read via +a structural protocol) so this overload ladder doesn't need maintaining +per-host-class. """ from typing import Any, TypeVar, overload diff --git a/src/mcp/server/connection.py b/src/mcp/server/connection.py index 72c4ed062f..df3652ce0e 100644 --- a/src/mcp/server/connection.py +++ b/src/mcp/server/connection.py @@ -53,7 +53,7 @@ def __init__(self, outbound: Outbound, *, has_standalone_channel: bool) -> None: self.protocol_version: str | None = None self.initialized: anyio.Event = anyio.Event() # TODO: make this generic (Connection[StateT]) once connection_lifespan - # wiring lands in ServerRunner — see FOLLOWUPS.md. + # wiring lands in ServerRunner. self.state: Any = None async def send_raw_request( @@ -124,7 +124,7 @@ def check_capability(self, capability: ClientCapabilities) -> bool: Returns ``False`` if ``initialize`` hasn't completed yet. """ # TODO: redesign — mirrors v1 ServerSession.check_client_capability - # verbatim for parity. See FOLLOWUPS.md. + # verbatim for parity. if self.client_capabilities is None: return False have = self.client_capabilities diff --git a/src/mcp/server/context.py b/src/mcp/server/context.py index b7b97acf8b..4f0cffd9ad 100644 --- a/src/mcp/server/context.py +++ b/src/mcp/server/context.py @@ -42,8 +42,8 @@ class Context(BaseContext[TransportT], PeerMixin, TypedServerRequestMixin, Gener and `TypedServerRequestMixin` (typed ``send_request(req) -> Result``). Adds ``lifespan`` and ``connection``. - Constructed by `ServerRunner` (PR4) per inbound request and handed to the - user's handler. + Constructed by `ServerRunner` per inbound request and handed to the user's + handler. """ def __init__( diff --git a/src/mcp/shared/context.py b/src/mcp/shared/context.py index 68f439b738..38ca8bd9b4 100644 --- a/src/mcp/shared/context.py +++ b/src/mcp/shared/context.py @@ -28,8 +28,8 @@ class BaseContext(Generic[TransportT]): """Per-request context wrapping a `DispatchContext`. - `ServerRunner` (PR4) constructs one per inbound request and passes it to - the user's handler. + `ServerRunner` constructs one per inbound request and passes it to the + user's handler. """ def __init__(self, dctx: DispatchContext[TransportT], meta: RequestParamsMeta | None = None) -> None: From b4ef565880d55d10f011d3eb49295fcbff6797d5 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Mon, 20 Apr 2026 19:01:09 +0000 Subject: [PATCH 19/27] refactor: make BaseContext/Context covariant in their type params MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit LifespanT and TransportT are only exposed via read-only properties (lifespan, transport), so covariance is sound. This lets a Context[AppState, HttpTC] be passed where a Context[object, TransportContext] is expected — needed for ServerRunner's middleware chain to compose without casts, and for reusable middleware to be typed Context[object, TransportContext] instead of relying on Any-slack. --- src/mcp/server/context.py | 4 ++-- src/mcp/shared/context.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/mcp/server/context.py b/src/mcp/server/context.py index 4f0cffd9ad..4d35f8a902 100644 --- a/src/mcp/server/context.py +++ b/src/mcp/server/context.py @@ -30,8 +30,8 @@ class ServerRequestContext(RequestContext[ServerSession], Generic[LifespanContex close_standalone_sse_stream: CloseSSEStreamCallback | None = None -LifespanT = TypeVar("LifespanT", default=Any) -TransportT = TypeVar("TransportT", bound=TransportContext, default=TransportContext) +LifespanT = TypeVar("LifespanT", default=Any, covariant=True) +TransportT = TypeVar("TransportT", bound=TransportContext, default=TransportContext, covariant=True) class Context(BaseContext[TransportT], PeerMixin, TypedServerRequestMixin, Generic[LifespanT, TransportT]): diff --git a/src/mcp/shared/context.py b/src/mcp/shared/context.py index 38ca8bd9b4..ff69c48401 100644 --- a/src/mcp/shared/context.py +++ b/src/mcp/shared/context.py @@ -22,7 +22,7 @@ __all__ = ["BaseContext"] -TransportT = TypeVar("TransportT", bound=TransportContext, default=TransportContext) +TransportT = TypeVar("TransportT", bound=TransportContext, default=TransportContext, covariant=True) class BaseContext(Generic[TransportT]): From 22baccd0cd86e59c9cc5492055528114efa86a63 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Mon, 20 Apr 2026 18:18:20 +0000 Subject: [PATCH 20/27] =?UTF-8?q?feat:=20ServerRunner=20skeleton=20?= =?UTF-8?q?=E2=80=94=20=5Fon=5Frequest,=20initialize,=20init-gate?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ServerRunner is the per-connection orchestrator over a Dispatcher. This commit lands the skeleton: ServerRegistry Protocol, _on_request (lookup → validate → build Context → call handler → dump), _handle_initialize (populates Connection, opens the init-gate), and a basic _on_notify. Additive methods on lowlevel Server (get_request_handler / get_notification_handler / middleware / connection_lifespan) so it satisfies ServerRegistry without touching the existing run() path. _PARAMS_FOR_METHOD is scaffolding (marked TODO) until the registry stores params types directly. 5 tests over DirectDispatcher + a real lowlevel Server. --- src/mcp/server/lowlevel/server.py | 20 +++ src/mcp/server/runner.py | 218 ++++++++++++++++++++++++++++++ tests/server/test_runner.py | 154 +++++++++++++++++++++ 3 files changed, 392 insertions(+) create mode 100644 src/mcp/server/runner.py create mode 100644 tests/server/test_runner.py diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index 5e4e2e6f5b..de12832dc5 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -246,6 +246,26 @@ def _has_handler(self, method: str) -> bool: """Check if a handler is registered for the given method.""" return method in self._request_handlers or method in self._notification_handlers + # --- ServerRegistry protocol (consumed by ServerRunner) ------------------ + + def get_request_handler(self, method: str) -> Callable[..., Awaitable[Any]] | None: + """Return the handler for a request method, or ``None``.""" + return self._request_handlers.get(method) + + def get_notification_handler(self, method: str) -> Callable[..., Awaitable[Any]] | None: + """Return the handler for a notification method, or ``None``.""" + return self._notification_handlers.get(method) + + @property + def middleware(self) -> list[Any]: + """Context-tier middleware. Empty until the registry refactor adds registration.""" + return [] + + @property + def connection_lifespan(self) -> None: + """Per-connection lifespan. ``None`` until the registry refactor adds it.""" + return None + # TODO: Rethink capabilities API. Currently capabilities are derived from registered # handlers but require NotificationOptions to be passed externally for list_changed # flags, and experimental_capabilities as a separate dict. Consider deriving capabilities diff --git a/src/mcp/server/runner.py b/src/mcp/server/runner.py new file mode 100644 index 0000000000..66093b250a --- /dev/null +++ b/src/mcp/server/runner.py @@ -0,0 +1,218 @@ +"""`ServerRunner` — per-connection orchestrator over a `Dispatcher`. + +`ServerRunner` is the bridge between the dispatcher layer (`on_request` / +`on_notify`, untyped dicts) and the user's handler layer (typed `Context`, +typed params). One instance per client connection. It: + +* handles the ``initialize`` handshake and populates `Connection` +* gates requests until initialized (``ping`` exempt) +* looks up the handler in the server's registry, validates params, builds + `Context`, runs the middleware chain, returns the result dict +* drives ``dispatcher.run()`` and the per-connection lifespan + +`ServerRunner` consumes any `ServerRegistry` — the lowlevel `Server` satisfies +it via additive methods so the existing ``Server.run()`` path is unaffected. +""" + +from __future__ import annotations + +import logging +from collections.abc import Awaitable, Callable, Mapping +from dataclasses import dataclass, field +from typing import Any, Generic, Protocol, cast + +from pydantic import BaseModel +from typing_extensions import TypeVar + +from mcp.server.connection import Connection +from mcp.server.context import Context +from mcp.server.lowlevel.server import NotificationOptions +from mcp.shared.dispatcher import DispatchContext, Dispatcher +from mcp.shared.exceptions import MCPError +from mcp.shared.transport_context import TransportContext +from mcp.types import ( + INVALID_REQUEST, + LATEST_PROTOCOL_VERSION, + METHOD_NOT_FOUND, + CallToolRequestParams, + CompleteRequestParams, + GetPromptRequestParams, + Implementation, + InitializeRequestParams, + InitializeResult, + NotificationParams, + PaginatedRequestParams, + ProgressNotificationParams, + ReadResourceRequestParams, + RequestParams, + ServerCapabilities, + SetLevelRequestParams, + SubscribeRequestParams, + UnsubscribeRequestParams, +) + +__all__ = ["ServerRegistry", "ServerRunner"] + +logger = logging.getLogger(__name__) + +LifespanT = TypeVar("LifespanT", default=Any) +ServerTransportT = TypeVar("ServerTransportT", bound=TransportContext, default=TransportContext) + +Handler = Callable[..., Awaitable[Any]] +"""A request/notification handler: ``(ctx, params) -> result``. Typed loosely +so the existing `ServerRequestContext`-based handlers and the new +`Context`-based handlers both fit during the transition. +""" + +_INIT_EXEMPT: frozenset[str] = frozenset({"ping"}) + +# TODO: remove this lookup once `Server` stores (params_type, handler) in its +# registry directly. This is scaffolding so ServerRunner can validate params +# without changing the existing `_request_handlers` dict shape. +_PARAMS_FOR_METHOD: dict[str, type[BaseModel]] = { + "ping": RequestParams, + "tools/list": PaginatedRequestParams, + "tools/call": CallToolRequestParams, + "prompts/list": PaginatedRequestParams, + "prompts/get": GetPromptRequestParams, + "resources/list": PaginatedRequestParams, + "resources/templates/list": PaginatedRequestParams, + "resources/read": ReadResourceRequestParams, + "resources/subscribe": SubscribeRequestParams, + "resources/unsubscribe": UnsubscribeRequestParams, + "logging/setLevel": SetLevelRequestParams, + "completion/complete": CompleteRequestParams, +} +"""Spec method → params model. Scaffolding while the lowlevel `Server`'s +`_request_handlers` stores handler-only; the registry refactor should make this +the registry's responsibility (or store params types alongside handlers).""" + +_PARAMS_FOR_NOTIFICATION: dict[str, type[BaseModel]] = { + "notifications/initialized": NotificationParams, + "notifications/roots/list_changed": NotificationParams, + "notifications/progress": ProgressNotificationParams, +} + + +class ServerRegistry(Protocol): + """The handler registry `ServerRunner` consumes. + + The lowlevel `Server` satisfies this via additive methods. + """ + + @property + def name(self) -> str: ... + @property + def version(self) -> str | None: ... + + def get_request_handler(self, method: str) -> Handler | None: ... + def get_notification_handler(self, method: str) -> Handler | None: ... + def get_capabilities( + self, notification_options: Any, experimental_capabilities: dict[str, dict[str, Any]] + ) -> ServerCapabilities: ... + + +def _dump_result(result: Any) -> dict[str, Any]: + if result is None: + return {} + if isinstance(result, BaseModel): + return result.model_dump(by_alias=True, mode="json", exclude_none=True) + if isinstance(result, dict): + return cast(dict[str, Any], result) + raise TypeError(f"handler returned {type(result).__name__}; expected BaseModel, dict, or None") + + +@dataclass +class ServerRunner(Generic[LifespanT, ServerTransportT]): + """Per-connection orchestrator. One instance per client connection.""" + + server: ServerRegistry + dispatcher: Dispatcher[ServerTransportT] + lifespan_state: LifespanT + has_standalone_channel: bool + stateless: bool = False + + connection: Connection = field(init=False) + _initialized: bool = field(init=False) + + def __post_init__(self) -> None: + self._initialized = self.stateless + self.connection = Connection(self.dispatcher, has_standalone_channel=self.has_standalone_channel) + + async def _on_request( + self, + dctx: DispatchContext[TransportContext], + method: str, + params: Mapping[str, Any] | None, + ) -> dict[str, Any]: + if method == "initialize": + return self._handle_initialize(params) + if not self._initialized and method not in _INIT_EXEMPT: + raise MCPError( + code=INVALID_REQUEST, + message=f"Received {method!r} before initialization was complete", + ) + handler = self.server.get_request_handler(method) + if handler is None: + raise MCPError(code=METHOD_NOT_FOUND, message=f"Method not found: {method}") + # TODO: scaffolding — params_type comes from a static lookup until the + # registry stores it alongside the handler. + params_type = _PARAMS_FOR_METHOD.get(method, RequestParams) + # ValidationError propagates; the dispatcher's exception boundary maps + # it to INVALID_PARAMS. + typed_params = params_type.model_validate(params or {}) + ctx = self._make_context(dctx, typed_params) + result = await handler(ctx, typed_params) + return _dump_result(result) + + async def _on_notify( + self, + dctx: DispatchContext[TransportContext], + method: str, + params: Mapping[str, Any] | None, + ) -> None: + if method == "notifications/initialized": + self._initialized = True + self.connection.initialized.set() + return + if not self._initialized: + logger.debug("dropped %s: received before initialization", method) + return + handler = self.server.get_notification_handler(method) + if handler is None: + logger.debug("no handler for notification %s", method) + return + params_type = _PARAMS_FOR_NOTIFICATION.get(method, NotificationParams) + typed_params = params_type.model_validate(params or {}) + ctx = self._make_context(dctx, typed_params) + await handler(ctx, typed_params) + + def _make_context( + self, dctx: DispatchContext[TransportContext], typed_params: BaseModel + ) -> Context[LifespanT, ServerTransportT]: + # `OnRequest` delivers `DispatchContext[TransportContext]`; this + # ServerRunner instance was constructed for a specific + # `ServerTransportT`, so the narrow is safe by construction. + narrowed = cast(DispatchContext[ServerTransportT], dctx) + meta = getattr(typed_params, "meta", None) + return Context(narrowed, lifespan=self.lifespan_state, connection=self.connection, meta=meta) + + def _handle_initialize(self, params: Mapping[str, Any] | None) -> dict[str, Any]: + init = InitializeRequestParams.model_validate(params or {}) + self.connection.client_info = init.client_info + self.connection.client_capabilities = init.capabilities + # TODO: real version negotiation. This always responds with LATEST, + # which is wrong — the server should pick the highest version both + # sides support and compute a per-connection feature set from it. + # See FOLLOWUPS: "Consolidate per-connection mode/negotiation". + self.connection.protocol_version = ( + init.protocol_version if init.protocol_version in {LATEST_PROTOCOL_VERSION} else LATEST_PROTOCOL_VERSION + ) + self._initialized = True + self.connection.initialized.set() + result = InitializeResult( + protocol_version=self.connection.protocol_version, + capabilities=self.server.get_capabilities(NotificationOptions(), {}), + server_info=Implementation(name=self.server.name, version=self.server.version or "0.0.0"), + ) + return _dump_result(result) diff --git a/tests/server/test_runner.py b/tests/server/test_runner.py new file mode 100644 index 0000000000..5bff4b2888 --- /dev/null +++ b/tests/server/test_runner.py @@ -0,0 +1,154 @@ +"""Tests for `ServerRunner`. + +End-to-end over `DirectDispatcher` with a real lowlevel `Server` as the +registry. Covers `_on_request` routing, the initialize handshake, the +init-gate, and that handlers receive a fully-built `Context`. +""" + +from typing import Any + +import anyio +import pytest + +from mcp.server.connection import Connection +from mcp.server.context import Context +from mcp.server.lowlevel.server import Server +from mcp.server.runner import ServerRunner +from mcp.shared.direct_dispatcher import create_direct_dispatcher_pair +from mcp.shared.exceptions import MCPError +from mcp.shared.transport_context import TransportContext +from mcp.types import ( + INVALID_REQUEST, + LATEST_PROTOCOL_VERSION, + METHOD_NOT_FOUND, + ClientCapabilities, + Implementation, + InitializeRequestParams, + Tool, +) + +from ..shared.test_dispatcher import Recorder, echo_handlers + + +def _initialize_params() -> dict[str, Any]: + return InitializeRequestParams( + protocol_version=LATEST_PROTOCOL_VERSION, + capabilities=ClientCapabilities(), + client_info=Implementation(name="test-client", version="1.0"), + ).model_dump(by_alias=True, exclude_none=True) + + +_seen_ctx: list[Context[Any, TransportContext]] = [] +SrvT = Server[dict[str, Any]] + + +@pytest.fixture +def server() -> SrvT: + """A lowlevel Server with one tools/list handler registered.""" + _seen_ctx.clear() + + async def list_tools(ctx: Any, params: Any) -> Any: + # ctx is typed `Any` because Server's on_list_tools kwarg expects the + # legacy ServerRequestContext shape; ServerRunner passes the new + # `Context`. The transition is intentional — Handler is loosely typed. + _seen_ctx.append(ctx) + return {"tools": [Tool(name="t", input_schema={"type": "object"}).model_dump(by_alias=True)]} + + return Server(name="test-server", version="0.0.1", on_list_tools=list_tools) + + +@pytest.mark.anyio +async def test_runner_handles_initialize_and_populates_connection(server: SrvT): + client, server_d = create_direct_dispatcher_pair() + runner = ServerRunner( + server=server, + dispatcher=server_d, + lifespan_state=None, + has_standalone_channel=True, + ) + c_req, c_notify = echo_handlers(Recorder()) + async with anyio.create_task_group() as tg: + await tg.start(client.run, c_req, c_notify) + await tg.start(server_d.run, runner._on_request, runner._on_notify) + with anyio.fail_after(5): + result = await client.send_raw_request("initialize", _initialize_params()) + assert result["serverInfo"]["name"] == "test-server" + assert "tools" in result["capabilities"] + assert runner.connection.client_info is not None + assert runner.connection.client_info.name == "test-client" + assert runner.connection.protocol_version == LATEST_PROTOCOL_VERSION + assert runner._initialized is True + tg.cancel_scope.cancel() + + +@pytest.mark.anyio +async def test_runner_gates_requests_before_initialize(server: SrvT): + client, server_d = create_direct_dispatcher_pair() + runner = ServerRunner(server=server, dispatcher=server_d, lifespan_state=None, has_standalone_channel=True) + c_req, c_notify = echo_handlers(Recorder()) + async with anyio.create_task_group() as tg: + await tg.start(client.run, c_req, c_notify) + await tg.start(server_d.run, runner._on_request, runner._on_notify) + with anyio.fail_after(5): + with pytest.raises(MCPError) as exc: + await client.send_raw_request("tools/list", None) + assert exc.value.error.code == INVALID_REQUEST + # ping is exempt + assert await client.send_raw_request("ping", None) == {} + tg.cancel_scope.cancel() + + +@pytest.mark.anyio +async def test_runner_routes_to_handler_after_initialize_and_builds_context(server: SrvT): + client, server_d = create_direct_dispatcher_pair() + runner = ServerRunner(server=server, dispatcher=server_d, lifespan_state=None, has_standalone_channel=True) + c_req, c_notify = echo_handlers(Recorder()) + async with anyio.create_task_group() as tg: + await tg.start(client.run, c_req, c_notify) + await tg.start(server_d.run, runner._on_request, runner._on_notify) + with anyio.fail_after(5): + await client.send_raw_request("initialize", _initialize_params()) + result = await client.send_raw_request("tools/list", None) + assert result["tools"][0]["name"] == "t" + ctx = _seen_ctx[0] + assert isinstance(ctx, Context) + assert ctx.lifespan is None + assert isinstance(ctx.connection, Connection) + assert ctx.transport.kind == "direct" + tg.cancel_scope.cancel() + + +@pytest.mark.anyio +async def test_runner_unknown_method_raises_method_not_found(server: SrvT): + client, server_d = create_direct_dispatcher_pair() + runner = ServerRunner(server=server, dispatcher=server_d, lifespan_state=None, has_standalone_channel=True) + runner._initialized = True # bypass gate for this test + c_req, c_notify = echo_handlers(Recorder()) + async with anyio.create_task_group() as tg: + await tg.start(client.run, c_req, c_notify) + await tg.start(server_d.run, runner._on_request, runner._on_notify) + with anyio.fail_after(5): + with pytest.raises(MCPError) as exc: + await client.send_raw_request("nonexistent/method", None) + assert exc.value.error.code == METHOD_NOT_FOUND + tg.cancel_scope.cancel() + + +@pytest.mark.anyio +async def test_runner_stateless_skips_init_gate(server: SrvT): + client, server_d = create_direct_dispatcher_pair() + runner = ServerRunner( + server=server, + dispatcher=server_d, + lifespan_state=None, + has_standalone_channel=False, + stateless=True, + ) + c_req, c_notify = echo_handlers(Recorder()) + async with anyio.create_task_group() as tg: + await tg.start(client.run, c_req, c_notify) + await tg.start(server_d.run, runner._on_request, runner._on_notify) + with anyio.fail_after(5): + result = await client.send_raw_request("tools/list", None) + assert result["tools"][0]["name"] == "t" + tg.cancel_scope.cancel() From 9f740eb7d35aff5d77af276e653282d989adb7ab Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Mon, 20 Apr 2026 19:02:57 +0000 Subject: [PATCH 21/27] feat: ServerRunner middleware (two-tier) + _on_notify ContextMiddleware is a Protocol[L] (contravariant) so Server[L].middleware: list[ContextMiddleware[L]] is properly typed. App-specific middleware sees ctx.lifespan: L; reusable middleware typed ContextMiddleware[object] registers on any Server via contravariance. Context's covariance (previous PR3 commit) makes Context[L, ST] <: Context[L, TransportContext] so the chain composes without casts. dispatch_middleware (DispatchMiddleware list on ServerRunner) wraps the raw _on_request and sees everything including initialize/METHOD_NOT_FOUND. server.middleware (ContextMiddleware) runs inside _on_request after validation/ctx-build and wraps registered handlers only. _on_notify routes notifications/initialized (sets the flag), drops before-init and unknown methods, otherwise builds Context and calls the registered handler. 11 tests over DirectDispatcher + a real lowlevel Server. --- src/mcp/server/context.py | 36 +++++++- src/mcp/server/lowlevel/server.py | 10 +-- src/mcp/server/runner.py | 30 +++++-- tests/server/test_runner.py | 138 ++++++++++++++++++++++++++++++ 4 files changed, 201 insertions(+), 13 deletions(-) diff --git a/src/mcp/server/context.py b/src/mcp/server/context.py index 4d35f8a902..1c855ae48a 100644 --- a/src/mcp/server/context.py +++ b/src/mcp/server/context.py @@ -1,8 +1,10 @@ from __future__ import annotations +from collections.abc import Awaitable, Callable from dataclasses import dataclass -from typing import Any, Generic +from typing import Any, Generic, Protocol +from pydantic import BaseModel from typing_extensions import TypeVar from mcp.server._typed_request import TypedServerRequestMixin @@ -81,3 +83,35 @@ async def log(self, level: LoggingLevel, data: Any, logger: str | None = None, * if meta: params["_meta"] = meta await self.notify("notifications/message", params) + + +HandlerResult = BaseModel | dict[str, Any] | None +"""What a request handler (or middleware) may return. `ServerRunner` serializes +all three to a result dict.""" + +CallNext = Callable[[], Awaitable[HandlerResult]] + +_MwLifespanT = TypeVar("_MwLifespanT", contravariant=True) + + +class ContextMiddleware(Protocol[_MwLifespanT]): + """Context-tier middleware: ``(ctx, method, typed_params, call_next) -> result``. + + Runs *inside* `ServerRunner._on_request` after params validation and + `Context` construction. Wraps registered handlers (including ``ping``) but + not ``initialize``, ``METHOD_NOT_FOUND``, or validation failures. Listed + outermost-first on `Server.middleware`. + + `Server[L].middleware` holds `ContextMiddleware[L]`, so an app-specific + middleware sees `ctx.lifespan: L`. A reusable middleware (no app-specific + types) can be typed `ContextMiddleware[object]` — `Context` is covariant in + `LifespanT`, so it registers on any `Server[L]`. + """ + + async def __call__( + self, + ctx: Context[_MwLifespanT, TransportContext], + method: str, + params: BaseModel, + call_next: CallNext, + ) -> HandlerResult: ... diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index de12832dc5..466c158bd4 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -58,7 +58,7 @@ async def main(): from mcp.server.auth.provider import OAuthAuthorizationServerProvider, TokenVerifier from mcp.server.auth.routes import build_resource_metadata_url, create_auth_routes, create_protected_resource_routes from mcp.server.auth.settings import AuthSettings -from mcp.server.context import ServerRequestContext +from mcp.server.context import ContextMiddleware, ServerRequestContext from mcp.server.experimental.request_context import Experimental from mcp.server.lowlevel.experimental import ExperimentalHandlers from mcp.server.models import InitializationOptions @@ -199,6 +199,9 @@ def __init__( ] = {} self._experimental_handlers: ExperimentalHandlers[LifespanResultT] | None = None self._session_manager: StreamableHTTPSessionManager | None = None + # Context-tier middleware consumed by `ServerRunner`. Additive; the + # existing `run()` path ignores it. + self.middleware: list[ContextMiddleware[LifespanResultT]] = [] logger.debug("Initializing server %r", name) # Populate internal handler dicts from on_* kwargs @@ -256,11 +259,6 @@ def get_notification_handler(self, method: str) -> Callable[..., Awaitable[Any]] """Return the handler for a notification method, or ``None``.""" return self._notification_handlers.get(method) - @property - def middleware(self) -> list[Any]: - """Context-tier middleware. Empty until the registry refactor adds registration.""" - return [] - @property def connection_lifespan(self) -> None: """Per-connection lifespan. ``None`` until the registry refactor adds it.""" diff --git a/src/mcp/server/runner.py b/src/mcp/server/runner.py index 66093b250a..a7dae289f7 100644 --- a/src/mcp/server/runner.py +++ b/src/mcp/server/runner.py @@ -17,17 +17,18 @@ from __future__ import annotations import logging -from collections.abc import Awaitable, Callable, Mapping +from collections.abc import Awaitable, Callable, Mapping, Sequence from dataclasses import dataclass, field +from functools import partial, reduce from typing import Any, Generic, Protocol, cast from pydantic import BaseModel from typing_extensions import TypeVar from mcp.server.connection import Connection -from mcp.server.context import Context +from mcp.server.context import CallNext, Context, ContextMiddleware from mcp.server.lowlevel.server import NotificationOptions -from mcp.shared.dispatcher import DispatchContext, Dispatcher +from mcp.shared.dispatcher import DispatchContext, Dispatcher, DispatchMiddleware, OnRequest from mcp.shared.exceptions import MCPError from mcp.shared.transport_context import TransportContext from mcp.types import ( @@ -51,7 +52,7 @@ UnsubscribeRequestParams, ) -__all__ = ["ServerRegistry", "ServerRunner"] +__all__ = ["CallNext", "ContextMiddleware", "ServerRegistry", "ServerRunner"] logger = logging.getLogger(__name__) @@ -64,6 +65,7 @@ `Context`-based handlers both fit during the transition. """ + _INIT_EXEMPT: frozenset[str] = frozenset({"ping"}) # TODO: remove this lookup once `Server` stores (params_type, handler) in its @@ -105,6 +107,9 @@ def name(self) -> str: ... @property def version(self) -> str | None: ... + @property + def middleware(self) -> Sequence[ContextMiddleware[Any]]: ... + def get_request_handler(self, method: str) -> Handler | None: ... def get_notification_handler(self, method: str) -> Handler | None: ... def get_capabilities( @@ -131,6 +136,7 @@ class ServerRunner(Generic[LifespanT, ServerTransportT]): lifespan_state: LifespanT has_standalone_channel: bool stateless: bool = False + dispatch_middleware: list[DispatchMiddleware] = field(default_factory=list[DispatchMiddleware]) connection: Connection = field(init=False) _initialized: bool = field(init=False) @@ -139,6 +145,16 @@ def __post_init__(self) -> None: self._initialized = self.stateless self.connection = Connection(self.dispatcher, has_standalone_channel=self.has_standalone_channel) + def _compose_on_request(self) -> OnRequest: + """Wrap `_on_request` in `dispatch_middleware`, outermost-first. + + Dispatch-tier middleware sees raw ``(dctx, method, params) -> dict`` + and wraps everything — initialize, METHOD_NOT_FOUND, validation + failures included. `run()` calls this once and hands the result to + `dispatcher.run()`. + """ + return reduce(lambda h, mw: mw(h), reversed(self.dispatch_middleware), self._on_request) + async def _on_request( self, dctx: DispatchContext[TransportContext], @@ -162,8 +178,10 @@ async def _on_request( # it to INVALID_PARAMS. typed_params = params_type.model_validate(params or {}) ctx = self._make_context(dctx, typed_params) - result = await handler(ctx, typed_params) - return _dump_result(result) + call: CallNext = partial(handler, ctx, typed_params) + for mw in reversed(self.server.middleware): + call = partial(mw, ctx, method, typed_params, call) + return _dump_result(await call()) async def _on_notify( self, diff --git a/tests/server/test_runner.py b/tests/server/test_runner.py index 5bff4b2888..eca10497c5 100644 --- a/tests/server/test_runner.py +++ b/tests/server/test_runner.py @@ -8,6 +8,7 @@ from typing import Any import anyio +import anyio.lowlevel import pytest from mcp.server.connection import Connection @@ -134,6 +135,143 @@ async def test_runner_unknown_method_raises_method_not_found(server: SrvT): tg.cancel_scope.cancel() +@pytest.mark.anyio +async def test_runner_on_notify_initialized_sets_flag_and_connection_event(server: SrvT): + client, server_d = create_direct_dispatcher_pair() + runner = ServerRunner(server=server, dispatcher=server_d, lifespan_state=None, has_standalone_channel=True) + c_req, c_notify = echo_handlers(Recorder()) + async with anyio.create_task_group() as tg: + await tg.start(client.run, c_req, c_notify) + await tg.start(server_d.run, runner._on_request, runner._on_notify) + with anyio.fail_after(5): + await client.notify("notifications/initialized", None) + await runner.connection.initialized.wait() + assert runner._initialized is True + tg.cancel_scope.cancel() + + +@pytest.mark.anyio +async def test_runner_on_notify_routes_to_registered_handler(server: SrvT): + seen: list[tuple[Any, Any]] = [] + + async def on_roots_changed(ctx: Any, params: Any) -> None: + seen.append((ctx, params)) + + server._notification_handlers["notifications/roots/list_changed"] = on_roots_changed + client, server_d = create_direct_dispatcher_pair() + runner = ServerRunner(server=server, dispatcher=server_d, lifespan_state=None, has_standalone_channel=True) + runner._initialized = True + c_req, c_notify = echo_handlers(Recorder()) + async with anyio.create_task_group() as tg: + await tg.start(client.run, c_req, c_notify) + await tg.start(server_d.run, runner._on_request, runner._on_notify) + with anyio.fail_after(5): + await client.notify("notifications/roots/list_changed", None) + # DirectDispatcher delivers synchronously; one yield is enough. + await anyio.lowlevel.checkpoint() + assert len(seen) == 1 + assert isinstance(seen[0][0], Context) + tg.cancel_scope.cancel() + + +@pytest.mark.anyio +async def test_runner_on_notify_drops_before_init_and_unknown_methods(server: SrvT): + client, server_d = create_direct_dispatcher_pair() + runner = ServerRunner(server=server, dispatcher=server_d, lifespan_state=None, has_standalone_channel=True) + c_req, c_notify = echo_handlers(Recorder()) + async with anyio.create_task_group() as tg: + await tg.start(client.run, c_req, c_notify) + await tg.start(server_d.run, runner._on_request, runner._on_notify) + with anyio.fail_after(5): + await client.notify("notifications/roots/list_changed", None) # before init: dropped + await client.notify("notifications/initialized", None) + await client.notify("notifications/unknown", None) # no handler: dropped + # No exception raised; both drops are silent. + tg.cancel_scope.cancel() + + +@pytest.mark.anyio +async def test_runner_dispatch_middleware_wraps_everything_including_initialize(server: SrvT): + seen_methods: list[str] = [] + + def trace_mw(next_on_request: Any) -> Any: + async def wrapped(dctx: Any, method: str, params: Any) -> Any: + seen_methods.append(method) + return await next_on_request(dctx, method, params) + + return wrapped + + client, server_d = create_direct_dispatcher_pair() + runner = ServerRunner( + server=server, + dispatcher=server_d, + lifespan_state=None, + has_standalone_channel=True, + dispatch_middleware=[trace_mw], + ) + c_req, c_notify = echo_handlers(Recorder()) + on_req = runner._compose_on_request() + async with anyio.create_task_group() as tg: + await tg.start(client.run, c_req, c_notify) + await tg.start(server_d.run, on_req, runner._on_notify) + with anyio.fail_after(5): + await client.send_raw_request("initialize", _initialize_params()) + await client.send_raw_request("tools/list", None) + assert seen_methods == ["initialize", "tools/list"] + tg.cancel_scope.cancel() + + +@pytest.mark.anyio +async def test_runner_server_middleware_wraps_handlers_but_not_initialize(server: SrvT): + seen_methods: list[str] = [] + + async def ctx_mw(ctx: Any, method: str, params: Any, call_next: Any) -> Any: + seen_methods.append(method) + return await call_next() + + server.middleware.append(ctx_mw) + client, server_d = create_direct_dispatcher_pair() + runner = ServerRunner(server=server, dispatcher=server_d, lifespan_state=None, has_standalone_channel=True) + c_req, c_notify = echo_handlers(Recorder()) + async with anyio.create_task_group() as tg: + await tg.start(client.run, c_req, c_notify) + await tg.start(server_d.run, runner._on_request, runner._on_notify) + with anyio.fail_after(5): + await client.send_raw_request("initialize", _initialize_params()) + await client.send_raw_request("ping", None) + await client.send_raw_request("tools/list", None) + # initialize NOT wrapped; ping and tools/list ARE wrapped. + assert seen_methods == ["ping", "tools/list"] + tg.cancel_scope.cancel() + + +@pytest.mark.anyio +async def test_runner_server_middleware_runs_outermost_first(server: SrvT): + order: list[str] = [] + + def make_mw(tag: str) -> Any: + async def mw(ctx: Any, method: str, params: Any, call_next: Any) -> Any: + order.append(f"{tag}-in") + result = await call_next() + order.append(f"{tag}-out") + return result + + return mw + + server.middleware.extend([make_mw("a"), make_mw("b")]) + client, server_d = create_direct_dispatcher_pair() + runner = ServerRunner(server=server, dispatcher=server_d, lifespan_state=None, has_standalone_channel=True) + runner._initialized = True + c_req, c_notify = echo_handlers(Recorder()) + async with anyio.create_task_group() as tg: + await tg.start(client.run, c_req, c_notify) + await tg.start(server_d.run, runner._on_request, runner._on_notify) + with anyio.fail_after(5): + await client.send_raw_request("tools/list", None) + assert order == ["a-in", "b-in", "b-out", "a-out"] + tg.cancel_scope.cancel() + + @pytest.mark.anyio async def test_runner_stateless_skips_init_gate(server: SrvT): client, server_d = create_direct_dispatcher_pair() From b201e125af9c4a7c27ec45b1c22ba3bef875d154 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Wed, 22 Apr 2026 01:28:51 +0000 Subject: [PATCH 22/27] feat: ServerRunner.run() and otel_middleware run() composes dispatch_middleware over _on_request and forwards task_status to dispatcher.run() so callers can 'await tg.start(runner.run)'. otel_middleware is a DispatchMiddleware that wraps each request in a span, mirroring the existing Server._handle_request span shape: name 'MCP handle []', mcp.method.name attribute, W3C trace context extracted from params._meta (SEP-414), and ERROR status if the handler raises. connection_lifespan plumbing (the enter-late dance) is deferred to a separate commit since Server.connection_lifespan is None today. --- src/mcp/server/runner.py | 53 ++++++++++++++++++++++++++- tests/server/test_runner.py | 71 ++++++++++++++++++++++++++++++++++++- 2 files changed, 122 insertions(+), 2 deletions(-) diff --git a/src/mcp/server/runner.py b/src/mcp/server/runner.py index a7dae289f7..79dfc23e0e 100644 --- a/src/mcp/server/runner.py +++ b/src/mcp/server/runner.py @@ -22,12 +22,15 @@ from functools import partial, reduce from typing import Any, Generic, Protocol, cast +import anyio.abc +from opentelemetry.trace import SpanKind, StatusCode from pydantic import BaseModel from typing_extensions import TypeVar from mcp.server.connection import Connection from mcp.server.context import CallNext, Context, ContextMiddleware from mcp.server.lowlevel.server import NotificationOptions +from mcp.shared._otel import extract_trace_context, otel_span from mcp.shared.dispatcher import DispatchContext, Dispatcher, DispatchMiddleware, OnRequest from mcp.shared.exceptions import MCPError from mcp.shared.transport_context import TransportContext @@ -52,7 +55,7 @@ UnsubscribeRequestParams, ) -__all__ = ["CallNext", "ContextMiddleware", "ServerRegistry", "ServerRunner"] +__all__ = ["CallNext", "ContextMiddleware", "ServerRegistry", "ServerRunner", "otel_middleware"] logger = logging.getLogger(__name__) @@ -117,6 +120,44 @@ def get_capabilities( ) -> ServerCapabilities: ... +def otel_middleware(next_on_request: OnRequest) -> OnRequest: + """Dispatch-tier middleware that wraps each request in an OpenTelemetry span. + + Mirrors the span shape of the existing `Server._handle_request`: span name + ``"MCP handle []"``, ``mcp.method.name`` attribute, W3C + trace context extracted from ``params._meta`` (SEP-414), and an ERROR + status if the handler raises. + """ + + async def wrapped( + dctx: DispatchContext[TransportContext], method: str, params: Mapping[str, Any] | None + ) -> dict[str, Any]: + target: str | None + match params: + case {"name": str() as target}: + pass + case _: + target = None + parent: Any | None + match params: + case {"_meta": {**meta}}: + parent = extract_trace_context(meta) + case _: + parent = None + span_name = f"MCP handle {method}{f' {target}' if target else ''}" + with otel_span(span_name, kind=SpanKind.SERVER, attributes={"mcp.method.name": method}, context=parent) as span: + try: + return await next_on_request(dctx, method, params) + except MCPError as e: + span.set_status(StatusCode.ERROR, e.error.message) + raise + except Exception as e: + span.set_status(StatusCode.ERROR, str(e)) + raise + + return wrapped + + def _dump_result(result: Any) -> dict[str, Any]: if result is None: return {} @@ -145,6 +186,16 @@ def __post_init__(self) -> None: self._initialized = self.stateless self.connection = Connection(self.dispatcher, has_standalone_channel=self.has_standalone_channel) + async def run(self, *, task_status: anyio.abc.TaskStatus[None] = anyio.TASK_STATUS_IGNORED) -> None: + """Drive the dispatcher until the underlying channel closes. + + Composes `dispatch_middleware` over `_on_request` and hands the result + to `dispatcher.run()`. ``task_status.started()`` is forwarded so callers + can ``await tg.start(runner.run)`` and resume once the dispatcher is + ready to accept requests. + """ + await self.dispatcher.run(self._compose_on_request(), self._on_notify, task_status=task_status) + def _compose_on_request(self) -> OnRequest: """Wrap `_on_request` in `dispatch_middleware`, outermost-first. diff --git a/tests/server/test_runner.py b/tests/server/test_runner.py index eca10497c5..3d2fd84c0c 100644 --- a/tests/server/test_runner.py +++ b/tests/server/test_runner.py @@ -14,7 +14,7 @@ from mcp.server.connection import Connection from mcp.server.context import Context from mcp.server.lowlevel.server import Server -from mcp.server.runner import ServerRunner +from mcp.server.runner import ServerRunner, otel_middleware from mcp.shared.direct_dispatcher import create_direct_dispatcher_pair from mcp.shared.exceptions import MCPError from mcp.shared.transport_context import TransportContext @@ -272,6 +272,75 @@ async def mw(ctx: Any, method: str, params: Any, call_next: Any) -> Any: tg.cancel_scope.cancel() +@pytest.mark.anyio +async def test_runner_run_drives_dispatcher_end_to_end(server: SrvT): + client, server_d = create_direct_dispatcher_pair() + runner = ServerRunner(server=server, dispatcher=server_d, lifespan_state=None, has_standalone_channel=True) + c_req, c_notify = echo_handlers(Recorder()) + async with anyio.create_task_group() as tg: + await tg.start(client.run, c_req, c_notify) + await tg.start(runner.run) + with anyio.fail_after(5): + init = await client.send_raw_request("initialize", _initialize_params()) + tools = await client.send_raw_request("tools/list", None) + assert init["serverInfo"]["name"] == "test-server" + assert tools["tools"][0]["name"] == "t" + tg.cancel_scope.cancel() + + +@pytest.mark.anyio +async def test_runner_run_applies_dispatch_middleware(server: SrvT): + seen: list[str] = [] + + def trace_mw(next_on_request: Any) -> Any: + async def wrapped(dctx: Any, method: str, params: Any) -> Any: + seen.append(method) + return await next_on_request(dctx, method, params) + + return wrapped + + client, server_d = create_direct_dispatcher_pair() + runner = ServerRunner( + server=server, + dispatcher=server_d, + lifespan_state=None, + has_standalone_channel=True, + dispatch_middleware=[trace_mw], + ) + c_req, c_notify = echo_handlers(Recorder()) + async with anyio.create_task_group() as tg: + await tg.start(client.run, c_req, c_notify) + await tg.start(runner.run) + with anyio.fail_after(5): + await client.send_raw_request("initialize", _initialize_params()) + await client.send_raw_request("ping", None) + assert seen == ["initialize", "ping"] + tg.cancel_scope.cancel() + + +@pytest.mark.anyio +async def test_otel_middleware_passes_through_result_and_survives_handler_error(server: SrvT): + client, server_d = create_direct_dispatcher_pair() + runner = ServerRunner( + server=server, + dispatcher=server_d, + lifespan_state=None, + has_standalone_channel=True, + dispatch_middleware=[otel_middleware], + ) + c_req, c_notify = echo_handlers(Recorder()) + async with anyio.create_task_group() as tg: + await tg.start(client.run, c_req, c_notify) + await tg.start(runner.run) + with anyio.fail_after(5): + await client.send_raw_request("initialize", _initialize_params()) + tools = await client.send_raw_request("tools/list", None) + assert tools["tools"][0]["name"] == "t" + with pytest.raises(MCPError): + await client.send_raw_request("nonexistent/method", None) + tg.cancel_scope.cancel() + + @pytest.mark.anyio async def test_runner_stateless_skips_init_gate(server: SrvT): client, server_d = create_direct_dispatcher_pair() From a7313fdd873347ad84e6d4e316a09fe578ba0dc7 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Sat, 25 Apr 2026 21:40:05 +0000 Subject: [PATCH 23/27] =?UTF-8?q?test:=20ServerRunner=20coverage=20to=2010?= =?UTF-8?q?0%=20=E2=80=94=20otel=20span=20assertions=20+=20connected=5Frun?= =?UTF-8?q?ner=20harness?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add opentelemetry-sdk as a dev dep and a tests/server/conftest.py 'spans' fixture (TracerProvider + InMemorySpanExporter) so otel_middleware's span contract is observable. - Replace the otel pass-through test with four span-asserting tests (name + target, _meta traceparent → parent, MCPError → ERROR status without traceback, unexpected exception → ERROR status + exception event). These surfaced that start_as_current_span's default set_status_on_exception / record_exception was overwriting the middleware's explicit set_status and attaching tracebacks to protocol-level MCPErrors — now disabled and handled explicitly. - Add handler-return contract tests (None → {}, unsupported → INTERNAL_ERROR). - Introduce connected_runner async-contextmanager test harness and retrofit all tests through runner.run(); drop two tests made redundant by that. Harness closes dispatchers gracefully and re-raises body exceptions outside the task group so failures aren't ExceptionGroup-wrapped (and to avoid a coverage.py trace-loss false-negative on cancel-during-aexit). - Remove the unused Server.connection_lifespan placeholder; it lands with its consumer. --- pyproject.toml | 1 + src/mcp/server/lowlevel/server.py | 5 - src/mcp/server/runner.py | 10 +- src/mcp/shared/_otel.py | 11 +- tests/server/conftest.py | 34 +++ tests/server/test_runner.py | 401 ++++++++++++++---------------- uv.lock | 2 + 7 files changed, 246 insertions(+), 218 deletions(-) create mode 100644 tests/server/conftest.py diff --git a/pyproject.toml b/pyproject.toml index 6d2319621a..5f51fa9b85 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -91,6 +91,7 @@ dev = [ "pillow>=12.0", "strict-no-cover", "logfire>=3.0.0", + "opentelemetry-sdk>=1.39.1", ] docs = [ "mkdocs>=1.6.1", diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index 466c158bd4..a863246a18 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -259,11 +259,6 @@ def get_notification_handler(self, method: str) -> Callable[..., Awaitable[Any]] """Return the handler for a notification method, or ``None``.""" return self._notification_handlers.get(method) - @property - def connection_lifespan(self) -> None: - """Per-connection lifespan. ``None`` until the registry refactor adds it.""" - return None - # TODO: Rethink capabilities API. Currently capabilities are derived from registered # handlers but require NotificationOptions to be passed externally for list_changed # flags, and experimental_capabilities as a separate dict. Consider deriving capabilities diff --git a/src/mcp/server/runner.py b/src/mcp/server/runner.py index 79dfc23e0e..bb3af04435 100644 --- a/src/mcp/server/runner.py +++ b/src/mcp/server/runner.py @@ -145,13 +145,21 @@ async def wrapped( case _: parent = None span_name = f"MCP handle {method}{f' {target}' if target else ''}" - with otel_span(span_name, kind=SpanKind.SERVER, attributes={"mcp.method.name": method}, context=parent) as span: + with otel_span( + span_name, + kind=SpanKind.SERVER, + attributes={"mcp.method.name": method}, + context=parent, + record_exception=False, + set_status_on_exception=False, + ) as span: try: return await next_on_request(dctx, method, params) except MCPError as e: span.set_status(StatusCode.ERROR, e.error.message) raise except Exception as e: + span.record_exception(e) span.set_status(StatusCode.ERROR, str(e)) raise diff --git a/src/mcp/shared/_otel.py b/src/mcp/shared/_otel.py index 170e873a0f..553b8a0bce 100644 --- a/src/mcp/shared/_otel.py +++ b/src/mcp/shared/_otel.py @@ -20,9 +20,18 @@ def otel_span( kind: SpanKind, attributes: dict[str, Any] | None = None, context: Context | None = None, + record_exception: bool = True, + set_status_on_exception: bool = True, ) -> Iterator[Any]: """Create an OTel span.""" - with _tracer.start_as_current_span(name, kind=kind, attributes=attributes, context=context) as span: + with _tracer.start_as_current_span( + name, + kind=kind, + attributes=attributes, + context=context, + record_exception=record_exception, + set_status_on_exception=set_status_on_exception, + ) as span: yield span diff --git a/tests/server/conftest.py b/tests/server/conftest.py new file mode 100644 index 0000000000..37202f529e --- /dev/null +++ b/tests/server/conftest.py @@ -0,0 +1,34 @@ +"""Shared fixtures for server-side tests.""" + +from collections.abc import Iterator + +import pytest +from opentelemetry import trace +from opentelemetry.sdk.trace import TracerProvider +from opentelemetry.sdk.trace.export import SimpleSpanProcessor +from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter + +_span_exporter = InMemorySpanExporter() + + +@pytest.fixture(scope="session") +def _tracer_provider() -> TracerProvider: + """Install a real OTel SDK tracer provider once per test session. + + The runtime dependency is ``opentelemetry-api`` only, which yields no-op + ``NonRecordingSpan`` objects. Tests that need to assert on emitted spans + request the `spans` fixture, which depends on this one to make the global + tracer record into an in-memory exporter. + """ + provider = TracerProvider() + provider.add_span_processor(SimpleSpanProcessor(_span_exporter)) + trace.set_tracer_provider(provider) + return provider + + +@pytest.fixture +def spans(_tracer_provider: TracerProvider) -> Iterator[InMemorySpanExporter]: + """In-memory OTel span exporter, cleared before and after each test.""" + _span_exporter.clear() + yield _span_exporter + _span_exporter.clear() diff --git a/tests/server/test_runner.py b/tests/server/test_runner.py index 3d2fd84c0c..2006bf6486 100644 --- a/tests/server/test_runner.py +++ b/tests/server/test_runner.py @@ -1,24 +1,31 @@ """Tests for `ServerRunner`. End-to-end over `DirectDispatcher` with a real lowlevel `Server` as the -registry. Covers `_on_request` routing, the initialize handshake, the -init-gate, and that handlers receive a fully-built `Context`. +registry. The `connected_runner` helper starts both sides and (by default) +performs the initialize handshake, so each test exercises only the behaviour +under test. """ +from collections.abc import AsyncIterator +from contextlib import asynccontextmanager from typing import Any import anyio import anyio.lowlevel import pytest +from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter +from opentelemetry.trace import SpanKind, StatusCode from mcp.server.connection import Connection from mcp.server.context import Context from mcp.server.lowlevel.server import Server from mcp.server.runner import ServerRunner, otel_middleware -from mcp.shared.direct_dispatcher import create_direct_dispatcher_pair +from mcp.shared.direct_dispatcher import DirectDispatcher, create_direct_dispatcher_pair +from mcp.shared.dispatcher import DispatchMiddleware from mcp.shared.exceptions import MCPError from mcp.shared.transport_context import TransportContext from mcp.types import ( + INTERNAL_ERROR, INVALID_REQUEST, LATEST_PROTOCOL_VERSION, METHOD_NOT_FOUND, @@ -58,96 +65,107 @@ async def list_tools(ctx: Any, params: Any) -> Any: return Server(name="test-server", version="0.0.1", on_list_tools=list_tools) -@pytest.mark.anyio -async def test_runner_handles_initialize_and_populates_connection(server: SrvT): +@asynccontextmanager +async def connected_runner( + server: SrvT, + *, + initialized: bool = True, + stateless: bool = False, + has_standalone_channel: bool = True, + dispatch_middleware: list[DispatchMiddleware] | None = None, +) -> AsyncIterator[tuple[DirectDispatcher, ServerRunner[None, TransportContext]]]: + """Yield ``(client, runner)`` running over an in-memory dispatcher pair. + + Starts the client (echo handlers) and `runner.run()` in a task group, wraps + the body in ``anyio.fail_after(5)``, and cancels on exit. When + ``initialized`` is true the helper performs the real ``initialize`` request + before yielding, so tests start past the init-gate via the public path. + """ client, server_d = create_direct_dispatcher_pair() runner = ServerRunner( server=server, dispatcher=server_d, lifespan_state=None, - has_standalone_channel=True, + has_standalone_channel=has_standalone_channel, + stateless=stateless, + dispatch_middleware=dispatch_middleware or [], ) c_req, c_notify = echo_handlers(Recorder()) + body_exc: BaseException | None = None async with anyio.create_task_group() as tg: await tg.start(client.run, c_req, c_notify) - await tg.start(server_d.run, runner._on_request, runner._on_notify) - with anyio.fail_after(5): - result = await client.send_raw_request("initialize", _initialize_params()) - assert result["serverInfo"]["name"] == "test-server" - assert "tools" in result["capabilities"] - assert runner.connection.client_info is not None - assert runner.connection.client_info.name == "test-client" - assert runner.connection.protocol_version == LATEST_PROTOCOL_VERSION - assert runner._initialized is True - tg.cancel_scope.cancel() + await tg.start(runner.run) + try: + with anyio.fail_after(5): + if initialized: + await client.send_raw_request("initialize", _initialize_params()) + yield client, runner + except BaseException as e: + # Capture and re-raise outside the task group so test failures + # surface as the original exception, not an ExceptionGroup wrapper. + body_exc = e + client.close() + server_d.close() + if body_exc is not None: + raise body_exc + + +@pytest.mark.anyio +async def test_connected_runner_propagates_body_exception_unwrapped(server: SrvT): + """The harness re-raises body exceptions as-is, not as ``ExceptionGroup``.""" + with pytest.raises(RuntimeError, match="boom"): + async with connected_runner(server): + raise RuntimeError("boom") + + +@pytest.mark.anyio +async def test_runner_handles_initialize_and_populates_connection(server: SrvT): + async with connected_runner(server, initialized=False) as (client, runner): + result = await client.send_raw_request("initialize", _initialize_params()) + assert result["serverInfo"]["name"] == "test-server" + assert "tools" in result["capabilities"] + assert runner.connection.client_info is not None + assert runner.connection.client_info.name == "test-client" + assert runner.connection.protocol_version == LATEST_PROTOCOL_VERSION + assert runner._initialized is True @pytest.mark.anyio async def test_runner_gates_requests_before_initialize(server: SrvT): - client, server_d = create_direct_dispatcher_pair() - runner = ServerRunner(server=server, dispatcher=server_d, lifespan_state=None, has_standalone_channel=True) - c_req, c_notify = echo_handlers(Recorder()) - async with anyio.create_task_group() as tg: - await tg.start(client.run, c_req, c_notify) - await tg.start(server_d.run, runner._on_request, runner._on_notify) - with anyio.fail_after(5): - with pytest.raises(MCPError) as exc: - await client.send_raw_request("tools/list", None) - assert exc.value.error.code == INVALID_REQUEST - # ping is exempt - assert await client.send_raw_request("ping", None) == {} - tg.cancel_scope.cancel() + async with connected_runner(server, initialized=False) as (client, _): + with pytest.raises(MCPError) as exc: + await client.send_raw_request("tools/list", None) + assert exc.value.error.code == INVALID_REQUEST + # ping is exempt from the gate + assert await client.send_raw_request("ping", None) == {} @pytest.mark.anyio -async def test_runner_routes_to_handler_after_initialize_and_builds_context(server: SrvT): - client, server_d = create_direct_dispatcher_pair() - runner = ServerRunner(server=server, dispatcher=server_d, lifespan_state=None, has_standalone_channel=True) - c_req, c_notify = echo_handlers(Recorder()) - async with anyio.create_task_group() as tg: - await tg.start(client.run, c_req, c_notify) - await tg.start(server_d.run, runner._on_request, runner._on_notify) - with anyio.fail_after(5): - await client.send_raw_request("initialize", _initialize_params()) - result = await client.send_raw_request("tools/list", None) - assert result["tools"][0]["name"] == "t" - ctx = _seen_ctx[0] - assert isinstance(ctx, Context) - assert ctx.lifespan is None - assert isinstance(ctx.connection, Connection) - assert ctx.transport.kind == "direct" - tg.cancel_scope.cancel() +async def test_runner_routes_to_handler_and_builds_context(server: SrvT): + async with connected_runner(server) as (client, _): + result = await client.send_raw_request("tools/list", None) + assert result["tools"][0]["name"] == "t" + ctx = _seen_ctx[0] + assert isinstance(ctx, Context) + assert ctx.lifespan is None + assert isinstance(ctx.connection, Connection) + assert ctx.transport.kind == "direct" @pytest.mark.anyio async def test_runner_unknown_method_raises_method_not_found(server: SrvT): - client, server_d = create_direct_dispatcher_pair() - runner = ServerRunner(server=server, dispatcher=server_d, lifespan_state=None, has_standalone_channel=True) - runner._initialized = True # bypass gate for this test - c_req, c_notify = echo_handlers(Recorder()) - async with anyio.create_task_group() as tg: - await tg.start(client.run, c_req, c_notify) - await tg.start(server_d.run, runner._on_request, runner._on_notify) - with anyio.fail_after(5): - with pytest.raises(MCPError) as exc: - await client.send_raw_request("nonexistent/method", None) - assert exc.value.error.code == METHOD_NOT_FOUND - tg.cancel_scope.cancel() + async with connected_runner(server) as (client, _): + with pytest.raises(MCPError) as exc: + await client.send_raw_request("nonexistent/method", None) + assert exc.value.error.code == METHOD_NOT_FOUND @pytest.mark.anyio async def test_runner_on_notify_initialized_sets_flag_and_connection_event(server: SrvT): - client, server_d = create_direct_dispatcher_pair() - runner = ServerRunner(server=server, dispatcher=server_d, lifespan_state=None, has_standalone_channel=True) - c_req, c_notify = echo_handlers(Recorder()) - async with anyio.create_task_group() as tg: - await tg.start(client.run, c_req, c_notify) - await tg.start(server_d.run, runner._on_request, runner._on_notify) - with anyio.fail_after(5): - await client.notify("notifications/initialized", None) - await runner.connection.initialized.wait() - assert runner._initialized is True - tg.cancel_scope.cancel() + async with connected_runner(server, initialized=False) as (client, runner): + await client.notify("notifications/initialized", None) + await runner.connection.initialized.wait() + assert runner._initialized is True @pytest.mark.anyio @@ -158,36 +176,21 @@ async def on_roots_changed(ctx: Any, params: Any) -> None: seen.append((ctx, params)) server._notification_handlers["notifications/roots/list_changed"] = on_roots_changed - client, server_d = create_direct_dispatcher_pair() - runner = ServerRunner(server=server, dispatcher=server_d, lifespan_state=None, has_standalone_channel=True) - runner._initialized = True - c_req, c_notify = echo_handlers(Recorder()) - async with anyio.create_task_group() as tg: - await tg.start(client.run, c_req, c_notify) - await tg.start(server_d.run, runner._on_request, runner._on_notify) - with anyio.fail_after(5): - await client.notify("notifications/roots/list_changed", None) - # DirectDispatcher delivers synchronously; one yield is enough. - await anyio.lowlevel.checkpoint() - assert len(seen) == 1 - assert isinstance(seen[0][0], Context) - tg.cancel_scope.cancel() + async with connected_runner(server) as (client, _): + await client.notify("notifications/roots/list_changed", None) + # DirectDispatcher delivers synchronously; one yield is enough. + await anyio.lowlevel.checkpoint() + assert len(seen) == 1 + assert isinstance(seen[0][0], Context) @pytest.mark.anyio async def test_runner_on_notify_drops_before_init_and_unknown_methods(server: SrvT): - client, server_d = create_direct_dispatcher_pair() - runner = ServerRunner(server=server, dispatcher=server_d, lifespan_state=None, has_standalone_channel=True) - c_req, c_notify = echo_handlers(Recorder()) - async with anyio.create_task_group() as tg: - await tg.start(client.run, c_req, c_notify) - await tg.start(server_d.run, runner._on_request, runner._on_notify) - with anyio.fail_after(5): - await client.notify("notifications/roots/list_changed", None) # before init: dropped - await client.notify("notifications/initialized", None) - await client.notify("notifications/unknown", None) # no handler: dropped - # No exception raised; both drops are silent. - tg.cancel_scope.cancel() + async with connected_runner(server, initialized=False) as (client, _): + await client.notify("notifications/roots/list_changed", None) # before init: dropped + await client.notify("notifications/initialized", None) + await client.notify("notifications/unknown", None) # no handler: dropped + # No exception raised; both drops are silent. @pytest.mark.anyio @@ -201,24 +204,9 @@ async def wrapped(dctx: Any, method: str, params: Any) -> Any: return wrapped - client, server_d = create_direct_dispatcher_pair() - runner = ServerRunner( - server=server, - dispatcher=server_d, - lifespan_state=None, - has_standalone_channel=True, - dispatch_middleware=[trace_mw], - ) - c_req, c_notify = echo_handlers(Recorder()) - on_req = runner._compose_on_request() - async with anyio.create_task_group() as tg: - await tg.start(client.run, c_req, c_notify) - await tg.start(server_d.run, on_req, runner._on_notify) - with anyio.fail_after(5): - await client.send_raw_request("initialize", _initialize_params()) - await client.send_raw_request("tools/list", None) - assert seen_methods == ["initialize", "tools/list"] - tg.cancel_scope.cancel() + async with connected_runner(server, dispatch_middleware=[trace_mw]) as (client, _): + await client.send_raw_request("tools/list", None) + assert seen_methods == ["initialize", "tools/list"] @pytest.mark.anyio @@ -230,19 +218,11 @@ async def ctx_mw(ctx: Any, method: str, params: Any, call_next: Any) -> Any: return await call_next() server.middleware.append(ctx_mw) - client, server_d = create_direct_dispatcher_pair() - runner = ServerRunner(server=server, dispatcher=server_d, lifespan_state=None, has_standalone_channel=True) - c_req, c_notify = echo_handlers(Recorder()) - async with anyio.create_task_group() as tg: - await tg.start(client.run, c_req, c_notify) - await tg.start(server_d.run, runner._on_request, runner._on_notify) - with anyio.fail_after(5): - await client.send_raw_request("initialize", _initialize_params()) - await client.send_raw_request("ping", None) - await client.send_raw_request("tools/list", None) - # initialize NOT wrapped; ping and tools/list ARE wrapped. - assert seen_methods == ["ping", "tools/list"] - tg.cancel_scope.cancel() + async with connected_runner(server) as (client, _): + await client.send_raw_request("ping", None) + await client.send_raw_request("tools/list", None) + # initialize (sent by the helper) NOT wrapped; ping and tools/list ARE. + assert seen_methods == ["ping", "tools/list"] @pytest.mark.anyio @@ -259,103 +239,102 @@ async def mw(ctx: Any, method: str, params: Any, call_next: Any) -> Any: return mw server.middleware.extend([make_mw("a"), make_mw("b")]) - client, server_d = create_direct_dispatcher_pair() - runner = ServerRunner(server=server, dispatcher=server_d, lifespan_state=None, has_standalone_channel=True) - runner._initialized = True - c_req, c_notify = echo_handlers(Recorder()) - async with anyio.create_task_group() as tg: - await tg.start(client.run, c_req, c_notify) - await tg.start(server_d.run, runner._on_request, runner._on_notify) - with anyio.fail_after(5): - await client.send_raw_request("tools/list", None) - assert order == ["a-in", "b-in", "b-out", "a-out"] - tg.cancel_scope.cancel() + async with connected_runner(server) as (client, _): + await client.send_raw_request("tools/list", None) + assert order == ["a-in", "b-in", "b-out", "a-out"] @pytest.mark.anyio -async def test_runner_run_drives_dispatcher_end_to_end(server: SrvT): - client, server_d = create_direct_dispatcher_pair() - runner = ServerRunner(server=server, dispatcher=server_d, lifespan_state=None, has_standalone_channel=True) - c_req, c_notify = echo_handlers(Recorder()) - async with anyio.create_task_group() as tg: - await tg.start(client.run, c_req, c_notify) - await tg.start(runner.run) - with anyio.fail_after(5): - init = await client.send_raw_request("initialize", _initialize_params()) - tools = await client.send_raw_request("tools/list", None) - assert init["serverInfo"]["name"] == "test-server" - assert tools["tools"][0]["name"] == "t" - tg.cancel_scope.cancel() +async def test_runner_handler_returning_none_yields_empty_result(server: SrvT): + async def set_level(ctx: Any, params: Any) -> None: + return None + + server._request_handlers["logging/setLevel"] = set_level + async with connected_runner(server) as (client, _): + result = await client.send_raw_request("logging/setLevel", {"level": "info"}) + assert result == {} @pytest.mark.anyio -async def test_runner_run_applies_dispatch_middleware(server: SrvT): - seen: list[str] = [] +async def test_runner_handler_returning_unsupported_type_surfaces_as_internal_error(server: SrvT): + async def bad_return(ctx: Any, params: Any) -> int: + return 42 - def trace_mw(next_on_request: Any) -> Any: - async def wrapped(dctx: Any, method: str, params: Any) -> Any: - seen.append(method) - return await next_on_request(dctx, method, params) + server._request_handlers["tools/list"] = bad_return + async with connected_runner(server) as (client, _): + with pytest.raises(MCPError) as exc: + await client.send_raw_request("tools/list", None) + assert exc.value.error.code == INTERNAL_ERROR + assert "int" in exc.value.error.message - return wrapped - client, server_d = create_direct_dispatcher_pair() - runner = ServerRunner( - server=server, - dispatcher=server_d, - lifespan_state=None, - has_standalone_channel=True, - dispatch_middleware=[trace_mw], - ) - c_req, c_notify = echo_handlers(Recorder()) - async with anyio.create_task_group() as tg: - await tg.start(client.run, c_req, c_notify) - await tg.start(runner.run) - with anyio.fail_after(5): - await client.send_raw_request("initialize", _initialize_params()) - await client.send_raw_request("ping", None) - assert seen == ["initialize", "ping"] - tg.cancel_scope.cancel() +@pytest.mark.anyio +async def test_runner_stateless_skips_init_gate(server: SrvT): + async with connected_runner(server, initialized=False, stateless=True, has_standalone_channel=False) as (client, _): + result = await client.send_raw_request("tools/list", None) + assert result["tools"][0]["name"] == "t" @pytest.mark.anyio -async def test_otel_middleware_passes_through_result_and_survives_handler_error(server: SrvT): - client, server_d = create_direct_dispatcher_pair() - runner = ServerRunner( - server=server, - dispatcher=server_d, - lifespan_state=None, - has_standalone_channel=True, - dispatch_middleware=[otel_middleware], - ) - c_req, c_notify = echo_handlers(Recorder()) - async with anyio.create_task_group() as tg: - await tg.start(client.run, c_req, c_notify) - await tg.start(runner.run) - with anyio.fail_after(5): - await client.send_raw_request("initialize", _initialize_params()) - tools = await client.send_raw_request("tools/list", None) - assert tools["tools"][0]["name"] == "t" - with pytest.raises(MCPError): - await client.send_raw_request("nonexistent/method", None) - tg.cancel_scope.cancel() +async def test_otel_middleware_emits_server_span_with_method_and_target(server: SrvT, spans: InMemorySpanExporter): + async def call_tool(ctx: Any, params: Any) -> dict[str, Any]: + return {"content": [], "isError": False} + + server._request_handlers["tools/call"] = call_tool + async with connected_runner(server, dispatch_middleware=[otel_middleware]) as (client, _): + spans.clear() + result = await client.send_raw_request("tools/call", {"name": "mytool", "arguments": {}}) + assert result == {"content": [], "isError": False} + [span] = spans.get_finished_spans() + assert span.name == "MCP handle tools/call mytool" + assert span.kind == SpanKind.SERVER + assert span.attributes is not None + assert span.attributes["mcp.method.name"] == "tools/call" + assert span.status.status_code == StatusCode.UNSET @pytest.mark.anyio -async def test_runner_stateless_skips_init_gate(server: SrvT): - client, server_d = create_direct_dispatcher_pair() - runner = ServerRunner( - server=server, - dispatcher=server_d, - lifespan_state=None, - has_standalone_channel=False, - stateless=True, - ) - c_req, c_notify = echo_handlers(Recorder()) - async with anyio.create_task_group() as tg: - await tg.start(client.run, c_req, c_notify) - await tg.start(server_d.run, runner._on_request, runner._on_notify) - with anyio.fail_after(5): - result = await client.send_raw_request("tools/list", None) - assert result["tools"][0]["name"] == "t" - tg.cancel_scope.cancel() +async def test_otel_middleware_extracts_parent_context_from_meta(server: SrvT, spans: InMemorySpanExporter): + parent_span_id = "b7ad6b7169203331" + traceparent = f"00-0af7651916cd43dd8448eb211c80319c-{parent_span_id}-01" + async with connected_runner(server, dispatch_middleware=[otel_middleware]) as (client, _): + spans.clear() + await client.send_raw_request("tools/list", {"_meta": {"traceparent": traceparent}}) + [span] = spans.get_finished_spans() + assert span.parent is not None + assert format(span.parent.span_id, "016x") == parent_span_id + assert span.context is not None + assert format(span.context.trace_id, "032x") == "0af7651916cd43dd8448eb211c80319c" + + +@pytest.mark.anyio +async def test_otel_middleware_records_error_status_on_mcp_error(server: SrvT, spans: InMemorySpanExporter): + async with connected_runner(server, dispatch_middleware=[otel_middleware]) as (client, _): + spans.clear() + with pytest.raises(MCPError) as exc: + await client.send_raw_request("nonexistent/method", None) + assert exc.value.error.code == METHOD_NOT_FOUND + [span] = spans.get_finished_spans() + assert span.status.status_code == StatusCode.ERROR + assert span.status.description == "Method not found: nonexistent/method" + # MCPError is a protocol-level response, not a crash — no traceback event. + assert not [e for e in span.events if e.name == "exception"] + + +@pytest.mark.anyio +async def test_otel_middleware_records_error_status_on_handler_exception(server: SrvT, spans: InMemorySpanExporter): + async def failing(ctx: Any, params: Any) -> Any: + raise ValueError("handler blew up") + + server._request_handlers["tools/list"] = failing + async with connected_runner(server, dispatch_middleware=[otel_middleware]) as (client, _): + spans.clear() + with pytest.raises(MCPError) as exc: + await client.send_raw_request("tools/list", None) + assert exc.value.error.code == INTERNAL_ERROR + [span] = spans.get_finished_spans() + assert span.status.status_code == StatusCode.ERROR + assert span.status.description == "handler blew up" + [event] = [e for e in span.events if e.name == "exception"] + assert event.attributes is not None + assert event.attributes["exception.type"] == "ValueError" diff --git a/uv.lock b/uv.lock index 5b72e97fce..86ea2f5fb7 100644 --- a/uv.lock +++ b/uv.lock @@ -885,6 +885,7 @@ dev = [ { name = "inline-snapshot" }, { name = "logfire" }, { name = "mcp", extra = ["cli", "ws"] }, + { name = "opentelemetry-sdk" }, { name = "pillow" }, { name = "pyright" }, { name = "pytest" }, @@ -937,6 +938,7 @@ dev = [ { name = "inline-snapshot", specifier = ">=0.23.0" }, { name = "logfire", specifier = ">=3.0.0" }, { name = "mcp", extras = ["cli", "ws"], editable = "." }, + { name = "opentelemetry-sdk", specifier = ">=1.39.1" }, { name = "pillow", specifier = ">=12.0" }, { name = "pyright", specifier = ">=1.1.400" }, { name = "pytest", specifier = ">=8.4.0" }, From e3076e7e40dbb01334f7f2ccc38a37bf822b93b8 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Sat, 25 Apr 2026 22:14:12 +0000 Subject: [PATCH 24/27] test: converge span capture on capfire to fix xdist order-dependence MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The previous tests/server/conftest.py called trace.set_tracer_provider() directly, which is set-once per process and raced against logfire's capfire fixture (tests/shared/test_otel.py) under xdist — whichever ran first in a worker won, the other's tests broke. Converge on capfire as the single span-capture owner since logfire.configure() already handles repeat calls by swapping span processors instead of re-setting the provider: - tests/conftest.py: set LOGFIRE_DISTRIBUTED_TRACING=true so propagation tests don't trip logfire's 'found propagated trace context' RuntimeWarning. - tests/server/conftest.py: SpanCapture adapter over capfire.exporter — filters to the mcp-python-sdk instrumentation scope and excludes logfire's pending_span markers, so tests assert on raw ReadableSpan without importing logfire types. - tests/shared/test_otel.py: drop the now-unneeded filterwarnings decorator. --- tests/conftest.py | 11 ++++++++ tests/server/conftest.py | 55 ++++++++++++++++++++++--------------- tests/server/test_runner.py | 18 ++++++------ tests/shared/test_otel.py | 3 -- 4 files changed, 53 insertions(+), 34 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index af7e479932..b83c472135 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,5 +1,16 @@ +import os + import pytest +# OpenTelemetry's `set_tracer_provider` is set-once per process, so the suite +# uses a single span-capture mechanism: logfire's `capfire` fixture (its +# `configure()` swaps span processors on repeat calls rather than re-setting +# the provider). Logfire's default `distributed_tracing=None` emits a +# RuntimeWarning + diagnostic span when incoming W3C trace context is +# extracted; several tests exercise that propagation deliberately, so opt in +# suite-wide. Set before logfire is imported anywhere. +os.environ.setdefault("LOGFIRE_DISTRIBUTED_TRACING", "true") + @pytest.fixture def anyio_backend(): diff --git a/tests/server/conftest.py b/tests/server/conftest.py index 37202f529e..290ccc957a 100644 --- a/tests/server/conftest.py +++ b/tests/server/conftest.py @@ -3,32 +3,43 @@ from collections.abc import Iterator import pytest -from opentelemetry import trace -from opentelemetry.sdk.trace import TracerProvider -from opentelemetry.sdk.trace.export import SimpleSpanProcessor -from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter +from logfire.testing import CaptureLogfire, TestExporter +from opentelemetry.sdk.trace import ReadableSpan -_span_exporter = InMemorySpanExporter() +class SpanCapture: + """Thin adapter over logfire's `TestExporter` for asserting on MCP spans. -@pytest.fixture(scope="session") -def _tracer_provider() -> TracerProvider: - """Install a real OTel SDK tracer provider once per test session. - - The runtime dependency is ``opentelemetry-api`` only, which yields no-op - ``NonRecordingSpan`` objects. Tests that need to assert on emitted spans - request the `spans` fixture, which depends on this one to make the global - tracer record into an in-memory exporter. + `finished()` returns the raw `ReadableSpan` objects emitted by the + ``mcp-python-sdk`` instrumentation scope, filtered to exclude logfire's + synthetic ``pending_span`` markers, so tests can assert directly on + `.name`, `.kind`, `.status`, `.attributes`, `.parent`, `.events`. """ - provider = TracerProvider() - provider.add_span_processor(SimpleSpanProcessor(_span_exporter)) - trace.set_tracer_provider(provider) - return provider + + def __init__(self, exporter: TestExporter) -> None: + self._exporter = exporter + + def clear(self) -> None: + self._exporter.clear() + + def finished(self) -> list[ReadableSpan]: + return [ + s + for s in self._exporter.exported_spans + if s.instrumentation_scope is not None + and s.instrumentation_scope.name == "mcp-python-sdk" + and not (s.attributes and s.attributes.get("logfire.span_type") == "pending_span") + ] @pytest.fixture -def spans(_tracer_provider: TracerProvider) -> Iterator[InMemorySpanExporter]: - """In-memory OTel span exporter, cleared before and after each test.""" - _span_exporter.clear() - yield _span_exporter - _span_exporter.clear() +def spans(capfire: CaptureLogfire) -> Iterator[SpanCapture]: + """In-memory MCP span capture, cleared before and after each test. + + Backed by the project-level `capfire` override (see ``tests/conftest.py``) + so there is a single global tracer provider for the suite. + """ + capture = SpanCapture(capfire.exporter) + capture.clear() + yield capture + capture.clear() diff --git a/tests/server/test_runner.py b/tests/server/test_runner.py index 2006bf6486..843b0ae8b9 100644 --- a/tests/server/test_runner.py +++ b/tests/server/test_runner.py @@ -13,7 +13,6 @@ import anyio import anyio.lowlevel import pytest -from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter from opentelemetry.trace import SpanKind, StatusCode from mcp.server.connection import Connection @@ -36,6 +35,7 @@ ) from ..shared.test_dispatcher import Recorder, echo_handlers +from .conftest import SpanCapture def _initialize_params() -> dict[str, Any]: @@ -276,7 +276,7 @@ async def test_runner_stateless_skips_init_gate(server: SrvT): @pytest.mark.anyio -async def test_otel_middleware_emits_server_span_with_method_and_target(server: SrvT, spans: InMemorySpanExporter): +async def test_otel_middleware_emits_server_span_with_method_and_target(server: SrvT, spans: SpanCapture): async def call_tool(ctx: Any, params: Any) -> dict[str, Any]: return {"content": [], "isError": False} @@ -285,7 +285,7 @@ async def call_tool(ctx: Any, params: Any) -> dict[str, Any]: spans.clear() result = await client.send_raw_request("tools/call", {"name": "mytool", "arguments": {}}) assert result == {"content": [], "isError": False} - [span] = spans.get_finished_spans() + [span] = spans.finished() assert span.name == "MCP handle tools/call mytool" assert span.kind == SpanKind.SERVER assert span.attributes is not None @@ -294,13 +294,13 @@ async def call_tool(ctx: Any, params: Any) -> dict[str, Any]: @pytest.mark.anyio -async def test_otel_middleware_extracts_parent_context_from_meta(server: SrvT, spans: InMemorySpanExporter): +async def test_otel_middleware_extracts_parent_context_from_meta(server: SrvT, spans: SpanCapture): parent_span_id = "b7ad6b7169203331" traceparent = f"00-0af7651916cd43dd8448eb211c80319c-{parent_span_id}-01" async with connected_runner(server, dispatch_middleware=[otel_middleware]) as (client, _): spans.clear() await client.send_raw_request("tools/list", {"_meta": {"traceparent": traceparent}}) - [span] = spans.get_finished_spans() + [span] = spans.finished() assert span.parent is not None assert format(span.parent.span_id, "016x") == parent_span_id assert span.context is not None @@ -308,13 +308,13 @@ async def test_otel_middleware_extracts_parent_context_from_meta(server: SrvT, s @pytest.mark.anyio -async def test_otel_middleware_records_error_status_on_mcp_error(server: SrvT, spans: InMemorySpanExporter): +async def test_otel_middleware_records_error_status_on_mcp_error(server: SrvT, spans: SpanCapture): async with connected_runner(server, dispatch_middleware=[otel_middleware]) as (client, _): spans.clear() with pytest.raises(MCPError) as exc: await client.send_raw_request("nonexistent/method", None) assert exc.value.error.code == METHOD_NOT_FOUND - [span] = spans.get_finished_spans() + [span] = spans.finished() assert span.status.status_code == StatusCode.ERROR assert span.status.description == "Method not found: nonexistent/method" # MCPError is a protocol-level response, not a crash — no traceback event. @@ -322,7 +322,7 @@ async def test_otel_middleware_records_error_status_on_mcp_error(server: SrvT, s @pytest.mark.anyio -async def test_otel_middleware_records_error_status_on_handler_exception(server: SrvT, spans: InMemorySpanExporter): +async def test_otel_middleware_records_error_status_on_handler_exception(server: SrvT, spans: SpanCapture): async def failing(ctx: Any, params: Any) -> Any: raise ValueError("handler blew up") @@ -332,7 +332,7 @@ async def failing(ctx: Any, params: Any) -> Any: with pytest.raises(MCPError) as exc: await client.send_raw_request("tools/list", None) assert exc.value.error.code == INTERNAL_ERROR - [span] = spans.get_finished_spans() + [span] = spans.finished() assert span.status.status_code == StatusCode.ERROR assert span.status.description == "handler blew up" [event] = [e for e in span.events if e.name == "exception"] diff --git a/tests/shared/test_otel.py b/tests/shared/test_otel.py index ec7ff78cc1..a7df4c4294 100644 --- a/tests/shared/test_otel.py +++ b/tests/shared/test_otel.py @@ -10,9 +10,6 @@ pytestmark = pytest.mark.anyio -# Logfire warns about propagated trace context by default (distributed_tracing=None). -# This is expected here since we're testing cross-boundary context propagation. -@pytest.mark.filterwarnings("ignore::RuntimeWarning") async def test_client_and_server_spans(capfire: CaptureLogfire): """Verify that calling a tool produces client and server spans with correct attributes.""" server = MCPServer("test") From 6c7953e3792eabaf25338e1a730b523392ba46e6 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Thu, 30 Apr 2026 17:29:55 +0000 Subject: [PATCH 25/27] feat: Server registry stores HandlerEntry; ServerRunner consumes Server[L] directly MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Server is generic in LifespanResultT only — no TransportContextT. Spike (scratch/spike-tt-on-server) found a third generic breaks bare-Server plumbing helpers via invariance and only buys one None-check; it remains additive later via PEP 696 default if demand materialises. TT stays on the transport layer (Dispatcher/DispatchContext/BaseContext in mcp.shared); the server layer (Server/Context/ServerRunner/ServerMiddleware) consumes base TransportContext. - HandlerEntry[L] frozen dataclass (params_type, handler) replaces bare callables in the registry; params type erased to Any in storage, correlated at add_request_handler[P] - Public add_request_handler/add_notification_handler; capabilities() zero-arg (notification_options/experimental_capabilities now ctor kwargs) - ServerRunner drops the ServerRegistry Protocol scaffold and reads Server[L] directly; _make_context no longer narrows dctx - ServerMiddleware[L] (one contravariant param) - Context[L] (BaseContext[TransportContext] fixed) --- src/mcp/server/context.py | 17 ++-- src/mcp/server/lowlevel/server.py | 147 +++++++++++++++++++--------- src/mcp/server/runner.py | 122 +++++------------------ tests/server/test_runner.py | 76 ++++++++++---- tests/server/test_server_context.py | 12 +-- 5 files changed, 197 insertions(+), 177 deletions(-) diff --git a/src/mcp/server/context.py b/src/mcp/server/context.py index 1c855ae48a..1cf2be1899 100644 --- a/src/mcp/server/context.py +++ b/src/mcp/server/context.py @@ -33,10 +33,9 @@ class ServerRequestContext(RequestContext[ServerSession], Generic[LifespanContex LifespanT = TypeVar("LifespanT", default=Any, covariant=True) -TransportT = TypeVar("TransportT", bound=TransportContext, default=TransportContext, covariant=True) -class Context(BaseContext[TransportT], PeerMixin, TypedServerRequestMixin, Generic[LifespanT, TransportT]): +class Context(BaseContext[TransportContext], PeerMixin, TypedServerRequestMixin, Generic[LifespanT]): """Server-side per-request context. Composes `BaseContext` (forwards to `DispatchContext`, satisfies `Outbound`), @@ -50,7 +49,7 @@ class Context(BaseContext[TransportT], PeerMixin, TypedServerRequestMixin, Gener def __init__( self, - dctx: DispatchContext[TransportT], + dctx: DispatchContext[TransportContext], *, lifespan: LifespanT, connection: Connection, @@ -94,7 +93,7 @@ async def log(self, level: LoggingLevel, data: Any, logger: str | None = None, * _MwLifespanT = TypeVar("_MwLifespanT", contravariant=True) -class ContextMiddleware(Protocol[_MwLifespanT]): +class ServerMiddleware(Protocol[_MwLifespanT]): """Context-tier middleware: ``(ctx, method, typed_params, call_next) -> result``. Runs *inside* `ServerRunner._on_request` after params validation and @@ -102,15 +101,15 @@ class ContextMiddleware(Protocol[_MwLifespanT]): not ``initialize``, ``METHOD_NOT_FOUND``, or validation failures. Listed outermost-first on `Server.middleware`. - `Server[L].middleware` holds `ContextMiddleware[L]`, so an app-specific - middleware sees `ctx.lifespan: L`. A reusable middleware (no app-specific - types) can be typed `ContextMiddleware[object]` — `Context` is covariant in - `LifespanT`, so it registers on any `Server[L]`. + `Server[L].middleware` holds `ServerMiddleware[L]`, so an app-specific + middleware sees `ctx.lifespan: L`. A reusable middleware can be typed + `ServerMiddleware[object]` — `Context` is covariant in `LifespanT`, so it + registers on any `Server[L]`. """ async def __call__( self, - ctx: Context[_MwLifespanT, TransportContext], + ctx: Context[_MwLifespanT], method: str, params: BaseModel, call_next: CallNext, diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index a863246a18..375ca94c0d 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -41,11 +41,13 @@ async def main(): import warnings from collections.abc import AsyncIterator, Awaitable, Callable from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontextmanager +from dataclasses import dataclass from importlib.metadata import version as importlib_version from typing import Any, Generic, cast import anyio from opentelemetry.trace import SpanKind, StatusCode +from pydantic import BaseModel from starlette.applications import Starlette from starlette.middleware import Middleware from starlette.middleware.authentication import AuthenticationMiddleware @@ -58,7 +60,7 @@ async def main(): from mcp.server.auth.provider import OAuthAuthorizationServerProvider, TokenVerifier from mcp.server.auth.routes import build_resource_metadata_url, create_auth_routes, create_protected_resource_routes from mcp.server.auth.settings import AuthSettings -from mcp.server.context import ContextMiddleware, ServerRequestContext +from mcp.server.context import HandlerResult, ServerMiddleware, ServerRequestContext from mcp.server.experimental.request_context import Experimental from mcp.server.lowlevel.experimental import ExperimentalHandlers from mcp.server.models import InitializationOptions @@ -76,6 +78,30 @@ async def main(): LifespanResultT = TypeVar("LifespanResultT", default=Any) +_ParamsT = TypeVar("_ParamsT", bound=BaseModel, default=BaseModel) + +RequestHandler = Callable[[ServerRequestContext[LifespanResultT], _ParamsT], Awaitable[HandlerResult]] +"""A registered request handler: ``(ctx, params) -> result``.""" + +NotificationHandler = Callable[[ServerRequestContext[LifespanResultT], _ParamsT], Awaitable[None]] +"""A registered notification handler: ``(ctx, params) -> None``.""" + + +@dataclass(frozen=True, slots=True) +class HandlerEntry(Generic[LifespanResultT]): + """A registered handler and the params model to validate incoming params against. + + Stored in `Server._request_handlers` / `_notification_handlers` and consumed + by `ServerRunner` to validate, build `Context`, and invoke. The handler's + second-argument type is erased to ``Any`` in storage (each entry has a + different concrete params type and `Callable` parameters are contravariant); + the precise type is recoverable via `params_type`. The correlation is + enforced at registration time by `Server.add_request_handler`. + """ + + params_type: type[BaseModel] + handler: RequestHandler[LifespanResultT, Any] + class NotificationOptions: def __init__(self, prompts_changed: bool = False, resources_changed: bool = False, tools_changed: bool = False): @@ -85,7 +111,7 @@ def __init__(self, prompts_changed: bool = False, resources_changed: bool = Fals @asynccontextmanager -async def lifespan(_: Server[LifespanResultT]) -> AsyncIterator[dict[str, Any]]: +async def lifespan(_: Server[Any]) -> AsyncIterator[dict[str, Any]]: """Default lifespan context manager that does nothing. Returns: @@ -109,6 +135,8 @@ def __init__( instructions: str | None = None, website_url: str | None = None, icons: list[types.Icon] | None = None, + notification_options: NotificationOptions | None = None, + experimental_capabilities: dict[str, dict[str, Any]] | None = None, lifespan: Callable[ [Server[LifespanResultT]], AbstractAsyncContextManager[LifespanResultT], @@ -193,57 +221,77 @@ def __init__( self.website_url = website_url self.icons = icons self.lifespan = lifespan - self._request_handlers: dict[str, Callable[[ServerRequestContext[LifespanResultT], Any], Awaitable[Any]]] = {} - self._notification_handlers: dict[ - str, Callable[[ServerRequestContext[LifespanResultT], Any], Awaitable[None]] - ] = {} + self._notification_options = notification_options or NotificationOptions() + self._experimental_capabilities = experimental_capabilities or {} + self._request_handlers: dict[str, HandlerEntry[LifespanResultT]] = {} + self._notification_handlers: dict[str, HandlerEntry[LifespanResultT]] = {} self._experimental_handlers: ExperimentalHandlers[LifespanResultT] | None = None self._session_manager: StreamableHTTPSessionManager | None = None # Context-tier middleware consumed by `ServerRunner`. Additive; the # existing `run()` path ignores it. - self.middleware: list[ContextMiddleware[LifespanResultT]] = [] + self.middleware: list[ServerMiddleware[LifespanResultT]] = [] logger.debug("Initializing server %r", name) - # Populate internal handler dicts from on_* kwargs - self._request_handlers.update( - { - method: handler - for method, handler in { - "ping": on_ping, - "prompts/list": on_list_prompts, - "prompts/get": on_get_prompt, - "resources/list": on_list_resources, - "resources/templates/list": on_list_resource_templates, - "resources/read": on_read_resource, - "resources/subscribe": on_subscribe_resource, - "resources/unsubscribe": on_unsubscribe_resource, - "tools/list": on_list_tools, - "tools/call": on_call_tool, - "logging/setLevel": on_set_logging_level, - "completion/complete": on_completion, - }.items() - if handler is not None - } - ) + _spec_requests: list[tuple[str, type[BaseModel], RequestHandler[LifespanResultT, Any] | None]] = [ + ("ping", types.RequestParams, on_ping), + ("prompts/list", types.PaginatedRequestParams, on_list_prompts), + ("prompts/get", types.GetPromptRequestParams, on_get_prompt), + ("resources/list", types.PaginatedRequestParams, on_list_resources), + ("resources/templates/list", types.PaginatedRequestParams, on_list_resource_templates), + ("resources/read", types.ReadResourceRequestParams, on_read_resource), + ("resources/subscribe", types.SubscribeRequestParams, on_subscribe_resource), + ("resources/unsubscribe", types.UnsubscribeRequestParams, on_unsubscribe_resource), + ("tools/list", types.PaginatedRequestParams, on_list_tools), + ("tools/call", types.CallToolRequestParams, on_call_tool), + ("logging/setLevel", types.SetLevelRequestParams, on_set_logging_level), + ("completion/complete", types.CompleteRequestParams, on_completion), + ] + self._request_handlers.update({m: HandlerEntry(pt, h) for m, pt, h in _spec_requests if h is not None}) + _spec_notifications: list[tuple[str, type[BaseModel], NotificationHandler[LifespanResultT, Any] | None]] = [ + ("notifications/roots/list_changed", types.NotificationParams, on_roots_list_changed), + ("notifications/progress", types.ProgressNotificationParams, on_progress), + ] self._notification_handlers.update( - { - method: handler - for method, handler in { - "notifications/roots/list_changed": on_roots_list_changed, - "notifications/progress": on_progress, - }.items() - if handler is not None - } + {m: HandlerEntry(pt, h) for m, pt, h in _spec_notifications if h is not None} ) + def add_request_handler( + self, + method: str, + params_type: type[_ParamsT], + handler: RequestHandler[LifespanResultT, _ParamsT], + ) -> None: + """Register a request handler for ``method``. + + ``params_type`` is the model incoming params are validated against + before the handler is invoked. It should subclass `RequestParams` so + ``_meta`` parses uniformly. Replaces any existing handler for the same + method (no collision guard against spec methods). + """ + self._request_handlers[method] = HandlerEntry(params_type, handler) + + def add_notification_handler( + self, + method: str, + params_type: type[_ParamsT], + handler: NotificationHandler[LifespanResultT, _ParamsT], + ) -> None: + """Register a notification handler for ``method``. + + ``params_type`` should subclass `NotificationParams` so ``_meta`` + parses uniformly. Replaces any existing handler. + """ + self._notification_handlers[method] = HandlerEntry(params_type, handler) + def _add_request_handler( self, method: str, - handler: Callable[[ServerRequestContext[LifespanResultT], Any], Awaitable[Any]], + handler: RequestHandler[LifespanResultT, Any], ) -> None: - """Add a request handler, silently replacing any existing handler for the same method.""" - self._request_handlers[method] = handler + # TODO: remove once experimental tasks plumbing and remaining callers + # migrate to `add_request_handler` with an explicit params_type. + self.add_request_handler(method, types.RequestParams, handler) def _has_handler(self, method: str) -> bool: """Check if a handler is registered for the given method.""" @@ -251,14 +299,18 @@ def _has_handler(self, method: str) -> bool: # --- ServerRegistry protocol (consumed by ServerRunner) ------------------ - def get_request_handler(self, method: str) -> Callable[..., Awaitable[Any]] | None: - """Return the handler for a request method, or ``None``.""" + def get_request_handler(self, method: str) -> HandlerEntry[LifespanResultT] | None: + """Return the registered entry for a request method, or ``None``.""" return self._request_handlers.get(method) - def get_notification_handler(self, method: str) -> Callable[..., Awaitable[Any]] | None: - """Return the handler for a notification method, or ``None``.""" + def get_notification_handler(self, method: str) -> HandlerEntry[LifespanResultT] | None: + """Return the registered entry for a notification method, or ``None``.""" return self._notification_handlers.get(method) + def capabilities(self) -> types.ServerCapabilities: + """Derive `ServerCapabilities` from registered handlers and constructor options.""" + return self.get_capabilities(self._notification_options, self._experimental_capabilities) + # TODO: Rethink capabilities API. Currently capabilities are derived from registered # handlers but require NotificationOptions to be passed externally for list_changed # flags, and experimental_capabilities as a separate dict. Consider deriving capabilities @@ -474,7 +526,8 @@ async def _handle_request( attributes={"mcp.method.name": req.method, "jsonrpc.request.id": message.request_id}, context=parent_context, ) as span: - if handler := self._request_handlers.get(req.method): + if entry := self._request_handlers.get(req.method): + handler = entry.handler logger.debug("Dispatching request of type %s", type(req).__name__) try: @@ -533,7 +586,8 @@ async def _handle_request( span.set_status(StatusCode.ERROR, response.message) try: - await message.respond(response) + # TODO: cast goes away when `_handle_request` is deleted. + await message.respond(cast(types.ServerResult | types.ErrorData, response)) except (anyio.BrokenResourceError, anyio.ClosedResourceError): # Transport closed between handler unblocking and respond. Happens # when _receive_loop's finally wakes a handler blocked on @@ -552,7 +606,8 @@ async def _handle_notification( session: ServerSession, lifespan_context: LifespanResultT, ) -> None: - if handler := self._notification_handlers.get(notify.method): + if entry := self._notification_handlers.get(notify.method): + handler = entry.handler logger.debug("Dispatching notification of type %s", type(notify).__name__) try: diff --git a/src/mcp/server/runner.py b/src/mcp/server/runner.py index bb3af04435..1ef711d020 100644 --- a/src/mcp/server/runner.py +++ b/src/mcp/server/runner.py @@ -10,17 +10,16 @@ `Context`, runs the middleware chain, returns the result dict * drives ``dispatcher.run()`` and the per-connection lifespan -`ServerRunner` consumes any `ServerRegistry` — the lowlevel `Server` satisfies -it via additive methods so the existing ``Server.run()`` path is unaffected. +`ServerRunner` holds a `Server` directly — `Server` is the registry. """ from __future__ import annotations import logging -from collections.abc import Awaitable, Callable, Mapping, Sequence +from collections.abc import Mapping from dataclasses import dataclass, field from functools import partial, reduce -from typing import Any, Generic, Protocol, cast +from typing import Any, Generic, cast import anyio.abc from opentelemetry.trace import SpanKind, StatusCode @@ -28,8 +27,8 @@ from typing_extensions import TypeVar from mcp.server.connection import Connection -from mcp.server.context import CallNext, Context, ContextMiddleware -from mcp.server.lowlevel.server import NotificationOptions +from mcp.server.context import CallNext, Context, ServerMiddleware +from mcp.server.lowlevel.server import Server from mcp.shared._otel import extract_trace_context, otel_span from mcp.shared.dispatcher import DispatchContext, Dispatcher, DispatchMiddleware, OnRequest from mcp.shared.exceptions import MCPError @@ -38,87 +37,20 @@ INVALID_REQUEST, LATEST_PROTOCOL_VERSION, METHOD_NOT_FOUND, - CallToolRequestParams, - CompleteRequestParams, - GetPromptRequestParams, Implementation, InitializeRequestParams, InitializeResult, - NotificationParams, - PaginatedRequestParams, - ProgressNotificationParams, - ReadResourceRequestParams, - RequestParams, - ServerCapabilities, - SetLevelRequestParams, - SubscribeRequestParams, - UnsubscribeRequestParams, ) -__all__ = ["CallNext", "ContextMiddleware", "ServerRegistry", "ServerRunner", "otel_middleware"] +__all__ = ["CallNext", "ServerMiddleware", "ServerRunner", "otel_middleware"] logger = logging.getLogger(__name__) LifespanT = TypeVar("LifespanT", default=Any) -ServerTransportT = TypeVar("ServerTransportT", bound=TransportContext, default=TransportContext) - -Handler = Callable[..., Awaitable[Any]] -"""A request/notification handler: ``(ctx, params) -> result``. Typed loosely -so the existing `ServerRequestContext`-based handlers and the new -`Context`-based handlers both fit during the transition. -""" _INIT_EXEMPT: frozenset[str] = frozenset({"ping"}) -# TODO: remove this lookup once `Server` stores (params_type, handler) in its -# registry directly. This is scaffolding so ServerRunner can validate params -# without changing the existing `_request_handlers` dict shape. -_PARAMS_FOR_METHOD: dict[str, type[BaseModel]] = { - "ping": RequestParams, - "tools/list": PaginatedRequestParams, - "tools/call": CallToolRequestParams, - "prompts/list": PaginatedRequestParams, - "prompts/get": GetPromptRequestParams, - "resources/list": PaginatedRequestParams, - "resources/templates/list": PaginatedRequestParams, - "resources/read": ReadResourceRequestParams, - "resources/subscribe": SubscribeRequestParams, - "resources/unsubscribe": UnsubscribeRequestParams, - "logging/setLevel": SetLevelRequestParams, - "completion/complete": CompleteRequestParams, -} -"""Spec method → params model. Scaffolding while the lowlevel `Server`'s -`_request_handlers` stores handler-only; the registry refactor should make this -the registry's responsibility (or store params types alongside handlers).""" - -_PARAMS_FOR_NOTIFICATION: dict[str, type[BaseModel]] = { - "notifications/initialized": NotificationParams, - "notifications/roots/list_changed": NotificationParams, - "notifications/progress": ProgressNotificationParams, -} - - -class ServerRegistry(Protocol): - """The handler registry `ServerRunner` consumes. - - The lowlevel `Server` satisfies this via additive methods. - """ - - @property - def name(self) -> str: ... - @property - def version(self) -> str | None: ... - - @property - def middleware(self) -> Sequence[ContextMiddleware[Any]]: ... - - def get_request_handler(self, method: str) -> Handler | None: ... - def get_notification_handler(self, method: str) -> Handler | None: ... - def get_capabilities( - self, notification_options: Any, experimental_capabilities: dict[str, dict[str, Any]] - ) -> ServerCapabilities: ... - def otel_middleware(next_on_request: OnRequest) -> OnRequest: """Dispatch-tier middleware that wraps each request in an OpenTelemetry span. @@ -177,11 +109,11 @@ def _dump_result(result: Any) -> dict[str, Any]: @dataclass -class ServerRunner(Generic[LifespanT, ServerTransportT]): +class ServerRunner(Generic[LifespanT]): """Per-connection orchestrator. One instance per client connection.""" - server: ServerRegistry - dispatcher: Dispatcher[ServerTransportT] + server: Server[LifespanT] + dispatcher: Dispatcher[TransportContext] lifespan_state: LifespanT has_standalone_channel: bool stateless: bool = False @@ -227,17 +159,15 @@ async def _on_request( code=INVALID_REQUEST, message=f"Received {method!r} before initialization was complete", ) - handler = self.server.get_request_handler(method) - if handler is None: + entry = self.server.get_request_handler(method) + if entry is None: raise MCPError(code=METHOD_NOT_FOUND, message=f"Method not found: {method}") - # TODO: scaffolding — params_type comes from a static lookup until the - # registry stores it alongside the handler. - params_type = _PARAMS_FOR_METHOD.get(method, RequestParams) # ValidationError propagates; the dispatcher's exception boundary maps # it to INVALID_PARAMS. - typed_params = params_type.model_validate(params or {}) + typed_params = entry.params_type.model_validate(params or {}) ctx = self._make_context(dctx, typed_params) - call: CallNext = partial(handler, ctx, typed_params) + # TODO: cast goes away when `ServerRequestContext = Context` lands. + call: CallNext = partial(cast(Any, entry.handler), ctx, typed_params) for mw in reversed(self.server.middleware): call = partial(mw, ctx, method, typed_params, call) return _dump_result(await call()) @@ -255,24 +185,18 @@ async def _on_notify( if not self._initialized: logger.debug("dropped %s: received before initialization", method) return - handler = self.server.get_notification_handler(method) - if handler is None: + entry = self.server.get_notification_handler(method) + if entry is None: logger.debug("no handler for notification %s", method) return - params_type = _PARAMS_FOR_NOTIFICATION.get(method, NotificationParams) - typed_params = params_type.model_validate(params or {}) + typed_params = entry.params_type.model_validate(params or {}) ctx = self._make_context(dctx, typed_params) - await handler(ctx, typed_params) - - def _make_context( - self, dctx: DispatchContext[TransportContext], typed_params: BaseModel - ) -> Context[LifespanT, ServerTransportT]: - # `OnRequest` delivers `DispatchContext[TransportContext]`; this - # ServerRunner instance was constructed for a specific - # `ServerTransportT`, so the narrow is safe by construction. - narrowed = cast(DispatchContext[ServerTransportT], dctx) + # TODO: cast goes away when `ServerRequestContext = Context` lands. + await cast(Any, entry.handler)(ctx, typed_params) + + def _make_context(self, dctx: DispatchContext[TransportContext], typed_params: BaseModel) -> Context[LifespanT]: meta = getattr(typed_params, "meta", None) - return Context(narrowed, lifespan=self.lifespan_state, connection=self.connection, meta=meta) + return Context(dctx, lifespan=self.lifespan_state, connection=self.connection, meta=meta) def _handle_initialize(self, params: Mapping[str, Any] | None) -> dict[str, Any]: init = InitializeRequestParams.model_validate(params or {}) @@ -289,7 +213,7 @@ def _handle_initialize(self, params: Mapping[str, Any] | None) -> dict[str, Any] self.connection.initialized.set() result = InitializeResult( protocol_version=self.connection.protocol_version, - capabilities=self.server.get_capabilities(NotificationOptions(), {}), + capabilities=self.server.capabilities(), server_info=Implementation(name=self.server.name, version=self.server.version or "0.0.0"), ) return _dump_result(result) diff --git a/tests/server/test_runner.py b/tests/server/test_runner.py index 843b0ae8b9..5ece8b9cfb 100644 --- a/tests/server/test_runner.py +++ b/tests/server/test_runner.py @@ -8,7 +8,7 @@ from collections.abc import AsyncIterator from contextlib import asynccontextmanager -from typing import Any +from typing import Any, cast import anyio import anyio.lowlevel @@ -17,20 +17,25 @@ from mcp.server.connection import Connection from mcp.server.context import Context -from mcp.server.lowlevel.server import Server +from mcp.server.lowlevel.server import NotificationOptions, Server from mcp.server.runner import ServerRunner, otel_middleware from mcp.shared.direct_dispatcher import DirectDispatcher, create_direct_dispatcher_pair from mcp.shared.dispatcher import DispatchMiddleware from mcp.shared.exceptions import MCPError -from mcp.shared.transport_context import TransportContext from mcp.types import ( INTERNAL_ERROR, INVALID_REQUEST, LATEST_PROTOCOL_VERSION, METHOD_NOT_FOUND, + CallToolRequestParams, ClientCapabilities, Implementation, InitializeRequestParams, + ListToolsResult, + NotificationParams, + PaginatedRequestParams, + RequestParams, + SetLevelRequestParams, Tool, ) @@ -46,7 +51,7 @@ def _initialize_params() -> dict[str, Any]: ).model_dump(by_alias=True, exclude_none=True) -_seen_ctx: list[Context[Any, TransportContext]] = [] +_seen_ctx: list[Context[Any]] = [] SrvT = Server[dict[str, Any]] @@ -55,12 +60,11 @@ def server() -> SrvT: """A lowlevel Server with one tools/list handler registered.""" _seen_ctx.clear() - async def list_tools(ctx: Any, params: Any) -> Any: - # ctx is typed `Any` because Server's on_list_tools kwarg expects the - # legacy ServerRequestContext shape; ServerRunner passes the new - # `Context`. The transition is intentional — Handler is loosely typed. + async def list_tools(ctx: Any, params: PaginatedRequestParams | None) -> ListToolsResult: + # ctx is `Any` while `on_*` kwargs are typed against `ServerRequestContext` + # but `ServerRunner` passes the new `Context`; tightens once the alias lands. _seen_ctx.append(ctx) - return {"tools": [Tool(name="t", input_schema={"type": "object"}).model_dump(by_alias=True)]} + return ListToolsResult(tools=[Tool(name="t", input_schema={"type": "object"})]) return Server(name="test-server", version="0.0.1", on_list_tools=list_tools) @@ -73,7 +77,7 @@ async def connected_runner( stateless: bool = False, has_standalone_channel: bool = True, dispatch_middleware: list[DispatchMiddleware] | None = None, -) -> AsyncIterator[tuple[DirectDispatcher, ServerRunner[None, TransportContext]]]: +) -> AsyncIterator[tuple[DirectDispatcher, ServerRunner[dict[str, Any]]]]: """Yield ``(client, runner)`` running over an in-memory dispatcher pair. Starts the client (echo handlers) and `runner.run()` in a task group, wraps @@ -85,7 +89,7 @@ async def connected_runner( runner = ServerRunner( server=server, dispatcher=server_d, - lifespan_state=None, + lifespan_state={}, has_standalone_channel=has_standalone_channel, stateless=stateless, dispatch_middleware=dispatch_middleware or [], @@ -147,7 +151,7 @@ async def test_runner_routes_to_handler_and_builds_context(server: SrvT): assert result["tools"][0]["name"] == "t" ctx = _seen_ctx[0] assert isinstance(ctx, Context) - assert ctx.lifespan is None + assert ctx.lifespan == {} assert isinstance(ctx.connection, Connection) assert ctx.transport.kind == "direct" @@ -175,7 +179,7 @@ async def test_runner_on_notify_routes_to_registered_handler(server: SrvT): async def on_roots_changed(ctx: Any, params: Any) -> None: seen.append((ctx, params)) - server._notification_handlers["notifications/roots/list_changed"] = on_roots_changed + server.add_notification_handler("notifications/roots/list_changed", NotificationParams, on_roots_changed) async with connected_runner(server) as (client, _): await client.notify("notifications/roots/list_changed", None) # DirectDispatcher delivers synchronously; one yield is enough. @@ -249,7 +253,7 @@ async def test_runner_handler_returning_none_yields_empty_result(server: SrvT): async def set_level(ctx: Any, params: Any) -> None: return None - server._request_handlers["logging/setLevel"] = set_level + server.add_request_handler("logging/setLevel", SetLevelRequestParams, set_level) async with connected_runner(server) as (client, _): result = await client.send_raw_request("logging/setLevel", {"level": "info"}) assert result == {} @@ -260,7 +264,9 @@ async def test_runner_handler_returning_unsupported_type_surfaces_as_internal_er async def bad_return(ctx: Any, params: Any) -> int: return 42 - server._request_handlers["tools/list"] = bad_return + # cast: deliberately registering a handler with a bad return type to + # exercise the runtime check; pyright would (correctly) reject it otherwise. + server.add_request_handler("tools/list", PaginatedRequestParams, cast(Any, bad_return)) async with connected_runner(server) as (client, _): with pytest.raises(MCPError) as exc: await client.send_raw_request("tools/list", None) @@ -275,12 +281,48 @@ async def test_runner_stateless_skips_init_gate(server: SrvT): assert result["tools"][0]["name"] == "t" +@pytest.mark.anyio +async def test_server_add_request_handler_routes_custom_method_with_validated_params(server: SrvT): + class GreetParams(RequestParams): + name: str + + received: list[GreetParams] = [] + + async def greet(ctx: Any, params: GreetParams) -> dict[str, Any]: + received.append(params) + return {"greeting": f"hello {params.name}"} + + server.add_request_handler("custom/greet", GreetParams, greet) + async with connected_runner(server) as (client, _): + result = await client.send_raw_request("custom/greet", {"name": "world"}) + assert result == {"greeting": "hello world"} + assert isinstance(received[0], GreetParams) + assert received[0].name == "world" + + +@pytest.mark.anyio +async def test_server_capabilities_reflects_ctor_options_in_initialize_result(): + async def list_tools(ctx: Any, params: PaginatedRequestParams | None) -> ListToolsResult: + raise NotImplementedError + + server: SrvT = Server( + name="caps-test", + on_list_tools=list_tools, + notification_options=NotificationOptions(tools_changed=True), + experimental_capabilities={"ext": {"k": "v"}}, + ) + async with connected_runner(server, initialized=False) as (client, _): + result = await client.send_raw_request("initialize", _initialize_params()) + assert result["capabilities"]["tools"]["listChanged"] is True + assert result["capabilities"]["experimental"] == {"ext": {"k": "v"}} + + @pytest.mark.anyio async def test_otel_middleware_emits_server_span_with_method_and_target(server: SrvT, spans: SpanCapture): async def call_tool(ctx: Any, params: Any) -> dict[str, Any]: return {"content": [], "isError": False} - server._request_handlers["tools/call"] = call_tool + server.add_request_handler("tools/call", CallToolRequestParams, call_tool) async with connected_runner(server, dispatch_middleware=[otel_middleware]) as (client, _): spans.clear() result = await client.send_raw_request("tools/call", {"name": "mytool", "arguments": {}}) @@ -326,7 +368,7 @@ async def test_otel_middleware_records_error_status_on_handler_exception(server: async def failing(ctx: Any, params: Any) -> Any: raise ValueError("handler blew up") - server._request_handlers["tools/list"] = failing + server.add_request_handler("tools/list", PaginatedRequestParams, failing) async with connected_runner(server, dispatch_middleware=[otel_middleware]) as (client, _): spans.clear() with pytest.raises(MCPError) as exc: diff --git a/tests/server/test_server_context.py b/tests/server/test_server_context.py index e01de34d33..43c2069a87 100644 --- a/tests/server/test_server_context.py +++ b/tests/server/test_server_context.py @@ -31,11 +31,11 @@ class _Lifespan: @pytest.mark.anyio async def test_context_exposes_lifespan_and_connection_and_forwards_base_context(): - captured: list[Context[_Lifespan, TransportContext]] = [] + captured: list[Context[_Lifespan]] = [] conn = Connection.__new__(Connection) # placeholder until running_pair gives us the dispatcher async def server_on_request(dctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: - ctx: Context[_Lifespan, TransportContext] = Context(dctx, lifespan=_Lifespan("app"), connection=conn) + ctx: Context[_Lifespan] = Context(dctx, lifespan=_Lifespan("app"), connection=conn) captured.append(ctx) return {} @@ -62,7 +62,7 @@ async def client_on_request(dctx: DCtx, method: str, params: Mapping[str, Any] | results: list[CreateMessageResult] = [] async def server_on_request(dctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: - ctx: Context[_Lifespan, TransportContext] = Context( + ctx: Context[_Lifespan] = Context( dctx, lifespan=_Lifespan("app"), connection=Connection(dctx, has_standalone_channel=True) ) results.append( @@ -92,7 +92,7 @@ async def client_on_request(dctx: DCtx, method: str, params: Mapping[str, Any] | results: list[ListRootsResult] = [] async def server_on_request(dctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: - ctx: Context[_Lifespan, TransportContext] = Context( + ctx: Context[_Lifespan] = Context( dctx, lifespan=_Lifespan("app"), connection=Connection(dctx, has_standalone_channel=True) ) results.append(await ctx.send_request(ListRootsRequest())) @@ -113,7 +113,7 @@ async def test_context_log_sends_request_scoped_message_notification(): _, c_notify = echo_handlers(crec) async def server_on_request(dctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: - ctx: Context[_Lifespan, TransportContext] = Context( + ctx: Context[_Lifespan] = Context( dctx, lifespan=_Lifespan("app"), connection=Connection(dctx, has_standalone_channel=True) ) await ctx.log("debug", "hello") @@ -137,7 +137,7 @@ async def test_context_log_includes_logger_and_meta_when_supplied(): _, c_notify = echo_handlers(crec) async def server_on_request(dctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: - ctx: Context[_Lifespan, TransportContext] = Context( + ctx: Context[_Lifespan] = Context( dctx, lifespan=_Lifespan("app"), connection=Connection(dctx, has_standalone_channel=True) ) await ctx.log("info", "x", logger="my.log", meta={"traceId": "t"}) From 0f265115fb4e35b4355e4431ca51cfe9b46fc672 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Thu, 7 May 2026 19:19:51 +0000 Subject: [PATCH 26/27] feat: Connection.state + exit_stack; ctx.session_id/headers; TransportContext.headers MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Per-connection state without a connection_lifespan CM or a second Server generic. Stateless is the default deployment, where a per-connection lifespan would wrap a single request; the enter-late mechanics it would need (race init vs dispatcher-done, ready-gate) were more machinery than the use case warrants. - Connection.session_id: str | None — set by the mount via ServerRunner(session_id=...); per-connection, not per-message - Connection.state: dict[str, Any] — scratch that persists across requests; handlers/middleware read and write freely - Connection.exit_stack: AsyncExitStack — handlers/middleware push CMs or callbacks for per-connection teardown; ServerRunner.run() unwinds it (shielded) in a finally after dispatcher.run() returns - TransportContext.headers: Mapping[str, str] | None on the base — populated by HTTP transports, None on stdio - Context.session_id / Context.headers convenience properties - create_direct_dispatcher_pair(headers=...) and connected_runner(session_id=..., headers=...) for tests --- src/mcp/server/connection.py | 26 +++++-- src/mcp/server/context.py | 19 ++++- src/mcp/server/runner.py | 15 +++- src/mcp/shared/direct_dispatcher.py | 4 +- src/mcp/shared/transport_context.py | 8 +++ tests/server/test_runner.py | 106 +++++++++++++++++++++++++++- 6 files changed, 163 insertions(+), 15 deletions(-) diff --git a/src/mcp/server/connection.py b/src/mcp/server/connection.py index df3652ce0e..5991715a44 100644 --- a/src/mcp/server/connection.py +++ b/src/mcp/server/connection.py @@ -1,9 +1,10 @@ """`Connection` — per-client connection state and the standalone outbound channel. Always present on `Context` (never ``None``), even in stateless deployments. -Holds peer info populated at ``initialize`` time, the per-connection lifespan -output, and an `Outbound` for the standalone stream (the SSE GET stream in -streamable HTTP, or the single duplex stream in stdio). +Holds peer info populated at ``initialize`` time, per-connection scratch +``state`` and an ``exit_stack`` for teardown, and an `Outbound` for the +standalone stream (the SSE GET stream in streamable HTTP, or the single duplex +stream in stdio). `notify` is best-effort: it never raises. If there's no standalone channel (stateless HTTP) or the stream has been dropped, the notification is @@ -14,6 +15,7 @@ import logging from collections.abc import Mapping +from contextlib import AsyncExitStack from typing import Any import anyio @@ -44,17 +46,27 @@ class Connection(TypedServerRequestMixin): ``None`` until ``initialize`` completes; ``initialized`` is set then. """ - def __init__(self, outbound: Outbound, *, has_standalone_channel: bool) -> None: + def __init__(self, outbound: Outbound, *, has_standalone_channel: bool, session_id: str | None = None) -> None: self._outbound = outbound self.has_standalone_channel = has_standalone_channel + self.session_id: str | None = session_id self.client_info: Implementation | None = None self.client_capabilities: ClientCapabilities | None = None self.protocol_version: str | None = None self.initialized: anyio.Event = anyio.Event() - # TODO: make this generic (Connection[StateT]) once connection_lifespan - # wiring lands in ServerRunner. - self.state: Any = None + + self.state: dict[str, Any] = {} + """Per-connection scratch state. Handlers and middleware may read and + write freely; persists across requests on this connection.""" + + self.exit_stack: AsyncExitStack = AsyncExitStack() + """Cleanup stack unwound by `ServerRunner` when the connection closes. + + Push context managers (``await exit_stack.enter_async_context(...)``) + or callbacks (``exit_stack.push_async_callback(...)``) from handlers or + middleware to register per-connection teardown. Unwound LIFO after + `dispatcher.run()` returns, shielded from cancellation.""" async def send_raw_request( self, diff --git a/src/mcp/server/context.py b/src/mcp/server/context.py index 1cf2be1899..d1514a9add 100644 --- a/src/mcp/server/context.py +++ b/src/mcp/server/context.py @@ -1,6 +1,6 @@ from __future__ import annotations -from collections.abc import Awaitable, Callable +from collections.abc import Awaitable, Callable, Mapping from dataclasses import dataclass from typing import Any, Generic, Protocol @@ -69,6 +69,23 @@ def connection(self) -> Connection: """The per-client `Connection` for this request's connection.""" return self._connection + @property + def session_id(self) -> str | None: + """The transport's session id for this connection, when one exists. + + Convenience for ``ctx.connection.session_id``. ``None`` on stdio and + stateless HTTP. + """ + return self._connection.session_id + + @property + def headers(self) -> Mapping[str, str] | None: + """Request headers carried by this message, when the transport has them. + + Convenience for ``ctx.transport.headers``. ``None`` on stdio. + """ + return self.transport.headers + async def log(self, level: LoggingLevel, data: Any, logger: str | None = None, *, meta: Meta | None = None) -> None: """Send a request-scoped ``notifications/message`` log entry. diff --git a/src/mcp/server/runner.py b/src/mcp/server/runner.py index 1ef711d020..1ba732ec4b 100644 --- a/src/mcp/server/runner.py +++ b/src/mcp/server/runner.py @@ -116,6 +116,7 @@ class ServerRunner(Generic[LifespanT]): dispatcher: Dispatcher[TransportContext] lifespan_state: LifespanT has_standalone_channel: bool + session_id: str | None = None stateless: bool = False dispatch_middleware: list[DispatchMiddleware] = field(default_factory=list[DispatchMiddleware]) @@ -124,7 +125,9 @@ class ServerRunner(Generic[LifespanT]): def __post_init__(self) -> None: self._initialized = self.stateless - self.connection = Connection(self.dispatcher, has_standalone_channel=self.has_standalone_channel) + self.connection = Connection( + self.dispatcher, has_standalone_channel=self.has_standalone_channel, session_id=self.session_id + ) async def run(self, *, task_status: anyio.abc.TaskStatus[None] = anyio.TASK_STATUS_IGNORED) -> None: """Drive the dispatcher until the underlying channel closes. @@ -132,9 +135,15 @@ async def run(self, *, task_status: anyio.abc.TaskStatus[None] = anyio.TASK_STAT Composes `dispatch_middleware` over `_on_request` and hands the result to `dispatcher.run()`. ``task_status.started()`` is forwarded so callers can ``await tg.start(runner.run)`` and resume once the dispatcher is - ready to accept requests. + ready to accept requests. Once the dispatcher exits, + `connection.exit_stack` is unwound (shielded) so any per-connection + cleanup registered by handlers or middleware runs to completion. """ - await self.dispatcher.run(self._compose_on_request(), self._on_notify, task_status=task_status) + try: + await self.dispatcher.run(self._compose_on_request(), self._on_notify, task_status=task_status) + finally: + with anyio.CancelScope(shield=True): + await self.connection.exit_stack.aclose() def _compose_on_request(self) -> OnRequest: """Wrap `_on_request` in `dispatch_middleware`, outermost-first. diff --git a/src/mcp/shared/direct_dispatcher.py b/src/mcp/shared/direct_dispatcher.py index 27443ec874..1842cf8abc 100644 --- a/src/mcp/shared/direct_dispatcher.py +++ b/src/mcp/shared/direct_dispatcher.py @@ -162,18 +162,20 @@ async def _dispatch_notify(self, method: str, params: Mapping[str, Any] | None) def create_direct_dispatcher_pair( *, can_send_request: bool = True, + headers: Mapping[str, str] | None = None, ) -> tuple[DirectDispatcher, DirectDispatcher]: """Create two `DirectDispatcher` instances wired to each other. Args: can_send_request: Sets `TransportContext.can_send_request` on both sides. Pass ``False`` to simulate a transport with no back-channel. + headers: Sets `TransportContext.headers` on both sides. Returns: A ``(left, right)`` pair. Conventionally ``left`` is the client side and ``right`` is the server side, but the wiring is symmetric. """ - ctx = TransportContext(kind=DIRECT_TRANSPORT_KIND, can_send_request=can_send_request) + ctx = TransportContext(kind=DIRECT_TRANSPORT_KIND, can_send_request=can_send_request, headers=headers) left = DirectDispatcher(ctx) right = DirectDispatcher(ctx) left.connect_to(right) diff --git a/src/mcp/shared/transport_context.py b/src/mcp/shared/transport_context.py index 832cead515..9346116707 100644 --- a/src/mcp/shared/transport_context.py +++ b/src/mcp/shared/transport_context.py @@ -6,6 +6,7 @@ dispatcher (`ServerRunner`, `Context`, user handlers) read its concrete fields. """ +from collections.abc import Mapping from dataclasses import dataclass __all__ = ["TransportContext"] @@ -28,3 +29,10 @@ class TransportContext: stdio, SSE, and stateful streamable HTTP. When ``False``, `DispatchContext.send_raw_request` raises `NoBackChannelError`. """ + + headers: Mapping[str, str] | None = None + """Request headers carried by this message, when the transport has them. + + Populated by HTTP-based transports; ``None`` on stdio. Handlers should + None-check before use. + """ diff --git a/tests/server/test_runner.py b/tests/server/test_runner.py index 5ece8b9cfb..33df234dbe 100644 --- a/tests/server/test_runner.py +++ b/tests/server/test_runner.py @@ -6,8 +6,8 @@ under test. """ -from collections.abc import AsyncIterator -from contextlib import asynccontextmanager +from collections.abc import AsyncIterator, Mapping +from contextlib import AbstractAsyncContextManager, asynccontextmanager from typing import Any, cast import anyio @@ -76,6 +76,8 @@ async def connected_runner( initialized: bool = True, stateless: bool = False, has_standalone_channel: bool = True, + session_id: str | None = None, + headers: Mapping[str, str] | None = None, dispatch_middleware: list[DispatchMiddleware] | None = None, ) -> AsyncIterator[tuple[DirectDispatcher, ServerRunner[dict[str, Any]]]]: """Yield ``(client, runner)`` running over an in-memory dispatcher pair. @@ -85,12 +87,13 @@ async def connected_runner( ``initialized`` is true the helper performs the real ``initialize`` request before yielding, so tests start past the init-gate via the public path. """ - client, server_d = create_direct_dispatcher_pair() + client, server_d = create_direct_dispatcher_pair(headers=headers) runner = ServerRunner( server=server, dispatcher=server_d, lifespan_state={}, has_standalone_channel=has_standalone_channel, + session_id=session_id, stateless=stateless, dispatch_middleware=dispatch_middleware or [], ) @@ -380,3 +383,100 @@ async def failing(ctx: Any, params: Any) -> Any: [event] = [e for e in span.events if e.name == "exception"] assert event.attributes is not None assert event.attributes["exception.type"] == "ValueError" + + +@pytest.mark.anyio +async def test_connection_state_persists_across_requests_on_same_connection(server: SrvT) -> None: + async def count(ctx: Any, params: PaginatedRequestParams | None) -> ListToolsResult: + ctx.connection.state["n"] = ctx.connection.state.get("n", 0) + 1 + return ListToolsResult(tools=[]) + + server.add_request_handler("tools/list", PaginatedRequestParams, count) + async with connected_runner(server) as (client, runner): + await client.send_raw_request("tools/list", None) + await client.send_raw_request("tools/list", None) + assert runner.connection.state == {"n": 2} + + +@pytest.mark.anyio +async def test_connection_exit_stack_runs_pushed_callback_after_close(server: SrvT) -> None: + cleaned: list[str] = [] + + async def push(ctx: Any, params: PaginatedRequestParams | None) -> ListToolsResult: + async def _cleanup() -> None: + cleaned.append("done") + + ctx.connection.exit_stack.push_async_callback(_cleanup) + return ListToolsResult(tools=[]) + + server.add_request_handler("tools/list", PaginatedRequestParams, push) + async with connected_runner(server) as (client, _runner): + await client.send_raw_request("tools/list", None) + assert cleaned == [] + assert cleaned == ["done"] + + +@pytest.mark.anyio +async def test_connection_exit_stack_unwinds_entered_context_manager_after_close(server: SrvT) -> None: + events: list[str] = [] + + class _Tracker(AbstractAsyncContextManager[str]): + async def __aenter__(self) -> str: + events.append("enter") + return "resource" + + async def __aexit__(self, *exc: object) -> None: + events.append("exit") + + async def acquire(ctx: Any, params: PaginatedRequestParams | None) -> ListToolsResult: + res = await ctx.connection.exit_stack.enter_async_context(_Tracker()) + ctx.connection.state["res"] = res + return ListToolsResult(tools=[]) + + server.add_request_handler("tools/list", PaginatedRequestParams, acquire) + async with connected_runner(server) as (client, runner): + await client.send_raw_request("tools/list", None) + assert events == ["enter"] + assert runner.connection.state["res"] == "resource" + assert events == ["enter", "exit"] + + +@pytest.mark.anyio +async def test_connection_exit_stack_runs_callbacks_lifo_after_handler_error(server: SrvT) -> None: + cleaned: list[int] = [] + + async def push_then_fail(ctx: Any, params: PaginatedRequestParams | None) -> ListToolsResult: + for i in (1, 2, 3): + ctx.connection.exit_stack.push_async_callback(_append, i) + raise RuntimeError("boom") + + async def _append(i: int) -> None: + cleaned.append(i) + + server.add_request_handler("tools/list", PaginatedRequestParams, push_then_fail) + async with connected_runner(server) as (client, _runner): + with pytest.raises(MCPError) as ei: + await client.send_raw_request("tools/list", None) + assert ei.value.error.code == INTERNAL_ERROR + assert cleaned == [] + assert cleaned == [3, 2, 1] + + +@pytest.mark.anyio +async def test_context_session_id_and_headers_expose_connection_and_transport(server: SrvT) -> None: + async with connected_runner(server, session_id="sess-abc", headers={"authorization": "Bearer t"}) as (client, _r): + await client.send_raw_request("tools/list", None) + [ctx] = _seen_ctx + assert ctx.session_id == "sess-abc" + assert ctx.session_id == ctx.connection.session_id + assert ctx.headers == {"authorization": "Bearer t"} + assert ctx.headers is ctx.transport.headers + + +@pytest.mark.anyio +async def test_context_session_id_and_headers_default_none(server: SrvT) -> None: + async with connected_runner(server) as (client, _r): + await client.send_raw_request("tools/list", None) + [ctx] = _seen_ctx + assert ctx.session_id is None + assert ctx.headers is None From f2d4cba6232d71db56ab85416cb9ca2df6ad964a Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Fri, 8 May 2026 14:23:21 +0000 Subject: [PATCH 27/27] fix: JSONRPCDispatcher coerces string response/progress IDs to int for correlation Matches BaseSession._normalize_request_id and the TypeScript SDK: a peer that echoes the request ID as a JSON string still resolves the waiter. Applied at both lookup sites (_resolve_pending and the progress-token match). Parity prep for the PR6 e2e suite. --- src/mcp/shared/jsonrpc_dispatcher.py | 19 +++++- tests/shared/test_jsonrpc_dispatcher.py | 83 +++++++++++++++++++++++++ 2 files changed, 100 insertions(+), 2 deletions(-) diff --git a/src/mcp/shared/jsonrpc_dispatcher.py b/src/mcp/shared/jsonrpc_dispatcher.py index f1e7b3675e..b450bb66d5 100644 --- a/src/mcp/shared/jsonrpc_dispatcher.py +++ b/src/mcp/shared/jsonrpc_dispatcher.py @@ -76,6 +76,21 @@ `TransportContext(kind="jsonrpc", can_send_request=True)` when not supplied.""" +def _coerce_id(request_id: RequestId) -> RequestId: + """Coerce a string request ID to int when it's a valid int literal. + + `_allocate_id` only ever produces ``int`` keys for ``_pending``, but a peer + may echo the ID back as a JSON string. The TypeScript SDK and `BaseSession` + both perform this coercion at lookup time so the response still correlates. + """ + if isinstance(request_id, str): + try: + return int(request_id) + except ValueError: + pass + return request_id + + @dataclass(slots=True) class _Pending: """An outbound request awaiting its response.""" @@ -409,7 +424,7 @@ def _dispatch_notification( if msg.method == "notifications/progress": match msg.params: case {"progressToken": str() | int() as token, "progress": int() | float() as progress} if ( - pending := self._pending.get(token) + pending := self._pending.get(_coerce_id(token)) ) is not None and pending.on_progress is not None: total = msg.params.get("total") message = msg.params.get("message") @@ -428,7 +443,7 @@ def _dispatch_notification( self._spawn(on_notify, dctx, msg.method, msg.params, sender_ctx=sender_ctx) def _resolve_pending(self, request_id: RequestId | None, outcome: dict[str, Any] | ErrorData) -> None: - pending = self._pending.get(request_id) if request_id is not None else None + pending = self._pending.get(_coerce_id(request_id)) if request_id is not None else None if pending is None: logger.debug("dropping response for unknown/late request id %r", request_id) return diff --git a/tests/shared/test_jsonrpc_dispatcher.py b/tests/shared/test_jsonrpc_dispatcher.py index 7f9f11718b..5755b55d15 100644 --- a/tests/shared/test_jsonrpc_dispatcher.py +++ b/tests/shared/test_jsonrpc_dispatcher.py @@ -18,6 +18,7 @@ from mcp.shared.exceptions import MCPError from mcp.shared.jsonrpc_dispatcher import ( # pyright: ignore[reportPrivateUsage] JSONRPCDispatcher, + _coerce_id, _outbound_metadata, _Pending, ) @@ -29,6 +30,7 @@ INVALID_PARAMS, ErrorData, JSONRPCError, + JSONRPCNotification, JSONRPCRequest, JSONRPCResponse, Tool, @@ -511,6 +513,87 @@ def test_outbound_metadata_with_resumption_token_returns_client_metadata(): assert _outbound_metadata(None, {}) is None +@pytest.mark.anyio +async def test_response_with_string_id_correlates_to_int_keyed_pending_request(): + """A peer that echoes the request ID as a JSON string still resolves the waiter.""" + c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) + s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) + client: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(s2c_recv, c2s_send) + on_request, on_notify = echo_handlers(Recorder()) + try: + async with anyio.create_task_group() as tg: + await tg.start(client.run, on_request, on_notify) + with anyio.fail_after(5): + + async def respond_stringly() -> None: + out = await c2s_recv.receive() + assert isinstance(out, SessionMessage) + assert isinstance(out.message, JSONRPCRequest) + rid = out.message.id + assert isinstance(rid, int) + await s2c_send.send( + SessionMessage(message=JSONRPCResponse(jsonrpc="2.0", id=str(rid), result={"ok": True})) + ) + + tg.start_soon(respond_stringly) + result = await client.send_raw_request("ping", None) + assert result == {"ok": True} + tg.cancel_scope.cancel() + finally: + for s in (c2s_send, c2s_recv, s2c_send, s2c_recv): + s.close() + + +@pytest.mark.anyio +async def test_progress_with_string_token_reaches_callback_for_int_keyed_request(): + c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) + s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) + client: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(s2c_recv, c2s_send) + on_request, on_notify = echo_handlers(Recorder()) + seen: list[float] = [] + try: + async with anyio.create_task_group() as tg: + await tg.start(client.run, on_request, on_notify) + with anyio.fail_after(5): + + async def respond_with_string_token_progress() -> None: + out = await c2s_recv.receive() + assert isinstance(out, SessionMessage) + assert isinstance(out.message, JSONRPCRequest) + rid = out.message.id + assert isinstance(rid, int) + await s2c_send.send( + SessionMessage( + message=JSONRPCNotification( + jsonrpc="2.0", + method="notifications/progress", + params={"progressToken": str(rid), "progress": 0.5}, + ) + ) + ) + await s2c_send.send( + SessionMessage(message=JSONRPCResponse(jsonrpc="2.0", id=rid, result={"ok": True})) + ) + + async def on_progress(progress: float, total: float | None, message: str | None) -> None: + seen.append(progress) + + tg.start_soon(respond_with_string_token_progress) + result = await client.send_raw_request("ping", None, {"on_progress": on_progress}) + assert result == {"ok": True} + tg.cancel_scope.cancel() + finally: + for s in (c2s_send, c2s_recv, s2c_send, s2c_recv): + s.close() + assert seen == [0.5] + + +def test_coerce_id_passes_through_non_numeric_string_and_int(): + assert _coerce_id("7") == 7 + assert _coerce_id("not-an-int") == "not-an-int" + assert _coerce_id(42) == 42 + + @pytest.mark.anyio async def test_jsonrpc_error_response_with_null_id_is_dropped(): """Parse-error responses (id=null) have no waiter; they're logged and dropped."""