diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index d34e438fc9..341df0abb8 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -5,7 +5,6 @@ on: branches: ["main", "v1.x"] tags: ["v*.*.*"] pull_request: - branches: ["main", "v1.x"] permissions: contents: read diff --git a/pyproject.toml b/pyproject.toml index 6d2319621a..5f51fa9b85 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -91,6 +91,7 @@ dev = [ "pillow>=12.0", "strict-no-cover", "logfire>=3.0.0", + "opentelemetry-sdk>=1.39.1", ] docs = [ "mkdocs>=1.6.1", diff --git a/src/mcp/server/_typed_request.py b/src/mcp/server/_typed_request.py new file mode 100644 index 0000000000..4334b20a94 --- /dev/null +++ b/src/mcp/server/_typed_request.py @@ -0,0 +1,86 @@ +"""Typed ``send_request`` for server-to-client requests. + +`TypedServerRequestMixin` provides a typed `send_request(req) -> Result` over +the host's raw `Outbound.send_raw_request`. Spec server-to-client request types +have their result type inferred via per-type overloads; custom requests pass +``result_type=`` explicitly. + +If the spec's request set grows substantially, consider declaring the result +mapping on the request types themselves (a ``__mcp_result__`` ClassVar read via +a structural protocol) so this overload ladder doesn't need maintaining +per-host-class. +""" + +from typing import Any, TypeVar, overload + +from pydantic import BaseModel + +from mcp.shared.dispatcher import CallOptions, Outbound +from mcp.shared.peer import dump_params +from mcp.types import ( + CreateMessageRequest, + CreateMessageResult, + ElicitRequest, + ElicitResult, + EmptyResult, + ListRootsRequest, + ListRootsResult, + PingRequest, + Request, +) + +__all__ = ["TypedServerRequestMixin"] + +ResultT = TypeVar("ResultT", bound=BaseModel) + +_RESULT_FOR: dict[type[Request[Any, Any]], type[BaseModel]] = { + CreateMessageRequest: CreateMessageResult, + ElicitRequest: ElicitResult, + ListRootsRequest: ListRootsResult, + PingRequest: EmptyResult, +} + + +class TypedServerRequestMixin: + """Typed ``send_request`` for the server-to-client request set. + + Mixed into `Connection` and the server `Context`. Each method constrains + ``self`` to `Outbound` so any host with ``send_raw_request`` works. + """ + + @overload + async def send_request( + self: Outbound, req: CreateMessageRequest, *, opts: CallOptions | None = None + ) -> CreateMessageResult: ... + @overload + async def send_request(self: Outbound, req: ElicitRequest, *, opts: CallOptions | None = None) -> ElicitResult: ... + @overload + async def send_request( + self: Outbound, req: ListRootsRequest, *, opts: CallOptions | None = None + ) -> ListRootsResult: ... + @overload + async def send_request(self: Outbound, req: PingRequest, *, opts: CallOptions | None = None) -> EmptyResult: ... + @overload + async def send_request( + self: Outbound, req: Request[Any, Any], *, result_type: type[ResultT], opts: CallOptions | None = None + ) -> ResultT: ... + async def send_request( + self: Outbound, + req: Request[Any, Any], + *, + result_type: type[BaseModel] | None = None, + opts: CallOptions | None = None, + ) -> BaseModel: + """Send a typed server-to-client request and return its typed result. + + For spec request types the result type is inferred. For custom requests + pass ``result_type=`` explicitly. + + Raises: + MCPError: The peer responded with an error. + NoBackChannelError: No back-channel for server-initiated requests. + KeyError: ``result_type`` omitted for a non-spec request type. + """ + raw = await self.send_raw_request(req.method, dump_params(req.params), opts) + cls = result_type if result_type is not None else _RESULT_FOR[type(req)] + return cls.model_validate(raw) diff --git a/src/mcp/server/connection.py b/src/mcp/server/connection.py new file mode 100644 index 0000000000..5991715a44 --- /dev/null +++ b/src/mcp/server/connection.py @@ -0,0 +1,158 @@ +"""`Connection` — per-client connection state and the standalone outbound channel. + +Always present on `Context` (never ``None``), even in stateless deployments. +Holds peer info populated at ``initialize`` time, per-connection scratch +``state`` and an ``exit_stack`` for teardown, and an `Outbound` for the +standalone stream (the SSE GET stream in streamable HTTP, or the single duplex +stream in stdio). + +`notify` is best-effort: it never raises. If there's no standalone channel +(stateless HTTP) or the stream has been dropped, the notification is +debug-logged and silently discarded — server-initiated notifications are +inherently advisory. `send_raw_request` *does* raise `NoBackChannelError` when +there's no channel; `ping` is the only spec-sanctioned standalone request. +""" + +import logging +from collections.abc import Mapping +from contextlib import AsyncExitStack +from typing import Any + +import anyio + +from mcp.server._typed_request import TypedServerRequestMixin +from mcp.shared.dispatcher import CallOptions, Outbound +from mcp.shared.exceptions import NoBackChannelError +from mcp.shared.peer import Meta, dump_params +from mcp.types import ClientCapabilities, Implementation, LoggingLevel + +__all__ = ["Connection"] + +logger = logging.getLogger(__name__) + + +def _notification_params(payload: dict[str, Any] | None, meta: Meta | None) -> dict[str, Any] | None: + if not meta: + return payload + out = dict(payload or {}) + out["_meta"] = meta + return out + + +class Connection(TypedServerRequestMixin): + """Per-client connection state and standalone-stream `Outbound`. + + Constructed by `ServerRunner` once per connection. The peer-info fields are + ``None`` until ``initialize`` completes; ``initialized`` is set then. + """ + + def __init__(self, outbound: Outbound, *, has_standalone_channel: bool, session_id: str | None = None) -> None: + self._outbound = outbound + self.has_standalone_channel = has_standalone_channel + self.session_id: str | None = session_id + + self.client_info: Implementation | None = None + self.client_capabilities: ClientCapabilities | None = None + self.protocol_version: str | None = None + self.initialized: anyio.Event = anyio.Event() + + self.state: dict[str, Any] = {} + """Per-connection scratch state. Handlers and middleware may read and + write freely; persists across requests on this connection.""" + + self.exit_stack: AsyncExitStack = AsyncExitStack() + """Cleanup stack unwound by `ServerRunner` when the connection closes. + + Push context managers (``await exit_stack.enter_async_context(...)``) + or callbacks (``exit_stack.push_async_callback(...)``) from handlers or + middleware to register per-connection teardown. Unwound LIFO after + `dispatcher.run()` returns, shielded from cancellation.""" + + async def send_raw_request( + self, + method: str, + params: Mapping[str, Any] | None, + opts: CallOptions | None = None, + ) -> dict[str, Any]: + """Send a raw request on the standalone stream. + + Low-level `Outbound` channel. Prefer the typed ``send_request`` (from + `TypedServerRequestMixin`) or the convenience methods below; use this + directly only for off-spec messages. + + Raises: + MCPError: The peer responded with an error. + NoBackChannelError: ``has_standalone_channel`` is ``False``. + """ + if not self.has_standalone_channel: + raise NoBackChannelError(method) + return await self._outbound.send_raw_request(method, params, opts) + + async def notify(self, method: str, params: Mapping[str, Any] | None) -> None: + """Send a best-effort notification on the standalone stream. + + Never raises. If there's no standalone channel or the stream is broken, + the notification is dropped and debug-logged. + """ + if not self.has_standalone_channel: + logger.debug("dropped %s: no standalone channel", method) + return + try: + await self._outbound.notify(method, params) + except (anyio.BrokenResourceError, anyio.ClosedResourceError): + logger.debug("dropped %s: standalone stream closed", method) + + async def ping(self, *, meta: Meta | None = None, opts: CallOptions | None = None) -> None: + """Send a ``ping`` request on the standalone stream. + + Raises: + MCPError: The peer responded with an error. + NoBackChannelError: ``has_standalone_channel`` is ``False``. + """ + await self.send_raw_request("ping", dump_params(None, meta), opts) + + async def log(self, level: LoggingLevel, data: Any, logger: str | None = None, *, meta: Meta | None = None) -> None: + """Send a ``notifications/message`` log entry on the standalone stream. Best-effort.""" + params: dict[str, Any] = {"level": level, "data": data} + if logger is not None: + params["logger"] = logger + await self.notify("notifications/message", _notification_params(params, meta)) + + async def send_tool_list_changed(self, *, meta: Meta | None = None) -> None: + await self.notify("notifications/tools/list_changed", _notification_params(None, meta)) + + async def send_prompt_list_changed(self, *, meta: Meta | None = None) -> None: + await self.notify("notifications/prompts/list_changed", _notification_params(None, meta)) + + async def send_resource_list_changed(self, *, meta: Meta | None = None) -> None: + await self.notify("notifications/resources/list_changed", _notification_params(None, meta)) + + async def send_resource_updated(self, uri: str, *, meta: Meta | None = None) -> None: + await self.notify("notifications/resources/updated", _notification_params({"uri": uri}, meta)) + + def check_capability(self, capability: ClientCapabilities) -> bool: + """Return whether the connected client declared the given capability. + + Returns ``False`` if ``initialize`` hasn't completed yet. + """ + # TODO: redesign — mirrors v1 ServerSession.check_client_capability + # verbatim for parity. + if self.client_capabilities is None: + return False + have = self.client_capabilities + if capability.roots is not None: + if have.roots is None: + return False + if capability.roots.list_changed and not have.roots.list_changed: + return False + if capability.sampling is not None and have.sampling is None: + return False + if capability.elicitation is not None and have.elicitation is None: + return False + if capability.experimental is not None: + if have.experimental is None: + return False + for k in capability.experimental: + if k not in have.experimental: + return False + return True diff --git a/src/mcp/server/context.py b/src/mcp/server/context.py index d8e11d78b2..d1514a9add 100644 --- a/src/mcp/server/context.py +++ b/src/mcp/server/context.py @@ -1,14 +1,23 @@ from __future__ import annotations +from collections.abc import Awaitable, Callable, Mapping from dataclasses import dataclass -from typing import Any, Generic +from typing import Any, Generic, Protocol +from pydantic import BaseModel from typing_extensions import TypeVar +from mcp.server._typed_request import TypedServerRequestMixin +from mcp.server.connection import Connection from mcp.server.experimental.request_context import Experimental from mcp.server.session import ServerSession from mcp.shared._context import RequestContext +from mcp.shared.context import BaseContext +from mcp.shared.dispatcher import DispatchContext from mcp.shared.message import CloseSSEStreamCallback +from mcp.shared.peer import Meta, PeerMixin +from mcp.shared.transport_context import TransportContext +from mcp.types import LoggingLevel, RequestParamsMeta LifespanContextT = TypeVar("LifespanContextT", default=dict[str, Any]) RequestT = TypeVar("RequestT", default=Any) @@ -21,3 +30,104 @@ class ServerRequestContext(RequestContext[ServerSession], Generic[LifespanContex request: RequestT | None = None close_sse_stream: CloseSSEStreamCallback | None = None close_standalone_sse_stream: CloseSSEStreamCallback | None = None + + +LifespanT = TypeVar("LifespanT", default=Any, covariant=True) + + +class Context(BaseContext[TransportContext], PeerMixin, TypedServerRequestMixin, Generic[LifespanT]): + """Server-side per-request context. + + Composes `BaseContext` (forwards to `DispatchContext`, satisfies `Outbound`), + `PeerMixin` (kwarg-style ``sample``/``elicit_*``/``list_roots``/``ping``), + and `TypedServerRequestMixin` (typed ``send_request(req) -> Result``). Adds + ``lifespan`` and ``connection``. + + Constructed by `ServerRunner` per inbound request and handed to the user's + handler. + """ + + def __init__( + self, + dctx: DispatchContext[TransportContext], + *, + lifespan: LifespanT, + connection: Connection, + meta: RequestParamsMeta | None = None, + ) -> None: + super().__init__(dctx, meta=meta) + self._lifespan = lifespan + self._connection = connection + + @property + def lifespan(self) -> LifespanT: + """The server-wide lifespan output (what `Server(..., lifespan=...)` yielded).""" + return self._lifespan + + @property + def connection(self) -> Connection: + """The per-client `Connection` for this request's connection.""" + return self._connection + + @property + def session_id(self) -> str | None: + """The transport's session id for this connection, when one exists. + + Convenience for ``ctx.connection.session_id``. ``None`` on stdio and + stateless HTTP. + """ + return self._connection.session_id + + @property + def headers(self) -> Mapping[str, str] | None: + """Request headers carried by this message, when the transport has them. + + Convenience for ``ctx.transport.headers``. ``None`` on stdio. + """ + return self.transport.headers + + async def log(self, level: LoggingLevel, data: Any, logger: str | None = None, *, meta: Meta | None = None) -> None: + """Send a request-scoped ``notifications/message`` log entry. + + Uses this request's back-channel (so the entry rides the request's SSE + stream in streamable HTTP), not the standalone stream — use + ``ctx.connection.log(...)`` for that. + """ + params: dict[str, Any] = {"level": level, "data": data} + if logger is not None: + params["logger"] = logger + if meta: + params["_meta"] = meta + await self.notify("notifications/message", params) + + +HandlerResult = BaseModel | dict[str, Any] | None +"""What a request handler (or middleware) may return. `ServerRunner` serializes +all three to a result dict.""" + +CallNext = Callable[[], Awaitable[HandlerResult]] + +_MwLifespanT = TypeVar("_MwLifespanT", contravariant=True) + + +class ServerMiddleware(Protocol[_MwLifespanT]): + """Context-tier middleware: ``(ctx, method, typed_params, call_next) -> result``. + + Runs *inside* `ServerRunner._on_request` after params validation and + `Context` construction. Wraps registered handlers (including ``ping``) but + not ``initialize``, ``METHOD_NOT_FOUND``, or validation failures. Listed + outermost-first on `Server.middleware`. + + `Server[L].middleware` holds `ServerMiddleware[L]`, so an app-specific + middleware sees `ctx.lifespan: L`. A reusable middleware can be typed + `ServerMiddleware[object]` — `Context` is covariant in `LifespanT`, so it + registers on any `Server[L]`. + """ + + async def __call__( + self, + ctx: Context[_MwLifespanT], + method: str, + params: BaseModel, + call_next: CallNext, + ) -> HandlerResult: ... diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index 5e4e2e6f5b..375ca94c0d 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -41,11 +41,13 @@ async def main(): import warnings from collections.abc import AsyncIterator, Awaitable, Callable from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontextmanager +from dataclasses import dataclass from importlib.metadata import version as importlib_version from typing import Any, Generic, cast import anyio from opentelemetry.trace import SpanKind, StatusCode +from pydantic import BaseModel from starlette.applications import Starlette from starlette.middleware import Middleware from starlette.middleware.authentication import AuthenticationMiddleware @@ -58,7 +60,7 @@ async def main(): from mcp.server.auth.provider import OAuthAuthorizationServerProvider, TokenVerifier from mcp.server.auth.routes import build_resource_metadata_url, create_auth_routes, create_protected_resource_routes from mcp.server.auth.settings import AuthSettings -from mcp.server.context import ServerRequestContext +from mcp.server.context import HandlerResult, ServerMiddleware, ServerRequestContext from mcp.server.experimental.request_context import Experimental from mcp.server.lowlevel.experimental import ExperimentalHandlers from mcp.server.models import InitializationOptions @@ -76,6 +78,30 @@ async def main(): LifespanResultT = TypeVar("LifespanResultT", default=Any) +_ParamsT = TypeVar("_ParamsT", bound=BaseModel, default=BaseModel) + +RequestHandler = Callable[[ServerRequestContext[LifespanResultT], _ParamsT], Awaitable[HandlerResult]] +"""A registered request handler: ``(ctx, params) -> result``.""" + +NotificationHandler = Callable[[ServerRequestContext[LifespanResultT], _ParamsT], Awaitable[None]] +"""A registered notification handler: ``(ctx, params) -> None``.""" + + +@dataclass(frozen=True, slots=True) +class HandlerEntry(Generic[LifespanResultT]): + """A registered handler and the params model to validate incoming params against. + + Stored in `Server._request_handlers` / `_notification_handlers` and consumed + by `ServerRunner` to validate, build `Context`, and invoke. The handler's + second-argument type is erased to ``Any`` in storage (each entry has a + different concrete params type and `Callable` parameters are contravariant); + the precise type is recoverable via `params_type`. The correlation is + enforced at registration time by `Server.add_request_handler`. + """ + + params_type: type[BaseModel] + handler: RequestHandler[LifespanResultT, Any] + class NotificationOptions: def __init__(self, prompts_changed: bool = False, resources_changed: bool = False, tools_changed: bool = False): @@ -85,7 +111,7 @@ def __init__(self, prompts_changed: bool = False, resources_changed: bool = Fals @asynccontextmanager -async def lifespan(_: Server[LifespanResultT]) -> AsyncIterator[dict[str, Any]]: +async def lifespan(_: Server[Any]) -> AsyncIterator[dict[str, Any]]: """Default lifespan context manager that does nothing. Returns: @@ -109,6 +135,8 @@ def __init__( instructions: str | None = None, website_url: str | None = None, icons: list[types.Icon] | None = None, + notification_options: NotificationOptions | None = None, + experimental_capabilities: dict[str, dict[str, Any]] | None = None, lifespan: Callable[ [Server[LifespanResultT]], AbstractAsyncContextManager[LifespanResultT], @@ -193,59 +221,96 @@ def __init__( self.website_url = website_url self.icons = icons self.lifespan = lifespan - self._request_handlers: dict[str, Callable[[ServerRequestContext[LifespanResultT], Any], Awaitable[Any]]] = {} - self._notification_handlers: dict[ - str, Callable[[ServerRequestContext[LifespanResultT], Any], Awaitable[None]] - ] = {} + self._notification_options = notification_options or NotificationOptions() + self._experimental_capabilities = experimental_capabilities or {} + self._request_handlers: dict[str, HandlerEntry[LifespanResultT]] = {} + self._notification_handlers: dict[str, HandlerEntry[LifespanResultT]] = {} self._experimental_handlers: ExperimentalHandlers[LifespanResultT] | None = None self._session_manager: StreamableHTTPSessionManager | None = None + # Context-tier middleware consumed by `ServerRunner`. Additive; the + # existing `run()` path ignores it. + self.middleware: list[ServerMiddleware[LifespanResultT]] = [] logger.debug("Initializing server %r", name) - # Populate internal handler dicts from on_* kwargs - self._request_handlers.update( - { - method: handler - for method, handler in { - "ping": on_ping, - "prompts/list": on_list_prompts, - "prompts/get": on_get_prompt, - "resources/list": on_list_resources, - "resources/templates/list": on_list_resource_templates, - "resources/read": on_read_resource, - "resources/subscribe": on_subscribe_resource, - "resources/unsubscribe": on_unsubscribe_resource, - "tools/list": on_list_tools, - "tools/call": on_call_tool, - "logging/setLevel": on_set_logging_level, - "completion/complete": on_completion, - }.items() - if handler is not None - } - ) + _spec_requests: list[tuple[str, type[BaseModel], RequestHandler[LifespanResultT, Any] | None]] = [ + ("ping", types.RequestParams, on_ping), + ("prompts/list", types.PaginatedRequestParams, on_list_prompts), + ("prompts/get", types.GetPromptRequestParams, on_get_prompt), + ("resources/list", types.PaginatedRequestParams, on_list_resources), + ("resources/templates/list", types.PaginatedRequestParams, on_list_resource_templates), + ("resources/read", types.ReadResourceRequestParams, on_read_resource), + ("resources/subscribe", types.SubscribeRequestParams, on_subscribe_resource), + ("resources/unsubscribe", types.UnsubscribeRequestParams, on_unsubscribe_resource), + ("tools/list", types.PaginatedRequestParams, on_list_tools), + ("tools/call", types.CallToolRequestParams, on_call_tool), + ("logging/setLevel", types.SetLevelRequestParams, on_set_logging_level), + ("completion/complete", types.CompleteRequestParams, on_completion), + ] + self._request_handlers.update({m: HandlerEntry(pt, h) for m, pt, h in _spec_requests if h is not None}) + _spec_notifications: list[tuple[str, type[BaseModel], NotificationHandler[LifespanResultT, Any] | None]] = [ + ("notifications/roots/list_changed", types.NotificationParams, on_roots_list_changed), + ("notifications/progress", types.ProgressNotificationParams, on_progress), + ] self._notification_handlers.update( - { - method: handler - for method, handler in { - "notifications/roots/list_changed": on_roots_list_changed, - "notifications/progress": on_progress, - }.items() - if handler is not None - } + {m: HandlerEntry(pt, h) for m, pt, h in _spec_notifications if h is not None} ) + def add_request_handler( + self, + method: str, + params_type: type[_ParamsT], + handler: RequestHandler[LifespanResultT, _ParamsT], + ) -> None: + """Register a request handler for ``method``. + + ``params_type`` is the model incoming params are validated against + before the handler is invoked. It should subclass `RequestParams` so + ``_meta`` parses uniformly. Replaces any existing handler for the same + method (no collision guard against spec methods). + """ + self._request_handlers[method] = HandlerEntry(params_type, handler) + + def add_notification_handler( + self, + method: str, + params_type: type[_ParamsT], + handler: NotificationHandler[LifespanResultT, _ParamsT], + ) -> None: + """Register a notification handler for ``method``. + + ``params_type`` should subclass `NotificationParams` so ``_meta`` + parses uniformly. Replaces any existing handler. + """ + self._notification_handlers[method] = HandlerEntry(params_type, handler) + def _add_request_handler( self, method: str, - handler: Callable[[ServerRequestContext[LifespanResultT], Any], Awaitable[Any]], + handler: RequestHandler[LifespanResultT, Any], ) -> None: - """Add a request handler, silently replacing any existing handler for the same method.""" - self._request_handlers[method] = handler + # TODO: remove once experimental tasks plumbing and remaining callers + # migrate to `add_request_handler` with an explicit params_type. + self.add_request_handler(method, types.RequestParams, handler) def _has_handler(self, method: str) -> bool: """Check if a handler is registered for the given method.""" return method in self._request_handlers or method in self._notification_handlers + # --- ServerRegistry protocol (consumed by ServerRunner) ------------------ + + def get_request_handler(self, method: str) -> HandlerEntry[LifespanResultT] | None: + """Return the registered entry for a request method, or ``None``.""" + return self._request_handlers.get(method) + + def get_notification_handler(self, method: str) -> HandlerEntry[LifespanResultT] | None: + """Return the registered entry for a notification method, or ``None``.""" + return self._notification_handlers.get(method) + + def capabilities(self) -> types.ServerCapabilities: + """Derive `ServerCapabilities` from registered handlers and constructor options.""" + return self.get_capabilities(self._notification_options, self._experimental_capabilities) + # TODO: Rethink capabilities API. Currently capabilities are derived from registered # handlers but require NotificationOptions to be passed externally for list_changed # flags, and experimental_capabilities as a separate dict. Consider deriving capabilities @@ -461,7 +526,8 @@ async def _handle_request( attributes={"mcp.method.name": req.method, "jsonrpc.request.id": message.request_id}, context=parent_context, ) as span: - if handler := self._request_handlers.get(req.method): + if entry := self._request_handlers.get(req.method): + handler = entry.handler logger.debug("Dispatching request of type %s", type(req).__name__) try: @@ -520,7 +586,8 @@ async def _handle_request( span.set_status(StatusCode.ERROR, response.message) try: - await message.respond(response) + # TODO: cast goes away when `_handle_request` is deleted. + await message.respond(cast(types.ServerResult | types.ErrorData, response)) except (anyio.BrokenResourceError, anyio.ClosedResourceError): # Transport closed between handler unblocking and respond. Happens # when _receive_loop's finally wakes a handler blocked on @@ -539,7 +606,8 @@ async def _handle_notification( session: ServerSession, lifespan_context: LifespanResultT, ) -> None: - if handler := self._notification_handlers.get(notify.method): + if entry := self._notification_handlers.get(notify.method): + handler = entry.handler logger.debug("Dispatching notification of type %s", type(notify).__name__) try: diff --git a/src/mcp/server/runner.py b/src/mcp/server/runner.py new file mode 100644 index 0000000000..1ba732ec4b --- /dev/null +++ b/src/mcp/server/runner.py @@ -0,0 +1,228 @@ +"""`ServerRunner` — per-connection orchestrator over a `Dispatcher`. + +`ServerRunner` is the bridge between the dispatcher layer (`on_request` / +`on_notify`, untyped dicts) and the user's handler layer (typed `Context`, +typed params). One instance per client connection. It: + +* handles the ``initialize`` handshake and populates `Connection` +* gates requests until initialized (``ping`` exempt) +* looks up the handler in the server's registry, validates params, builds + `Context`, runs the middleware chain, returns the result dict +* drives ``dispatcher.run()`` and the per-connection lifespan + +`ServerRunner` holds a `Server` directly — `Server` is the registry. +""" + +from __future__ import annotations + +import logging +from collections.abc import Mapping +from dataclasses import dataclass, field +from functools import partial, reduce +from typing import Any, Generic, cast + +import anyio.abc +from opentelemetry.trace import SpanKind, StatusCode +from pydantic import BaseModel +from typing_extensions import TypeVar + +from mcp.server.connection import Connection +from mcp.server.context import CallNext, Context, ServerMiddleware +from mcp.server.lowlevel.server import Server +from mcp.shared._otel import extract_trace_context, otel_span +from mcp.shared.dispatcher import DispatchContext, Dispatcher, DispatchMiddleware, OnRequest +from mcp.shared.exceptions import MCPError +from mcp.shared.transport_context import TransportContext +from mcp.types import ( + INVALID_REQUEST, + LATEST_PROTOCOL_VERSION, + METHOD_NOT_FOUND, + Implementation, + InitializeRequestParams, + InitializeResult, +) + +__all__ = ["CallNext", "ServerMiddleware", "ServerRunner", "otel_middleware"] + +logger = logging.getLogger(__name__) + +LifespanT = TypeVar("LifespanT", default=Any) + + +_INIT_EXEMPT: frozenset[str] = frozenset({"ping"}) + + +def otel_middleware(next_on_request: OnRequest) -> OnRequest: + """Dispatch-tier middleware that wraps each request in an OpenTelemetry span. + + Mirrors the span shape of the existing `Server._handle_request`: span name + ``"MCP handle []"``, ``mcp.method.name`` attribute, W3C + trace context extracted from ``params._meta`` (SEP-414), and an ERROR + status if the handler raises. + """ + + async def wrapped( + dctx: DispatchContext[TransportContext], method: str, params: Mapping[str, Any] | None + ) -> dict[str, Any]: + target: str | None + match params: + case {"name": str() as target}: + pass + case _: + target = None + parent: Any | None + match params: + case {"_meta": {**meta}}: + parent = extract_trace_context(meta) + case _: + parent = None + span_name = f"MCP handle {method}{f' {target}' if target else ''}" + with otel_span( + span_name, + kind=SpanKind.SERVER, + attributes={"mcp.method.name": method}, + context=parent, + record_exception=False, + set_status_on_exception=False, + ) as span: + try: + return await next_on_request(dctx, method, params) + except MCPError as e: + span.set_status(StatusCode.ERROR, e.error.message) + raise + except Exception as e: + span.record_exception(e) + span.set_status(StatusCode.ERROR, str(e)) + raise + + return wrapped + + +def _dump_result(result: Any) -> dict[str, Any]: + if result is None: + return {} + if isinstance(result, BaseModel): + return result.model_dump(by_alias=True, mode="json", exclude_none=True) + if isinstance(result, dict): + return cast(dict[str, Any], result) + raise TypeError(f"handler returned {type(result).__name__}; expected BaseModel, dict, or None") + + +@dataclass +class ServerRunner(Generic[LifespanT]): + """Per-connection orchestrator. One instance per client connection.""" + + server: Server[LifespanT] + dispatcher: Dispatcher[TransportContext] + lifespan_state: LifespanT + has_standalone_channel: bool + session_id: str | None = None + stateless: bool = False + dispatch_middleware: list[DispatchMiddleware] = field(default_factory=list[DispatchMiddleware]) + + connection: Connection = field(init=False) + _initialized: bool = field(init=False) + + def __post_init__(self) -> None: + self._initialized = self.stateless + self.connection = Connection( + self.dispatcher, has_standalone_channel=self.has_standalone_channel, session_id=self.session_id + ) + + async def run(self, *, task_status: anyio.abc.TaskStatus[None] = anyio.TASK_STATUS_IGNORED) -> None: + """Drive the dispatcher until the underlying channel closes. + + Composes `dispatch_middleware` over `_on_request` and hands the result + to `dispatcher.run()`. ``task_status.started()`` is forwarded so callers + can ``await tg.start(runner.run)`` and resume once the dispatcher is + ready to accept requests. Once the dispatcher exits, + `connection.exit_stack` is unwound (shielded) so any per-connection + cleanup registered by handlers or middleware runs to completion. + """ + try: + await self.dispatcher.run(self._compose_on_request(), self._on_notify, task_status=task_status) + finally: + with anyio.CancelScope(shield=True): + await self.connection.exit_stack.aclose() + + def _compose_on_request(self) -> OnRequest: + """Wrap `_on_request` in `dispatch_middleware`, outermost-first. + + Dispatch-tier middleware sees raw ``(dctx, method, params) -> dict`` + and wraps everything — initialize, METHOD_NOT_FOUND, validation + failures included. `run()` calls this once and hands the result to + `dispatcher.run()`. + """ + return reduce(lambda h, mw: mw(h), reversed(self.dispatch_middleware), self._on_request) + + async def _on_request( + self, + dctx: DispatchContext[TransportContext], + method: str, + params: Mapping[str, Any] | None, + ) -> dict[str, Any]: + if method == "initialize": + return self._handle_initialize(params) + if not self._initialized and method not in _INIT_EXEMPT: + raise MCPError( + code=INVALID_REQUEST, + message=f"Received {method!r} before initialization was complete", + ) + entry = self.server.get_request_handler(method) + if entry is None: + raise MCPError(code=METHOD_NOT_FOUND, message=f"Method not found: {method}") + # ValidationError propagates; the dispatcher's exception boundary maps + # it to INVALID_PARAMS. + typed_params = entry.params_type.model_validate(params or {}) + ctx = self._make_context(dctx, typed_params) + # TODO: cast goes away when `ServerRequestContext = Context` lands. + call: CallNext = partial(cast(Any, entry.handler), ctx, typed_params) + for mw in reversed(self.server.middleware): + call = partial(mw, ctx, method, typed_params, call) + return _dump_result(await call()) + + async def _on_notify( + self, + dctx: DispatchContext[TransportContext], + method: str, + params: Mapping[str, Any] | None, + ) -> None: + if method == "notifications/initialized": + self._initialized = True + self.connection.initialized.set() + return + if not self._initialized: + logger.debug("dropped %s: received before initialization", method) + return + entry = self.server.get_notification_handler(method) + if entry is None: + logger.debug("no handler for notification %s", method) + return + typed_params = entry.params_type.model_validate(params or {}) + ctx = self._make_context(dctx, typed_params) + # TODO: cast goes away when `ServerRequestContext = Context` lands. + await cast(Any, entry.handler)(ctx, typed_params) + + def _make_context(self, dctx: DispatchContext[TransportContext], typed_params: BaseModel) -> Context[LifespanT]: + meta = getattr(typed_params, "meta", None) + return Context(dctx, lifespan=self.lifespan_state, connection=self.connection, meta=meta) + + def _handle_initialize(self, params: Mapping[str, Any] | None) -> dict[str, Any]: + init = InitializeRequestParams.model_validate(params or {}) + self.connection.client_info = init.client_info + self.connection.client_capabilities = init.capabilities + # TODO: real version negotiation. This always responds with LATEST, + # which is wrong — the server should pick the highest version both + # sides support and compute a per-connection feature set from it. + # See FOLLOWUPS: "Consolidate per-connection mode/negotiation". + self.connection.protocol_version = ( + init.protocol_version if init.protocol_version in {LATEST_PROTOCOL_VERSION} else LATEST_PROTOCOL_VERSION + ) + self._initialized = True + self.connection.initialized.set() + result = InitializeResult( + protocol_version=self.connection.protocol_version, + capabilities=self.server.capabilities(), + server_info=Implementation(name=self.server.name, version=self.server.version or "0.0.0"), + ) + return _dump_result(result) diff --git a/src/mcp/shared/_otel.py b/src/mcp/shared/_otel.py index 170e873a0f..553b8a0bce 100644 --- a/src/mcp/shared/_otel.py +++ b/src/mcp/shared/_otel.py @@ -20,9 +20,18 @@ def otel_span( kind: SpanKind, attributes: dict[str, Any] | None = None, context: Context | None = None, + record_exception: bool = True, + set_status_on_exception: bool = True, ) -> Iterator[Any]: """Create an OTel span.""" - with _tracer.start_as_current_span(name, kind=kind, attributes=attributes, context=context) as span: + with _tracer.start_as_current_span( + name, + kind=kind, + attributes=attributes, + context=context, + record_exception=record_exception, + set_status_on_exception=set_status_on_exception, + ) as span: yield span diff --git a/src/mcp/shared/context.py b/src/mcp/shared/context.py new file mode 100644 index 0000000000..ff69c48401 --- /dev/null +++ b/src/mcp/shared/context.py @@ -0,0 +1,82 @@ +"""`BaseContext` — the user-facing per-request context. + +Composition over a `DispatchContext`: forwards the transport metadata, the +back-channel (`send_raw_request`/`notify`), progress reporting, and the cancel +event. Adds `meta` (the inbound request's `_meta` field). + +Satisfies `Outbound`, so `PeerMixin` works on it (the server-side `Context` +mixes that in directly). Shared between client and server: the server's +`Context` extends this with `lifespan`/`connection`; `ClientContext` is just an +alias. +""" + +from collections.abc import Mapping +from typing import Any, Generic + +import anyio +from typing_extensions import TypeVar + +from mcp.shared.dispatcher import CallOptions, DispatchContext +from mcp.shared.transport_context import TransportContext +from mcp.types import RequestParamsMeta + +__all__ = ["BaseContext"] + +TransportT = TypeVar("TransportT", bound=TransportContext, default=TransportContext, covariant=True) + + +class BaseContext(Generic[TransportT]): + """Per-request context wrapping a `DispatchContext`. + + `ServerRunner` constructs one per inbound request and passes it to the + user's handler. + """ + + def __init__(self, dctx: DispatchContext[TransportT], meta: RequestParamsMeta | None = None) -> None: + self._dctx = dctx + self._meta = meta + + @property + def transport(self) -> TransportT: + """Transport-specific metadata for this inbound request.""" + return self._dctx.transport + + @property + def cancel_requested(self) -> anyio.Event: + """Set when the peer sends ``notifications/cancelled`` for this request.""" + return self._dctx.cancel_requested + + @property + def can_send_request(self) -> bool: + """Whether the back-channel can deliver server-initiated requests.""" + return self._dctx.transport.can_send_request + + @property + def meta(self) -> RequestParamsMeta | None: + """The inbound request's ``_meta`` field, if present.""" + return self._meta + + async def send_raw_request( + self, + method: str, + params: Mapping[str, Any] | None, + opts: CallOptions | None = None, + ) -> dict[str, Any]: + """Send a request to the peer on the back-channel. + + Raises: + MCPError: The peer responded with an error. + NoBackChannelError: ``can_send_request`` is ``False``. + """ + return await self._dctx.send_raw_request(method, params, opts) + + async def notify(self, method: str, params: Mapping[str, Any] | None) -> None: + """Send a notification to the peer on the back-channel.""" + await self._dctx.notify(method, params) + + async def report_progress(self, progress: float, total: float | None = None, message: str | None = None) -> None: + """Report progress for this request, if the peer supplied a progress token. + + A no-op when no token was supplied. + """ + await self._dctx.progress(progress, total, message) diff --git a/src/mcp/shared/direct_dispatcher.py b/src/mcp/shared/direct_dispatcher.py new file mode 100644 index 0000000000..1842cf8abc --- /dev/null +++ b/src/mcp/shared/direct_dispatcher.py @@ -0,0 +1,183 @@ +"""In-memory `Dispatcher` that wires two peers together with no transport. + +`DirectDispatcher` is the simplest possible `Dispatcher` implementation: a +request on one side directly invokes the other side's `on_request`. There is no +serialization, no JSON-RPC framing, and no streams. It exists to: + +* prove the `Dispatcher` Protocol is implementable without JSON-RPC +* provide a fast substrate for testing the layers above the dispatcher + (`ServerRunner`, `Context`, `Connection`) without wire-level moving parts +* embed a server in-process when the JSON-RPC overhead is unnecessary + +Unlike `JSONRPCDispatcher`, exceptions raised in a handler propagate directly +to the caller — there is no exception-to-`ErrorData` boundary here. +""" + +from __future__ import annotations + +from collections.abc import Awaitable, Callable, Mapping +from dataclasses import dataclass, field +from typing import Any + +import anyio +import anyio.abc + +from mcp.shared.dispatcher import CallOptions, OnNotify, OnRequest, ProgressFnT +from mcp.shared.exceptions import MCPError, NoBackChannelError +from mcp.shared.transport_context import TransportContext +from mcp.types import INTERNAL_ERROR, REQUEST_TIMEOUT + +__all__ = ["DirectDispatcher", "create_direct_dispatcher_pair"] + +DIRECT_TRANSPORT_KIND = "direct" + + +_Request = Callable[[str, Mapping[str, Any] | None, CallOptions | None], Awaitable[dict[str, Any]]] +_Notify = Callable[[str, Mapping[str, Any] | None], Awaitable[None]] + + +@dataclass +class _DirectDispatchContext: + """`DispatchContext` for an inbound request on a `DirectDispatcher`. + + The back-channel callables target the *originating* side, so a handler's + `send_raw_request` reaches the peer that made the inbound request. + """ + + transport: TransportContext + _back_request: _Request + _back_notify: _Notify + _on_progress: ProgressFnT | None = None + cancel_requested: anyio.Event = field(default_factory=anyio.Event) + + async def notify(self, method: str, params: Mapping[str, Any] | None) -> None: + await self._back_notify(method, params) + + async def send_raw_request( + self, + method: str, + params: Mapping[str, Any] | None, + opts: CallOptions | None = None, + ) -> dict[str, Any]: + if not self.transport.can_send_request: + raise NoBackChannelError(method) + return await self._back_request(method, params, opts) + + async def progress(self, progress: float, total: float | None = None, message: str | None = None) -> None: + if self._on_progress is not None: + await self._on_progress(progress, total, message) + + +class DirectDispatcher: + """A `Dispatcher` that calls a peer's handlers directly, in-process. + + Two instances are wired together with `create_direct_dispatcher_pair`; each + holds a reference to the other. `send_raw_request` on one awaits the peer's + `on_request`. `run` parks until `close` is called. + """ + + def __init__(self, transport_ctx: TransportContext): + self._transport_ctx = transport_ctx + self._peer: DirectDispatcher | None = None + self._on_request: OnRequest | None = None + self._on_notify: OnNotify | None = None + self._ready = anyio.Event() + self._closed = anyio.Event() + + def connect_to(self, peer: DirectDispatcher) -> None: + self._peer = peer + + async def send_raw_request( + self, + method: str, + params: Mapping[str, Any] | None, + opts: CallOptions | None = None, + ) -> dict[str, Any]: + if self._peer is None: + raise RuntimeError("DirectDispatcher has no peer; use create_direct_dispatcher_pair()") + return await self._peer._dispatch_request(method, params, opts) + + async def notify(self, method: str, params: Mapping[str, Any] | None) -> None: + if self._peer is None: + raise RuntimeError("DirectDispatcher has no peer; use create_direct_dispatcher_pair()") + await self._peer._dispatch_notify(method, params) + + async def run( + self, + on_request: OnRequest, + on_notify: OnNotify, + *, + task_status: anyio.abc.TaskStatus[None] = anyio.TASK_STATUS_IGNORED, + ) -> None: + self._on_request = on_request + self._on_notify = on_notify + self._ready.set() + task_status.started() + await self._closed.wait() + + def close(self) -> None: + self._closed.set() + + def _make_context(self, on_progress: ProgressFnT | None = None) -> _DirectDispatchContext: + assert self._peer is not None + peer = self._peer + return _DirectDispatchContext( + transport=self._transport_ctx, + _back_request=lambda m, p, o: peer._dispatch_request(m, p, o), + _back_notify=lambda m, p: peer._dispatch_notify(m, p), + _on_progress=on_progress, + ) + + async def _dispatch_request( + self, + method: str, + params: Mapping[str, Any] | None, + opts: CallOptions | None, + ) -> dict[str, Any]: + await self._ready.wait() + assert self._on_request is not None + opts = opts or {} + dctx = self._make_context(on_progress=opts.get("on_progress")) + try: + with anyio.fail_after(opts.get("timeout")): + try: + return await self._on_request(dctx, method, params) + except MCPError: + raise + except Exception as e: + raise MCPError(code=INTERNAL_ERROR, message=str(e)) from e + except TimeoutError: + raise MCPError( + code=REQUEST_TIMEOUT, + message=f"Timed out after {opts.get('timeout')}s waiting for {method!r}", + ) from None + + async def _dispatch_notify(self, method: str, params: Mapping[str, Any] | None) -> None: + await self._ready.wait() + assert self._on_notify is not None + dctx = self._make_context() + await self._on_notify(dctx, method, params) + + +def create_direct_dispatcher_pair( + *, + can_send_request: bool = True, + headers: Mapping[str, str] | None = None, +) -> tuple[DirectDispatcher, DirectDispatcher]: + """Create two `DirectDispatcher` instances wired to each other. + + Args: + can_send_request: Sets `TransportContext.can_send_request` on both + sides. Pass ``False`` to simulate a transport with no back-channel. + headers: Sets `TransportContext.headers` on both sides. + + Returns: + A ``(left, right)`` pair. Conventionally ``left`` is the client side + and ``right`` is the server side, but the wiring is symmetric. + """ + ctx = TransportContext(kind=DIRECT_TRANSPORT_KIND, can_send_request=can_send_request, headers=headers) + left = DirectDispatcher(ctx) + right = DirectDispatcher(ctx) + left.connect_to(right) + right.connect_to(left) + return left, right diff --git a/src/mcp/shared/dispatcher.py b/src/mcp/shared/dispatcher.py new file mode 100644 index 0000000000..20c090323b --- /dev/null +++ b/src/mcp/shared/dispatcher.py @@ -0,0 +1,157 @@ +"""Dispatcher Protocol — the call/return boundary between transports and handlers. + +A Dispatcher turns a duplex message channel into two things: + +* an outbound API: ``send_raw_request(method, params)`` and ``notify(method, params)`` +* an inbound pump: ``run(on_request, on_notify)`` that drives the receive loop + and invokes the supplied handlers for each incoming request/notification + +It is deliberately *not* MCP-aware. Method names are strings, params and +results are ``dict[str, Any]``. The MCP type layer (request/result models, +capability negotiation, ``Context``) sits above this; the wire encoding +(JSON-RPC, gRPC, in-process direct calls) sits below it. + +See ``JSONRPCDispatcher`` for the production implementation and +``DirectDispatcher`` for an in-memory implementation used in tests and for +embedding a server in-process. +""" + +from collections.abc import Awaitable, Callable, Mapping +from typing import Any, Protocol, TypedDict, TypeVar, runtime_checkable + +import anyio +import anyio.abc + +from mcp.shared.transport_context import TransportContext + +__all__ = [ + "CallOptions", + "DispatchContext", + "DispatchMiddleware", + "Dispatcher", + "OnNotify", + "OnRequest", + "Outbound", + "ProgressFnT", +] + +TransportT_co = TypeVar("TransportT_co", bound=TransportContext, covariant=True) + + +class ProgressFnT(Protocol): + """Callback invoked when a progress notification arrives for a pending request.""" + + async def __call__(self, progress: float, total: float | None, message: str | None) -> None: ... + + +class CallOptions(TypedDict, total=False): + """Per-call options for `Outbound.send_raw_request`. + + All keys are optional. Dispatchers ignore keys they do not understand. + """ + + timeout: float + """Seconds to wait for a result before raising and sending ``notifications/cancelled``.""" + + on_progress: ProgressFnT + """Receive ``notifications/progress`` updates for this request.""" + + resumption_token: str + """Opaque token to resume a previously interrupted request (transport-dependent).""" + + on_resumption_token: Callable[[str], Awaitable[None]] + """Receive a resumption token when the transport issues one.""" + + +@runtime_checkable +class Outbound(Protocol): + """Anything that can send requests and notifications to the peer. + + Both `Dispatcher` (top-level outbound) and `DispatchContext` (back-channel + during an inbound request) extend this. The MCP type layer (`PeerMixin`, + `Connection`, `Context`) builds typed ``send_request`` / convenience methods + on top of this raw channel. + """ + + async def send_raw_request( + self, + method: str, + params: Mapping[str, Any] | None, + opts: CallOptions | None = None, + ) -> dict[str, Any]: + """Send a request and await its raw result dict. + + Raises: + MCPError: If the peer responded with an error, or the handler + raised. Implementations normalize all handler exceptions to + `MCPError` so callers see a single exception type. + """ + ... + + async def notify(self, method: str, params: Mapping[str, Any] | None) -> None: + """Send a fire-and-forget notification.""" + ... + + +class DispatchContext(Outbound, Protocol[TransportT_co]): + """Per-request context handed to ``on_request`` / ``on_notify``. + + Carries the transport metadata for the inbound message and provides the + back-channel for sending requests/notifications to the peer while handling + it. `send_raw_request` raises `NoBackChannelError` if + ``transport.can_send_request`` is ``False``. + """ + + @property + def transport(self) -> TransportT_co: + """Transport-specific metadata for this inbound message.""" + ... + + @property + def cancel_requested(self) -> anyio.Event: + """Set when the peer sends ``notifications/cancelled`` for this request.""" + ... + + async def progress(self, progress: float, total: float | None = None, message: str | None = None) -> None: + """Report progress for the inbound request, if the peer supplied a progress token. + + A no-op when no token was supplied. + """ + ... + + +OnRequest = Callable[[DispatchContext[TransportContext], str, Mapping[str, Any] | None], Awaitable[dict[str, Any]]] +"""Handler for inbound requests: ``(ctx, method, params) -> result``. Raise ``MCPError`` to send an error response.""" + +OnNotify = Callable[[DispatchContext[TransportContext], str, Mapping[str, Any] | None], Awaitable[None]] +"""Handler for inbound notifications: ``(ctx, method, params)``.""" + +DispatchMiddleware = Callable[[OnRequest], OnRequest] +"""Wraps an ``OnRequest`` to produce another ``OnRequest``. Applied outermost-first.""" + + +class Dispatcher(Outbound, Protocol[TransportT_co]): + """A duplex request/notification channel with call-return semantics. + + Implementations own correlation of outbound requests to inbound results, the + receive loop, per-request concurrency, and cancellation/progress wiring. + """ + + async def run( + self, + on_request: OnRequest, + on_notify: OnNotify, + *, + task_status: anyio.abc.TaskStatus[None] = anyio.TASK_STATUS_IGNORED, + ) -> None: + """Drive the receive loop until the underlying channel closes. + + Each inbound request is dispatched to ``on_request`` in its own task; + the returned dict (or raised ``MCPError``) is sent back as the response. + Inbound notifications go to ``on_notify``. + + ``task_status.started()`` is called once the dispatcher is ready to + accept ``send_request``/``notify`` calls, so callers can use + ``await tg.start(dispatcher.run, on_request, on_notify)``. + """ + ... diff --git a/src/mcp/shared/exceptions.py b/src/mcp/shared/exceptions.py index f153ea319d..b62629b6c8 100644 --- a/src/mcp/shared/exceptions.py +++ b/src/mcp/shared/exceptions.py @@ -2,7 +2,7 @@ from typing import Any, cast -from mcp.types import URL_ELICITATION_REQUIRED, ElicitRequestURLParams, ErrorData, JSONRPCError +from mcp.types import INVALID_REQUEST, URL_ELICITATION_REQUIRED, ElicitRequestURLParams, ErrorData, JSONRPCError class MCPError(Exception): @@ -41,6 +41,25 @@ def __str__(self) -> str: return self.message +class NoBackChannelError(MCPError): + """Raised when sending a server-initiated request over a transport that cannot deliver it. + + Stateless HTTP and JSON-response-mode HTTP have no channel for the server to + push requests (sampling, elicitation, roots/list) to the client. This is + raised by `DispatchContext.send_raw_request` when `transport.can_send_request` + is ``False``, and serializes to an ``INVALID_REQUEST`` error response. + """ + + def __init__(self, method: str): + super().__init__( + code=INVALID_REQUEST, + message=( + f"Cannot send {method!r}: this transport context has no back-channel for server-initiated requests." + ), + ) + self.method = method + + class StatelessModeNotSupported(RuntimeError): """Raised when attempting to use a method that is not supported in stateless mode. diff --git a/src/mcp/shared/jsonrpc_dispatcher.py b/src/mcp/shared/jsonrpc_dispatcher.py new file mode 100644 index 0000000000..b450bb66d5 --- /dev/null +++ b/src/mcp/shared/jsonrpc_dispatcher.py @@ -0,0 +1,558 @@ +"""JSON-RPC `Dispatcher` implementation. + +Consumes the existing `SessionMessage`-based stream contract that all current +transports (stdio, SSE, streamable HTTP) speak. Owns request-id correlation, +the receive loop, per-request task isolation, cancellation/progress wiring, and +the single exception-to-wire boundary. + +The MCP type layer (`ServerRunner`, `Context`, `Client`) sits above this and +sees only `(ctx, method, params) -> dict`. Transports sit below and see only +`SessionMessage` reads/writes. + +The dispatcher is *mostly* MCP-agnostic — methods/params are opaque strings and +dicts — but it intercepts ``notifications/cancelled`` and +``notifications/progress`` because request correlation, cancellation and +progress are exactly the wiring this layer exists to provide. Those few wire +shapes are extracted with structural ``match`` patterns (no casts, no +``mcp.types`` model coupling); a malformed payload simply fails to match and +the correlation is skipped. +""" + +from __future__ import annotations + +import contextvars +import logging +from collections.abc import Awaitable, Callable, Mapping +from dataclasses import dataclass, field +from typing import Any, Generic, Literal, TypeVar, cast, overload + +import anyio +import anyio.abc +from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream +from pydantic import ValidationError + +from mcp.shared._stream_protocols import ReadStream, WriteStream +from mcp.shared.dispatcher import CallOptions, Dispatcher, OnNotify, OnRequest, ProgressFnT +from mcp.shared.exceptions import MCPError, NoBackChannelError +from mcp.shared.message import ( + ClientMessageMetadata, + MessageMetadata, + ServerMessageMetadata, + SessionMessage, +) +from mcp.shared.transport_context import TransportContext +from mcp.types import ( + CONNECTION_CLOSED, + INTERNAL_ERROR, + INVALID_PARAMS, + REQUEST_CANCELLED, + REQUEST_TIMEOUT, + ErrorData, + JSONRPCError, + JSONRPCMessage, + JSONRPCNotification, + JSONRPCRequest, + JSONRPCResponse, + ProgressToken, + RequestId, +) + +__all__ = ["JSONRPCDispatcher"] + +logger = logging.getLogger(__name__) + +TransportT = TypeVar("TransportT", bound=TransportContext) + +PeerCancelMode = Literal["interrupt", "signal"] +"""How inbound ``notifications/cancelled`` is applied to a running handler. + +``"interrupt"`` (default) cancels the handler's scope. ``"signal"`` only sets +``ctx.cancel_requested`` and lets the handler observe it cooperatively. +""" + +TransportBuilder = Callable[[RequestId | None, MessageMetadata], TransportContext] +"""Builds the per-message `TransportContext` from the inbound JSON-RPC id and +the `SessionMessage.metadata` the transport attached. Defaults to a plain +`TransportContext(kind="jsonrpc", can_send_request=True)` when not supplied.""" + + +def _coerce_id(request_id: RequestId) -> RequestId: + """Coerce a string request ID to int when it's a valid int literal. + + `_allocate_id` only ever produces ``int`` keys for ``_pending``, but a peer + may echo the ID back as a JSON string. The TypeScript SDK and `BaseSession` + both perform this coercion at lookup time so the response still correlates. + """ + if isinstance(request_id, str): + try: + return int(request_id) + except ValueError: + pass + return request_id + + +@dataclass(slots=True) +class _Pending: + """An outbound request awaiting its response.""" + + send: MemoryObjectSendStream[dict[str, Any] | ErrorData] + receive: MemoryObjectReceiveStream[dict[str, Any] | ErrorData] + on_progress: ProgressFnT | None = None + + +@dataclass(slots=True) +class _InFlight(Generic[TransportT]): + """An inbound request currently being handled.""" + + scope: anyio.CancelScope + dctx: _JSONRPCDispatchContext[TransportT] + cancelled_by_peer: bool = False + + +@dataclass +class _JSONRPCDispatchContext(Generic[TransportT]): + """Concrete `DispatchContext` produced for each inbound JSON-RPC message.""" + + transport: TransportT + _dispatcher: JSONRPCDispatcher[TransportT] + _request_id: RequestId | None + _progress_token: ProgressToken | None = None + _closed: bool = False + cancel_requested: anyio.Event = field(default_factory=anyio.Event) + + @property + def can_send_request(self) -> bool: + return self.transport.can_send_request and not self._closed + + async def notify(self, method: str, params: Mapping[str, Any] | None) -> None: + await self._dispatcher.notify(method, params, _related_request_id=self._request_id) + + async def send_raw_request( + self, + method: str, + params: Mapping[str, Any] | None, + opts: CallOptions | None = None, + ) -> dict[str, Any]: + if not self.can_send_request: + raise NoBackChannelError(method) + return await self._dispatcher.send_raw_request(method, params, opts, _related_request_id=self._request_id) + + async def progress(self, progress: float, total: float | None = None, message: str | None = None) -> None: + if self._progress_token is None: + return + params: dict[str, Any] = {"progressToken": self._progress_token, "progress": progress} + if total is not None: + params["total"] = total + if message is not None: + params["message"] = message + await self.notify("notifications/progress", params) + + def close(self) -> None: + self._closed = True + + +def _default_transport_builder(_request_id: RequestId | None, _meta: MessageMetadata) -> TransportContext: + return TransportContext(kind="jsonrpc", can_send_request=True) + + +def _outbound_metadata(related_request_id: RequestId | None, opts: CallOptions | None) -> MessageMetadata: + """Choose the `SessionMessage.metadata` for an outgoing request/notification. + + `ServerMessageMetadata` tags a server-to-client message with the inbound + request it belongs to (so streamable-HTTP can route it onto that request's + SSE stream). `ClientMessageMetadata` carries resumption hints to the + client transport. ``None`` is the common case. + """ + if related_request_id is not None: + return ServerMessageMetadata(related_request_id=related_request_id) + if opts: + token = opts.get("resumption_token") + on_token = opts.get("on_resumption_token") + if token is not None or on_token is not None: + return ClientMessageMetadata(resumption_token=token, on_resumption_token_update=on_token) + return None + + +class JSONRPCDispatcher(Dispatcher[TransportT]): + """`Dispatcher` over the existing `SessionMessage` stream contract. + + Inherits the `Dispatcher` Protocol explicitly so pyright checks + conformance at the class definition rather than at first use. + """ + + @overload + def __init__( + self: JSONRPCDispatcher[TransportContext], + read_stream: ReadStream[SessionMessage | Exception], + write_stream: WriteStream[SessionMessage], + ) -> None: ... + @overload + def __init__( + self, + read_stream: ReadStream[SessionMessage | Exception], + write_stream: WriteStream[SessionMessage], + *, + transport_builder: Callable[[RequestId | None, MessageMetadata], TransportT], + peer_cancel_mode: PeerCancelMode = "interrupt", + raise_handler_exceptions: bool = False, + ) -> None: ... + def __init__( + self, + read_stream: ReadStream[SessionMessage | Exception], + write_stream: WriteStream[SessionMessage], + *, + transport_builder: Callable[[RequestId | None, MessageMetadata], TransportT] | None = None, + peer_cancel_mode: PeerCancelMode = "interrupt", + raise_handler_exceptions: bool = False, + ) -> None: + self._read_stream = read_stream + self._write_stream = write_stream + # The overloads guarantee that when `transport_builder` is omitted, + # `TransportT` is `TransportContext`, so the default is type-correct; + # pyright can't see across overloads, hence the cast. + self._transport_builder = cast( + "Callable[[RequestId | None, MessageMetadata], TransportT]", + transport_builder or _default_transport_builder, + ) + self._peer_cancel_mode: PeerCancelMode = peer_cancel_mode + self._raise_handler_exceptions = raise_handler_exceptions + + self._next_id = 0 + self._pending: dict[RequestId, _Pending] = {} + self._in_flight: dict[RequestId, _InFlight[TransportT]] = {} + self._tg: anyio.abc.TaskGroup | None = None + self._running = False + + async def send_raw_request( + self, + method: str, + params: Mapping[str, Any] | None, + opts: CallOptions | None = None, + *, + _related_request_id: RequestId | None = None, + ) -> dict[str, Any]: + """Send a JSON-RPC request and await its response. + + ``_related_request_id`` is set only by `_JSONRPCDispatchContext` when a + handler makes a server-to-client request mid-flight; it routes the + outgoing message onto the correct per-request SSE stream (SHTTP) via + `ServerMessageMetadata`. Top-level callers leave it ``None``. + + Raises: + MCPError: The peer responded with a JSON-RPC error; or + ``REQUEST_TIMEOUT`` if ``opts["timeout"]`` elapsed; or + ``CONNECTION_CLOSED`` if the dispatcher shut down while + awaiting the response. + RuntimeError: Called before ``run()`` has started or after it has + finished. + """ + if not self._running: + raise RuntimeError("JSONRPCDispatcher.send_raw_request called before run() / after close") + opts = opts or {} + request_id = self._allocate_id() + out_params = dict(params) if params is not None else None + on_progress = opts.get("on_progress") + if on_progress is not None: + # The caller wants progress updates. The spec mechanism is: include + # `_meta.progressToken` on the request; the peer echoes that token on + # any `notifications/progress` it sends. We use the request id as the + # token so the receive loop can find this `_Pending.on_progress` by + # `_pending[token]` without a second lookup table. + meta = dict((out_params or {}).get("_meta") or {}) + meta["progressToken"] = request_id + out_params = {**(out_params or {}), "_meta": meta} + + # buffer=1: at most one outcome is ever delivered. A `WouldBlock` from + # `_resolve_pending`/`_fan_out_closed` means the waiter already has an + # outcome and dropping the late/redundant signal is correct. buffer=0 + # is unsafe — there's a window between registering `_pending[id]` and + # parking in `receive()` where a close signal would be lost. + send, receive = anyio.create_memory_object_stream[dict[str, Any] | ErrorData](1) + pending = _Pending(send=send, receive=receive, on_progress=on_progress) + self._pending[request_id] = pending + + metadata = _outbound_metadata(_related_request_id, opts) + msg = JSONRPCRequest(jsonrpc="2.0", id=request_id, method=method, params=out_params) + try: + await self._write(msg, metadata) + with anyio.fail_after(opts.get("timeout")): + outcome = await receive.receive() + except TimeoutError: + # Spec-recommended courtesy: tell the peer we've given up so it can + # stop work and free resources. v1's BaseSession.send_request does + # NOT do this; it's new behaviour. + await self._cancel_outbound(request_id, f"timed out after {opts.get('timeout')}s") + raise MCPError(code=REQUEST_TIMEOUT, message=f"Request {method!r} timed out") from None + except anyio.get_cancelled_exc_class(): + # Our caller's scope was cancelled. We're already inside a cancelled + # scope, so any bare `await` here re-raises immediately — shield to + # let the courtesy cancel notification go out before we propagate. + with anyio.CancelScope(shield=True): + await self._cancel_outbound(request_id, "caller cancelled") + raise + finally: + # Always remove the waiter, even on cancel/timeout, so a late + # response from the peer (race) hits a closed stream and is dropped + # in `_dispatch` rather than leaking. + self._pending.pop(request_id, None) + send.close() + receive.close() + + if isinstance(outcome, ErrorData): + raise MCPError(code=outcome.code, message=outcome.message, data=outcome.data) + return outcome + + async def notify( + self, + method: str, + params: Mapping[str, Any] | None, + *, + _related_request_id: RequestId | None = None, + ) -> None: + msg = JSONRPCNotification(jsonrpc="2.0", method=method, params=dict(params) if params is not None else None) + await self._write(msg, _outbound_metadata(_related_request_id, None)) + + async def run( + self, + on_request: OnRequest, + on_notify: OnNotify, + *, + task_status: anyio.abc.TaskStatus[None] = anyio.TASK_STATUS_IGNORED, + ) -> None: + """Drive the receive loop until the read stream closes. + + Each inbound request is handled in its own task in an internal task + group; ``task_status.started()`` fires once that group is open, so + ``await tg.start(dispatcher.run, ...)`` resumes when ``send_raw_request`` + is usable. + """ + try: + async with anyio.create_task_group() as tg: + self._tg = tg + self._running = True + task_status.started() + async with self._read_stream: + async for item in self._read_stream: + # Duck-typed: `_context_streams.ContextReceiveStream` + # exposes `.last_context` (the sender's contextvars + # snapshot per message). Plain memory streams don't. + sender_ctx: contextvars.Context | None = getattr(self._read_stream, "last_context", None) + self._dispatch(item, on_request, on_notify, sender_ctx) + # Read stream EOF: wake any blocked `send_raw_request` waiters now, + # *before* the task group joins, so handlers parked in + # `dctx.send_raw_request()` can unwind and the join doesn't deadlock. + self._running = False + self._fan_out_closed() + finally: + # Covers the cancel/crash paths where the inline fan-out above is + # never reached. Idempotent. + self._running = False + self._tg = None + self._fan_out_closed() + + def _dispatch( + self, + item: SessionMessage | Exception, + on_request: OnRequest, + on_notify: OnNotify, + sender_ctx: contextvars.Context | None, + ) -> None: + """Route one inbound item. Synchronous: never awaits. + + Everything here is `send_nowait` or `_spawn`. An `await` would let one + slow message head-of-line block the entire read loop. + """ + if isinstance(item, Exception): + logger.debug("transport yielded exception: %r", item) + return + metadata = item.metadata + msg = item.message + match msg: + case JSONRPCRequest(): + self._dispatch_request(msg, metadata, on_request, sender_ctx) + case JSONRPCNotification(): + self._dispatch_notification(msg, metadata, on_notify, sender_ctx) + case JSONRPCResponse(): + self._resolve_pending(msg.id, msg.result) + case JSONRPCError(): # pragma: no branch + # `id` may be None per JSON-RPC (parse error before id known). + # The match is exhaustive over JSONRPCMessage; the no-match arc + # on this final case is unreachable. + self._resolve_pending(msg.id, msg.error) + + def _dispatch_request( + self, + req: JSONRPCRequest, + metadata: MessageMetadata, + on_request: OnRequest, + sender_ctx: contextvars.Context | None, + ) -> None: + progress_token: ProgressToken | None + match req.params: + case {"_meta": {"progressToken": str() | int() as progress_token}}: + pass + case _: + progress_token = None + transport_ctx = self._transport_builder(req.id, metadata) + dctx = _JSONRPCDispatchContext( + transport=transport_ctx, + _dispatcher=self, + _request_id=req.id, + _progress_token=progress_token, + ) + scope = anyio.CancelScope() + self._in_flight[req.id] = _InFlight(scope=scope, dctx=dctx) + self._spawn(self._handle_request, req, dctx, scope, on_request, sender_ctx=sender_ctx) + + def _dispatch_notification( + self, + msg: JSONRPCNotification, + metadata: MessageMetadata, + on_notify: OnNotify, + sender_ctx: contextvars.Context | None, + ) -> None: + if msg.method == "notifications/cancelled": + match msg.params: + case {"requestId": str() | int() as rid} if (in_flight := self._in_flight.get(rid)) is not None: + in_flight.cancelled_by_peer = True + in_flight.dctx.cancel_requested.set() + if self._peer_cancel_mode == "interrupt": + in_flight.scope.cancel() + case _: + pass + return + if msg.method == "notifications/progress": + match msg.params: + case {"progressToken": str() | int() as token, "progress": int() | float() as progress} if ( + pending := self._pending.get(_coerce_id(token)) + ) is not None and pending.on_progress is not None: + total = msg.params.get("total") + message = msg.params.get("message") + self._spawn( + pending.on_progress, + float(progress), + float(total) if isinstance(total, int | float) else None, + message if isinstance(message, str) else None, + sender_ctx=sender_ctx, + ) + case _: + pass + # fall through: progress is also teed to on_notify + transport_ctx = self._transport_builder(None, metadata) + dctx = _JSONRPCDispatchContext(transport=transport_ctx, _dispatcher=self, _request_id=None) + self._spawn(on_notify, dctx, msg.method, msg.params, sender_ctx=sender_ctx) + + def _resolve_pending(self, request_id: RequestId | None, outcome: dict[str, Any] | ErrorData) -> None: + pending = self._pending.get(_coerce_id(request_id)) if request_id is not None else None + if pending is None: + logger.debug("dropping response for unknown/late request id %r", request_id) + return + try: + pending.send.send_nowait(outcome) + except (anyio.WouldBlock, anyio.BrokenResourceError, anyio.ClosedResourceError): + logger.debug("waiter for request id %r already gone", request_id) + + def _spawn( + self, + fn: Callable[..., Awaitable[Any]], + *args: object, + sender_ctx: contextvars.Context | None, + ) -> None: + """Schedule ``fn(*args)`` in the run() task group, propagating the sender's contextvars. + + ASGI middleware (auth, OTel) sets contextvars on the request task that + wrote into the read stream. ``Context.run(tg.start_soon, ...)`` makes + the spawned handler inherit *that* context instead of the receive + loop's, so ``auth_context_var`` and OTel spans survive. + """ + assert self._tg is not None + if sender_ctx is not None: + sender_ctx.run(self._tg.start_soon, fn, *args) + else: + self._tg.start_soon(fn, *args) + + def _fan_out_closed(self) -> None: + """Wake every pending ``send_raw_request`` waiter with ``CONNECTION_CLOSED``. + + Synchronous (uses ``send_nowait``) because it's called from ``finally`` + which may be inside a cancelled scope. Idempotent. + """ + closed = ErrorData(code=CONNECTION_CLOSED, message="connection closed") + for pending in self._pending.values(): + try: + pending.send.send_nowait(closed) + except (anyio.WouldBlock, anyio.BrokenResourceError, anyio.ClosedResourceError): + pass + self._pending.clear() + + async def _handle_request( + self, + req: JSONRPCRequest, + dctx: _JSONRPCDispatchContext[TransportT], + scope: anyio.CancelScope, + on_request: OnRequest, + ) -> None: + """Run ``on_request`` for one inbound request and write its response. + + This is the single exception-to-wire boundary: handler exceptions are + caught here and serialized to ``JSONRPCError``. Nothing above this in + the stack constructs wire errors. + """ + try: + with scope: + try: + result = await on_request(dctx, req.method, req.params) + finally: + # Close the back-channel the moment the handler exits + # (success or raise), before the response write — a handler + # spawning detached work that later calls + # `dctx.send_raw_request()` should see `NoBackChannelError`. + dctx.close() + await self._write_result(req.id, result) + # Peer-cancel: `_dispatch_notification` cancelled this scope. anyio + # swallows a scope's *own* cancel at __exit__, so the result write + # (or the handler) is interrupted and execution lands here without + # reaching the `except cancelled` arm below. Spec SHOULD: send no + # response — fall through to `finally`. + except anyio.get_cancelled_exc_class(): + # Outer-cancel: run()'s task group is shutting down. Any bare + # `await` here re-raises immediately, so shield the courtesy write. + with anyio.CancelScope(shield=True): + await self._write_error(req.id, ErrorData(code=REQUEST_CANCELLED, message="Request cancelled")) + raise + except MCPError as e: + await self._write_error(req.id, e.error) + except ValidationError as e: + await self._write_error(req.id, ErrorData(code=INVALID_PARAMS, message=str(e))) + except Exception as e: + logger.exception("handler for %r raised", req.method) + await self._write_error(req.id, ErrorData(code=INTERNAL_ERROR, message=str(e))) + if self._raise_handler_exceptions: + raise + finally: + self._in_flight.pop(req.id, None) + + def _allocate_id(self) -> int: + self._next_id += 1 + return self._next_id + + async def _write(self, message: JSONRPCMessage, metadata: MessageMetadata = None) -> None: + await self._write_stream.send(SessionMessage(message=message, metadata=metadata)) + + async def _write_result(self, request_id: RequestId, result: dict[str, Any]) -> None: + try: + await self._write(JSONRPCResponse(jsonrpc="2.0", id=request_id, result=result)) + except (anyio.BrokenResourceError, anyio.ClosedResourceError): + logger.debug("dropped result for %r: write stream closed", request_id) + + async def _write_error(self, request_id: RequestId, error: ErrorData) -> None: + try: + await self._write(JSONRPCError(jsonrpc="2.0", id=request_id, error=error)) + except (anyio.BrokenResourceError, anyio.ClosedResourceError): + logger.debug("dropped error for %r: write stream closed", request_id) + + async def _cancel_outbound(self, request_id: RequestId, reason: str) -> None: + try: + await self.notify("notifications/cancelled", {"requestId": request_id, "reason": reason}) + except (anyio.BrokenResourceError, anyio.ClosedResourceError): + pass diff --git a/src/mcp/shared/peer.py b/src/mcp/shared/peer.py new file mode 100644 index 0000000000..47b64c7769 --- /dev/null +++ b/src/mcp/shared/peer.py @@ -0,0 +1,216 @@ +"""Typed MCP request sugar over an `Outbound`. + +`PeerMixin` defines the server-to-client request methods (sampling, elicitation, +roots, ping) once. Any class that satisfies `Outbound` (i.e. has +``send_raw_request`` and ``notify``) can mix it in and get the typed methods for +free — `Context`, `Connection`, `Client`, or the bare `Peer` wrapper below. + +The mixin does no capability gating: it builds the params, calls +``self.send_raw_request(method, params)``, and parses the result into the typed +model. Gating (and `NoBackChannelError`) is the host's `send_raw_request`'s job. +""" + +from collections.abc import Mapping +from typing import Any, overload + +from pydantic import BaseModel + +from mcp.shared.dispatcher import CallOptions, Outbound +from mcp.types import ( + CreateMessageRequestParams, + CreateMessageResult, + CreateMessageResultWithTools, + ElicitRequestedSchema, + ElicitRequestFormParams, + ElicitRequestURLParams, + ElicitResult, + IncludeContext, + ListRootsResult, + ModelPreferences, + SamplingMessage, + Tool, + ToolChoice, +) + +__all__ = ["Meta", "Peer", "PeerMixin", "dump_params"] + +Meta = dict[str, Any] +"""Type alias for the ``_meta`` field carried on request/notification params.""" + + +def dump_params(model: BaseModel | None, meta: Meta | None = None) -> dict[str, Any] | None: + """Serialize a params model to a wire dict, merging ``meta`` into ``_meta``. + + Shared by `PeerMixin`, `Connection`, and `TypedServerRequestMixin` so every + typed convenience method gets the same `_meta` handling. ``meta`` keys take + precedence over any ``_meta`` already present on the model. + """ + out = model.model_dump(by_alias=True, mode="json", exclude_none=True) if model is not None else None + if meta: + out = dict(out or {}) + out["_meta"] = {**out.get("_meta", {}), **meta} + return out + + +class PeerMixin: + """Typed server-to-client request methods. + + Each method constrains ``self`` to `Outbound` so the mixin can be applied + to anything with ``send_raw_request``/``notify`` — pyright checks the host + class structurally at the call site. + """ + + @overload + async def sample( + self: Outbound, + messages: list[SamplingMessage], + *, + max_tokens: int, + system_prompt: str | None = None, + include_context: IncludeContext | None = None, + temperature: float | None = None, + stop_sequences: list[str] | None = None, + metadata: dict[str, Any] | None = None, + model_preferences: ModelPreferences | None = None, + tools: None = None, + tool_choice: ToolChoice | None = None, + meta: Meta | None = None, + opts: CallOptions | None = None, + ) -> CreateMessageResult: ... + @overload + async def sample( + self: Outbound, + messages: list[SamplingMessage], + *, + max_tokens: int, + system_prompt: str | None = None, + include_context: IncludeContext | None = None, + temperature: float | None = None, + stop_sequences: list[str] | None = None, + metadata: dict[str, Any] | None = None, + model_preferences: ModelPreferences | None = None, + tools: list[Tool], + tool_choice: ToolChoice | None = None, + meta: Meta | None = None, + opts: CallOptions | None = None, + ) -> CreateMessageResultWithTools: ... + async def sample( + self: Outbound, + messages: list[SamplingMessage], + *, + max_tokens: int, + system_prompt: str | None = None, + include_context: IncludeContext | None = None, + temperature: float | None = None, + stop_sequences: list[str] | None = None, + metadata: dict[str, Any] | None = None, + model_preferences: ModelPreferences | None = None, + tools: list[Tool] | None = None, + tool_choice: ToolChoice | None = None, + meta: Meta | None = None, + opts: CallOptions | None = None, + ) -> CreateMessageResult | CreateMessageResultWithTools: + """Send a ``sampling/createMessage`` request to the peer. + + Raises: + MCPError: The peer responded with an error. + NoBackChannelError: The host's transport context has no + back-channel for server-initiated requests. + """ + params = CreateMessageRequestParams( + messages=messages, + system_prompt=system_prompt, + include_context=include_context, + temperature=temperature, + max_tokens=max_tokens, + stop_sequences=stop_sequences, + metadata=metadata, + model_preferences=model_preferences, + tools=tools, + tool_choice=tool_choice, + ) + result = await self.send_raw_request("sampling/createMessage", dump_params(params, meta), opts) + if tools is not None: + return CreateMessageResultWithTools.model_validate(result) + return CreateMessageResult.model_validate(result) + + async def elicit_form( + self: Outbound, + message: str, + requested_schema: ElicitRequestedSchema, + *, + meta: Meta | None = None, + opts: CallOptions | None = None, + ) -> ElicitResult: + """Send a form-mode ``elicitation/create`` request. + + Raises: + MCPError: The peer responded with an error. + NoBackChannelError: No back-channel for server-initiated requests. + """ + params = ElicitRequestFormParams(message=message, requested_schema=requested_schema) + result = await self.send_raw_request("elicitation/create", dump_params(params, meta), opts) + return ElicitResult.model_validate(result) + + async def elicit_url( + self: Outbound, + message: str, + url: str, + elicitation_id: str, + *, + meta: Meta | None = None, + opts: CallOptions | None = None, + ) -> ElicitResult: + """Send a URL-mode ``elicitation/create`` request. + + Raises: + MCPError: The peer responded with an error. + NoBackChannelError: No back-channel for server-initiated requests. + """ + params = ElicitRequestURLParams(message=message, url=url, elicitation_id=elicitation_id) + result = await self.send_raw_request("elicitation/create", dump_params(params, meta), opts) + return ElicitResult.model_validate(result) + + async def list_roots( + self: Outbound, *, meta: Meta | None = None, opts: CallOptions | None = None + ) -> ListRootsResult: + """Send a ``roots/list`` request. + + Raises: + MCPError: The peer responded with an error. + NoBackChannelError: No back-channel for server-initiated requests. + """ + result = await self.send_raw_request("roots/list", dump_params(None, meta), opts) + return ListRootsResult.model_validate(result) + + async def ping(self: Outbound, *, meta: Meta | None = None, opts: CallOptions | None = None) -> None: + """Send a ``ping`` request and ignore the result. + + Raises: + MCPError: The peer responded with an error. + NoBackChannelError: No back-channel for server-initiated requests. + """ + await self.send_raw_request("ping", dump_params(None, meta), opts) + + +class Peer(PeerMixin): + """Standalone wrapper that gives any `Outbound` the `PeerMixin` sugar. + + `Context` and `Connection` mix `PeerMixin` in directly; use `Peer` when + you have a bare dispatcher (or any `Outbound`) and want the typed methods + without writing your own host class. + """ + + def __init__(self, outbound: Outbound) -> None: + self._outbound = outbound + + async def send_raw_request( + self, + method: str, + params: Mapping[str, Any] | None, + opts: CallOptions | None = None, + ) -> dict[str, Any]: + return await self._outbound.send_raw_request(method, params, opts) + + async def notify(self, method: str, params: Mapping[str, Any] | None) -> None: + await self._outbound.notify(method, params) diff --git a/src/mcp/shared/transport_context.py b/src/mcp/shared/transport_context.py new file mode 100644 index 0000000000..9346116707 --- /dev/null +++ b/src/mcp/shared/transport_context.py @@ -0,0 +1,38 @@ +"""Transport-specific metadata attached to each inbound message. + +`TransportContext` is the base; each transport defines its own subclass with +whatever fields make sense (HTTP request id, ASGI scope, stdio process handle, +etc.). The dispatcher passes it through opaquely; only the layers above the +dispatcher (`ServerRunner`, `Context`, user handlers) read its concrete fields. +""" + +from collections.abc import Mapping +from dataclasses import dataclass + +__all__ = ["TransportContext"] + + +@dataclass(kw_only=True, frozen=True) +class TransportContext: + """Base transport metadata for an inbound message. + + Subclass per transport and add fields as needed. Instances are immutable. + """ + + kind: str + """Short identifier for the transport (e.g. ``"stdio"``, ``"streamable-http"``).""" + + can_send_request: bool + """Whether the transport can deliver server-initiated requests to the peer. + + ``False`` for stateless HTTP and HTTP with JSON response mode; ``True`` for + stdio, SSE, and stateful streamable HTTP. When ``False``, + `DispatchContext.send_raw_request` raises `NoBackChannelError`. + """ + + headers: Mapping[str, str] | None = None + """Request headers carried by this message, when the transport has them. + + Populated by HTTP-based transports; ``None`` on stdio. Handlers should + None-check before use. + """ diff --git a/src/mcp/types/__init__.py b/src/mcp/types/__init__.py index b442303937..ca1c328939 100644 --- a/src/mcp/types/__init__.py +++ b/src/mcp/types/__init__.py @@ -192,6 +192,7 @@ INVALID_REQUEST, METHOD_NOT_FOUND, PARSE_ERROR, + REQUEST_CANCELLED, REQUEST_TIMEOUT, URL_ELICITATION_REQUIRED, ErrorData, @@ -401,6 +402,7 @@ "INVALID_REQUEST", "METHOD_NOT_FOUND", "PARSE_ERROR", + "REQUEST_CANCELLED", "REQUEST_TIMEOUT", "URL_ELICITATION_REQUIRED", "ErrorData", diff --git a/src/mcp/types/jsonrpc.py b/src/mcp/types/jsonrpc.py index 84304a37c1..14743c33b0 100644 --- a/src/mcp/types/jsonrpc.py +++ b/src/mcp/types/jsonrpc.py @@ -43,6 +43,7 @@ class JSONRPCResponse(BaseModel): # SDK error codes CONNECTION_CLOSED = -32000 REQUEST_TIMEOUT = -32001 +REQUEST_CANCELLED = -32002 # Standard JSON-RPC error codes PARSE_ERROR = -32700 diff --git a/tests/conftest.py b/tests/conftest.py index af7e479932..b83c472135 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,5 +1,16 @@ +import os + import pytest +# OpenTelemetry's `set_tracer_provider` is set-once per process, so the suite +# uses a single span-capture mechanism: logfire's `capfire` fixture (its +# `configure()` swaps span processors on repeat calls rather than re-setting +# the provider). Logfire's default `distributed_tracing=None` emits a +# RuntimeWarning + diagnostic span when incoming W3C trace context is +# extracted; several tests exercise that propagation deliberately, so opt in +# suite-wide. Set before logfire is imported anywhere. +os.environ.setdefault("LOGFIRE_DISTRIBUTED_TRACING", "true") + @pytest.fixture def anyio_backend(): diff --git a/tests/server/conftest.py b/tests/server/conftest.py new file mode 100644 index 0000000000..290ccc957a --- /dev/null +++ b/tests/server/conftest.py @@ -0,0 +1,45 @@ +"""Shared fixtures for server-side tests.""" + +from collections.abc import Iterator + +import pytest +from logfire.testing import CaptureLogfire, TestExporter +from opentelemetry.sdk.trace import ReadableSpan + + +class SpanCapture: + """Thin adapter over logfire's `TestExporter` for asserting on MCP spans. + + `finished()` returns the raw `ReadableSpan` objects emitted by the + ``mcp-python-sdk`` instrumentation scope, filtered to exclude logfire's + synthetic ``pending_span`` markers, so tests can assert directly on + `.name`, `.kind`, `.status`, `.attributes`, `.parent`, `.events`. + """ + + def __init__(self, exporter: TestExporter) -> None: + self._exporter = exporter + + def clear(self) -> None: + self._exporter.clear() + + def finished(self) -> list[ReadableSpan]: + return [ + s + for s in self._exporter.exported_spans + if s.instrumentation_scope is not None + and s.instrumentation_scope.name == "mcp-python-sdk" + and not (s.attributes and s.attributes.get("logfire.span_type") == "pending_span") + ] + + +@pytest.fixture +def spans(capfire: CaptureLogfire) -> Iterator[SpanCapture]: + """In-memory MCP span capture, cleared before and after each test. + + Backed by the project-level `capfire` override (see ``tests/conftest.py``) + so there is a single global tracer provider for the suite. + """ + capture = SpanCapture(capfire.exporter) + capture.clear() + yield capture + capture.clear() diff --git a/tests/server/test_connection.py b/tests/server/test_connection.py new file mode 100644 index 0000000000..ded9dfd6ac --- /dev/null +++ b/tests/server/test_connection.py @@ -0,0 +1,205 @@ +"""Tests for `Connection`. + +`Connection` wraps an `Outbound` (the standalone stream). Its `notify` is +best-effort (never raises); `send_raw_request` is gated on +``has_standalone_channel``. Tested with a stub `Outbound` so we can assert wire +shape and inject failures. +""" + +import logging +from collections.abc import Mapping +from typing import Any + +import anyio +import pytest + +from mcp.server.connection import Connection +from mcp.shared.dispatcher import CallOptions +from mcp.shared.exceptions import NoBackChannelError +from mcp.types import ( + ClientCapabilities, + ElicitationCapability, + EmptyResult, + ListRootsRequest, + ListRootsResult, + PingRequest, + RootsCapability, + SamplingCapability, +) + + +class StubOutbound: + def __init__( + self, *, result: dict[str, Any] | None = None, raise_on_send: type[BaseException] | None = None + ) -> None: + self.requests: list[tuple[str, Mapping[str, Any] | None]] = [] + self.notifications: list[tuple[str, Mapping[str, Any] | None]] = [] + self._result = result if result is not None else {} + self._raise_on_send = raise_on_send + + async def send_raw_request( + self, method: str, params: Mapping[str, Any] | None, opts: CallOptions | None = None + ) -> dict[str, Any]: + self.requests.append((method, params)) + return self._result + + async def notify(self, method: str, params: Mapping[str, Any] | None) -> None: + if self._raise_on_send is not None: + raise self._raise_on_send() + self.notifications.append((method, params)) + + +@pytest.mark.anyio +async def test_connection_notify_forwards_to_outbound(): + out = StubOutbound() + conn = Connection(out, has_standalone_channel=True) + await conn.notify("notifications/message", {"level": "info", "data": "hi"}) + assert out.notifications == [("notifications/message", {"level": "info", "data": "hi"})] + + +@pytest.mark.anyio +async def test_connection_notify_swallows_broken_stream_and_debug_logs(caplog: pytest.LogCaptureFixture): + caplog.set_level(logging.DEBUG, logger="mcp.server.connection") + out = StubOutbound(raise_on_send=anyio.BrokenResourceError) + conn = Connection(out, has_standalone_channel=True) + await conn.notify("notifications/message", {"data": "x"}) # must not raise + assert "stream closed" in caplog.text.lower() + + +@pytest.mark.anyio +async def test_connection_notify_drops_when_no_standalone_channel(caplog: pytest.LogCaptureFixture): + caplog.set_level(logging.DEBUG, logger="mcp.server.connection") + out = StubOutbound() + conn = Connection(out, has_standalone_channel=False) + await conn.notify("notifications/message", {"data": "x"}) # must not raise + assert out.notifications == [] + assert "no standalone channel" in caplog.text.lower() + + +@pytest.mark.anyio +async def test_connection_send_raw_request_raises_nobackchannel_when_no_standalone_channel(): + conn = Connection(StubOutbound(), has_standalone_channel=False) + with pytest.raises(NoBackChannelError): + await conn.send_raw_request("ping", None) + + +@pytest.mark.anyio +async def test_connection_send_raw_request_forwards_when_standalone_channel_present(): + out = StubOutbound() + conn = Connection(out, has_standalone_channel=True) + result = await conn.send_raw_request("ping", None) + assert out.requests == [("ping", None)] + assert result == {} + + +@pytest.mark.anyio +async def test_connection_send_request_with_spec_type_infers_result_type(): + out = StubOutbound(result={"roots": [{"uri": "file:///ws"}]}) + conn = Connection(out, has_standalone_channel=True) + result = await conn.send_request(ListRootsRequest()) + method, _ = out.requests[0] + assert method == "roots/list" + assert isinstance(result, ListRootsResult) + assert str(result.roots[0].uri) == "file:///ws" + + +@pytest.mark.anyio +async def test_connection_send_request_with_result_type_kwarg_validates_custom_type(): + out = StubOutbound(result={}) + conn = Connection(out, has_standalone_channel=True) + result = await conn.send_request(PingRequest(), result_type=EmptyResult) + assert isinstance(result, EmptyResult) + + +@pytest.mark.anyio +async def test_connection_ping_sends_ping_on_standalone(): + out = StubOutbound() + conn = Connection(out, has_standalone_channel=True) + await conn.ping() + assert out.requests == [("ping", None)] + + +@pytest.mark.anyio +async def test_connection_log_sends_logging_message_notification(): + out = StubOutbound() + conn = Connection(out, has_standalone_channel=True) + await conn.log("info", {"k": "v"}, logger="my.logger") + method, params = out.notifications[0] + assert method == "notifications/message" + assert params is not None + assert params["level"] == "info" + assert params["data"] == {"k": "v"} + assert params["logger"] == "my.logger" + + +@pytest.mark.anyio +async def test_connection_log_with_meta_includes_meta_in_params(): + out = StubOutbound() + conn = Connection(out, has_standalone_channel=True) + await conn.log("info", "x", meta={"traceId": "abc"}) + _, params = out.notifications[0] + assert params is not None + assert params["_meta"] == {"traceId": "abc"} + + +@pytest.mark.anyio +async def test_connection_list_changed_notifications_send_correct_methods(): + out = StubOutbound() + conn = Connection(out, has_standalone_channel=True) + await conn.send_tool_list_changed() + await conn.send_prompt_list_changed() + await conn.send_resource_list_changed() + await conn.send_resource_updated("file:///workspace/a.txt") + methods = [m for m, _ in out.notifications] + assert methods == [ + "notifications/tools/list_changed", + "notifications/prompts/list_changed", + "notifications/resources/list_changed", + "notifications/resources/updated", + ] + assert out.notifications[-1][1] == {"uri": "file:///workspace/a.txt"} + + +@pytest.mark.anyio +async def test_connection_send_tool_list_changed_with_meta_includes_meta_only_params(): + out = StubOutbound() + conn = Connection(out, has_standalone_channel=True) + await conn.send_tool_list_changed(meta={"k": 1}) + assert out.notifications == [("notifications/tools/list_changed", {"_meta": {"k": 1}})] + + +def test_connection_check_capability_false_before_initialized(): + conn = Connection(StubOutbound(), has_standalone_channel=True) + assert conn.check_capability(ClientCapabilities(sampling=SamplingCapability())) is False + + +@pytest.mark.parametrize( + ("have", "want", "expected"), + [ + (ClientCapabilities(roots=None), ClientCapabilities(roots=RootsCapability()), False), + ( + ClientCapabilities(roots=RootsCapability(list_changed=False)), + ClientCapabilities(roots=RootsCapability(list_changed=True)), + False, + ), + (ClientCapabilities(sampling=None), ClientCapabilities(sampling=SamplingCapability()), False), + (ClientCapabilities(experimental=None), ClientCapabilities(experimental={"a": {}}), False), + (ClientCapabilities(experimental={"a": {}}), ClientCapabilities(experimental={"b": {}}), False), + (ClientCapabilities(experimental={"a": {}}), ClientCapabilities(experimental={"a": {}}), True), + ], +) +def test_check_capability_per_field_branches(have: ClientCapabilities, want: ClientCapabilities, expected: bool): + conn = Connection(StubOutbound(), has_standalone_channel=True) + conn.client_capabilities = have + assert conn.check_capability(want) is expected + + +def test_connection_check_capability_true_when_client_declares_it(): + conn = Connection(StubOutbound(), has_standalone_channel=True) + conn.client_capabilities = ClientCapabilities( + sampling=SamplingCapability(), roots=RootsCapability(list_changed=True) + ) + conn.initialized.set() + assert conn.check_capability(ClientCapabilities(sampling=SamplingCapability())) is True + assert conn.check_capability(ClientCapabilities(roots=RootsCapability(list_changed=True))) is True + assert conn.check_capability(ClientCapabilities(elicitation=ElicitationCapability())) is False diff --git a/tests/server/test_runner.py b/tests/server/test_runner.py new file mode 100644 index 0000000000..33df234dbe --- /dev/null +++ b/tests/server/test_runner.py @@ -0,0 +1,482 @@ +"""Tests for `ServerRunner`. + +End-to-end over `DirectDispatcher` with a real lowlevel `Server` as the +registry. The `connected_runner` helper starts both sides and (by default) +performs the initialize handshake, so each test exercises only the behaviour +under test. +""" + +from collections.abc import AsyncIterator, Mapping +from contextlib import AbstractAsyncContextManager, asynccontextmanager +from typing import Any, cast + +import anyio +import anyio.lowlevel +import pytest +from opentelemetry.trace import SpanKind, StatusCode + +from mcp.server.connection import Connection +from mcp.server.context import Context +from mcp.server.lowlevel.server import NotificationOptions, Server +from mcp.server.runner import ServerRunner, otel_middleware +from mcp.shared.direct_dispatcher import DirectDispatcher, create_direct_dispatcher_pair +from mcp.shared.dispatcher import DispatchMiddleware +from mcp.shared.exceptions import MCPError +from mcp.types import ( + INTERNAL_ERROR, + INVALID_REQUEST, + LATEST_PROTOCOL_VERSION, + METHOD_NOT_FOUND, + CallToolRequestParams, + ClientCapabilities, + Implementation, + InitializeRequestParams, + ListToolsResult, + NotificationParams, + PaginatedRequestParams, + RequestParams, + SetLevelRequestParams, + Tool, +) + +from ..shared.test_dispatcher import Recorder, echo_handlers +from .conftest import SpanCapture + + +def _initialize_params() -> dict[str, Any]: + return InitializeRequestParams( + protocol_version=LATEST_PROTOCOL_VERSION, + capabilities=ClientCapabilities(), + client_info=Implementation(name="test-client", version="1.0"), + ).model_dump(by_alias=True, exclude_none=True) + + +_seen_ctx: list[Context[Any]] = [] +SrvT = Server[dict[str, Any]] + + +@pytest.fixture +def server() -> SrvT: + """A lowlevel Server with one tools/list handler registered.""" + _seen_ctx.clear() + + async def list_tools(ctx: Any, params: PaginatedRequestParams | None) -> ListToolsResult: + # ctx is `Any` while `on_*` kwargs are typed against `ServerRequestContext` + # but `ServerRunner` passes the new `Context`; tightens once the alias lands. + _seen_ctx.append(ctx) + return ListToolsResult(tools=[Tool(name="t", input_schema={"type": "object"})]) + + return Server(name="test-server", version="0.0.1", on_list_tools=list_tools) + + +@asynccontextmanager +async def connected_runner( + server: SrvT, + *, + initialized: bool = True, + stateless: bool = False, + has_standalone_channel: bool = True, + session_id: str | None = None, + headers: Mapping[str, str] | None = None, + dispatch_middleware: list[DispatchMiddleware] | None = None, +) -> AsyncIterator[tuple[DirectDispatcher, ServerRunner[dict[str, Any]]]]: + """Yield ``(client, runner)`` running over an in-memory dispatcher pair. + + Starts the client (echo handlers) and `runner.run()` in a task group, wraps + the body in ``anyio.fail_after(5)``, and cancels on exit. When + ``initialized`` is true the helper performs the real ``initialize`` request + before yielding, so tests start past the init-gate via the public path. + """ + client, server_d = create_direct_dispatcher_pair(headers=headers) + runner = ServerRunner( + server=server, + dispatcher=server_d, + lifespan_state={}, + has_standalone_channel=has_standalone_channel, + session_id=session_id, + stateless=stateless, + dispatch_middleware=dispatch_middleware or [], + ) + c_req, c_notify = echo_handlers(Recorder()) + body_exc: BaseException | None = None + async with anyio.create_task_group() as tg: + await tg.start(client.run, c_req, c_notify) + await tg.start(runner.run) + try: + with anyio.fail_after(5): + if initialized: + await client.send_raw_request("initialize", _initialize_params()) + yield client, runner + except BaseException as e: + # Capture and re-raise outside the task group so test failures + # surface as the original exception, not an ExceptionGroup wrapper. + body_exc = e + client.close() + server_d.close() + if body_exc is not None: + raise body_exc + + +@pytest.mark.anyio +async def test_connected_runner_propagates_body_exception_unwrapped(server: SrvT): + """The harness re-raises body exceptions as-is, not as ``ExceptionGroup``.""" + with pytest.raises(RuntimeError, match="boom"): + async with connected_runner(server): + raise RuntimeError("boom") + + +@pytest.mark.anyio +async def test_runner_handles_initialize_and_populates_connection(server: SrvT): + async with connected_runner(server, initialized=False) as (client, runner): + result = await client.send_raw_request("initialize", _initialize_params()) + assert result["serverInfo"]["name"] == "test-server" + assert "tools" in result["capabilities"] + assert runner.connection.client_info is not None + assert runner.connection.client_info.name == "test-client" + assert runner.connection.protocol_version == LATEST_PROTOCOL_VERSION + assert runner._initialized is True + + +@pytest.mark.anyio +async def test_runner_gates_requests_before_initialize(server: SrvT): + async with connected_runner(server, initialized=False) as (client, _): + with pytest.raises(MCPError) as exc: + await client.send_raw_request("tools/list", None) + assert exc.value.error.code == INVALID_REQUEST + # ping is exempt from the gate + assert await client.send_raw_request("ping", None) == {} + + +@pytest.mark.anyio +async def test_runner_routes_to_handler_and_builds_context(server: SrvT): + async with connected_runner(server) as (client, _): + result = await client.send_raw_request("tools/list", None) + assert result["tools"][0]["name"] == "t" + ctx = _seen_ctx[0] + assert isinstance(ctx, Context) + assert ctx.lifespan == {} + assert isinstance(ctx.connection, Connection) + assert ctx.transport.kind == "direct" + + +@pytest.mark.anyio +async def test_runner_unknown_method_raises_method_not_found(server: SrvT): + async with connected_runner(server) as (client, _): + with pytest.raises(MCPError) as exc: + await client.send_raw_request("nonexistent/method", None) + assert exc.value.error.code == METHOD_NOT_FOUND + + +@pytest.mark.anyio +async def test_runner_on_notify_initialized_sets_flag_and_connection_event(server: SrvT): + async with connected_runner(server, initialized=False) as (client, runner): + await client.notify("notifications/initialized", None) + await runner.connection.initialized.wait() + assert runner._initialized is True + + +@pytest.mark.anyio +async def test_runner_on_notify_routes_to_registered_handler(server: SrvT): + seen: list[tuple[Any, Any]] = [] + + async def on_roots_changed(ctx: Any, params: Any) -> None: + seen.append((ctx, params)) + + server.add_notification_handler("notifications/roots/list_changed", NotificationParams, on_roots_changed) + async with connected_runner(server) as (client, _): + await client.notify("notifications/roots/list_changed", None) + # DirectDispatcher delivers synchronously; one yield is enough. + await anyio.lowlevel.checkpoint() + assert len(seen) == 1 + assert isinstance(seen[0][0], Context) + + +@pytest.mark.anyio +async def test_runner_on_notify_drops_before_init_and_unknown_methods(server: SrvT): + async with connected_runner(server, initialized=False) as (client, _): + await client.notify("notifications/roots/list_changed", None) # before init: dropped + await client.notify("notifications/initialized", None) + await client.notify("notifications/unknown", None) # no handler: dropped + # No exception raised; both drops are silent. + + +@pytest.mark.anyio +async def test_runner_dispatch_middleware_wraps_everything_including_initialize(server: SrvT): + seen_methods: list[str] = [] + + def trace_mw(next_on_request: Any) -> Any: + async def wrapped(dctx: Any, method: str, params: Any) -> Any: + seen_methods.append(method) + return await next_on_request(dctx, method, params) + + return wrapped + + async with connected_runner(server, dispatch_middleware=[trace_mw]) as (client, _): + await client.send_raw_request("tools/list", None) + assert seen_methods == ["initialize", "tools/list"] + + +@pytest.mark.anyio +async def test_runner_server_middleware_wraps_handlers_but_not_initialize(server: SrvT): + seen_methods: list[str] = [] + + async def ctx_mw(ctx: Any, method: str, params: Any, call_next: Any) -> Any: + seen_methods.append(method) + return await call_next() + + server.middleware.append(ctx_mw) + async with connected_runner(server) as (client, _): + await client.send_raw_request("ping", None) + await client.send_raw_request("tools/list", None) + # initialize (sent by the helper) NOT wrapped; ping and tools/list ARE. + assert seen_methods == ["ping", "tools/list"] + + +@pytest.mark.anyio +async def test_runner_server_middleware_runs_outermost_first(server: SrvT): + order: list[str] = [] + + def make_mw(tag: str) -> Any: + async def mw(ctx: Any, method: str, params: Any, call_next: Any) -> Any: + order.append(f"{tag}-in") + result = await call_next() + order.append(f"{tag}-out") + return result + + return mw + + server.middleware.extend([make_mw("a"), make_mw("b")]) + async with connected_runner(server) as (client, _): + await client.send_raw_request("tools/list", None) + assert order == ["a-in", "b-in", "b-out", "a-out"] + + +@pytest.mark.anyio +async def test_runner_handler_returning_none_yields_empty_result(server: SrvT): + async def set_level(ctx: Any, params: Any) -> None: + return None + + server.add_request_handler("logging/setLevel", SetLevelRequestParams, set_level) + async with connected_runner(server) as (client, _): + result = await client.send_raw_request("logging/setLevel", {"level": "info"}) + assert result == {} + + +@pytest.mark.anyio +async def test_runner_handler_returning_unsupported_type_surfaces_as_internal_error(server: SrvT): + async def bad_return(ctx: Any, params: Any) -> int: + return 42 + + # cast: deliberately registering a handler with a bad return type to + # exercise the runtime check; pyright would (correctly) reject it otherwise. + server.add_request_handler("tools/list", PaginatedRequestParams, cast(Any, bad_return)) + async with connected_runner(server) as (client, _): + with pytest.raises(MCPError) as exc: + await client.send_raw_request("tools/list", None) + assert exc.value.error.code == INTERNAL_ERROR + assert "int" in exc.value.error.message + + +@pytest.mark.anyio +async def test_runner_stateless_skips_init_gate(server: SrvT): + async with connected_runner(server, initialized=False, stateless=True, has_standalone_channel=False) as (client, _): + result = await client.send_raw_request("tools/list", None) + assert result["tools"][0]["name"] == "t" + + +@pytest.mark.anyio +async def test_server_add_request_handler_routes_custom_method_with_validated_params(server: SrvT): + class GreetParams(RequestParams): + name: str + + received: list[GreetParams] = [] + + async def greet(ctx: Any, params: GreetParams) -> dict[str, Any]: + received.append(params) + return {"greeting": f"hello {params.name}"} + + server.add_request_handler("custom/greet", GreetParams, greet) + async with connected_runner(server) as (client, _): + result = await client.send_raw_request("custom/greet", {"name": "world"}) + assert result == {"greeting": "hello world"} + assert isinstance(received[0], GreetParams) + assert received[0].name == "world" + + +@pytest.mark.anyio +async def test_server_capabilities_reflects_ctor_options_in_initialize_result(): + async def list_tools(ctx: Any, params: PaginatedRequestParams | None) -> ListToolsResult: + raise NotImplementedError + + server: SrvT = Server( + name="caps-test", + on_list_tools=list_tools, + notification_options=NotificationOptions(tools_changed=True), + experimental_capabilities={"ext": {"k": "v"}}, + ) + async with connected_runner(server, initialized=False) as (client, _): + result = await client.send_raw_request("initialize", _initialize_params()) + assert result["capabilities"]["tools"]["listChanged"] is True + assert result["capabilities"]["experimental"] == {"ext": {"k": "v"}} + + +@pytest.mark.anyio +async def test_otel_middleware_emits_server_span_with_method_and_target(server: SrvT, spans: SpanCapture): + async def call_tool(ctx: Any, params: Any) -> dict[str, Any]: + return {"content": [], "isError": False} + + server.add_request_handler("tools/call", CallToolRequestParams, call_tool) + async with connected_runner(server, dispatch_middleware=[otel_middleware]) as (client, _): + spans.clear() + result = await client.send_raw_request("tools/call", {"name": "mytool", "arguments": {}}) + assert result == {"content": [], "isError": False} + [span] = spans.finished() + assert span.name == "MCP handle tools/call mytool" + assert span.kind == SpanKind.SERVER + assert span.attributes is not None + assert span.attributes["mcp.method.name"] == "tools/call" + assert span.status.status_code == StatusCode.UNSET + + +@pytest.mark.anyio +async def test_otel_middleware_extracts_parent_context_from_meta(server: SrvT, spans: SpanCapture): + parent_span_id = "b7ad6b7169203331" + traceparent = f"00-0af7651916cd43dd8448eb211c80319c-{parent_span_id}-01" + async with connected_runner(server, dispatch_middleware=[otel_middleware]) as (client, _): + spans.clear() + await client.send_raw_request("tools/list", {"_meta": {"traceparent": traceparent}}) + [span] = spans.finished() + assert span.parent is not None + assert format(span.parent.span_id, "016x") == parent_span_id + assert span.context is not None + assert format(span.context.trace_id, "032x") == "0af7651916cd43dd8448eb211c80319c" + + +@pytest.mark.anyio +async def test_otel_middleware_records_error_status_on_mcp_error(server: SrvT, spans: SpanCapture): + async with connected_runner(server, dispatch_middleware=[otel_middleware]) as (client, _): + spans.clear() + with pytest.raises(MCPError) as exc: + await client.send_raw_request("nonexistent/method", None) + assert exc.value.error.code == METHOD_NOT_FOUND + [span] = spans.finished() + assert span.status.status_code == StatusCode.ERROR + assert span.status.description == "Method not found: nonexistent/method" + # MCPError is a protocol-level response, not a crash — no traceback event. + assert not [e for e in span.events if e.name == "exception"] + + +@pytest.mark.anyio +async def test_otel_middleware_records_error_status_on_handler_exception(server: SrvT, spans: SpanCapture): + async def failing(ctx: Any, params: Any) -> Any: + raise ValueError("handler blew up") + + server.add_request_handler("tools/list", PaginatedRequestParams, failing) + async with connected_runner(server, dispatch_middleware=[otel_middleware]) as (client, _): + spans.clear() + with pytest.raises(MCPError) as exc: + await client.send_raw_request("tools/list", None) + assert exc.value.error.code == INTERNAL_ERROR + [span] = spans.finished() + assert span.status.status_code == StatusCode.ERROR + assert span.status.description == "handler blew up" + [event] = [e for e in span.events if e.name == "exception"] + assert event.attributes is not None + assert event.attributes["exception.type"] == "ValueError" + + +@pytest.mark.anyio +async def test_connection_state_persists_across_requests_on_same_connection(server: SrvT) -> None: + async def count(ctx: Any, params: PaginatedRequestParams | None) -> ListToolsResult: + ctx.connection.state["n"] = ctx.connection.state.get("n", 0) + 1 + return ListToolsResult(tools=[]) + + server.add_request_handler("tools/list", PaginatedRequestParams, count) + async with connected_runner(server) as (client, runner): + await client.send_raw_request("tools/list", None) + await client.send_raw_request("tools/list", None) + assert runner.connection.state == {"n": 2} + + +@pytest.mark.anyio +async def test_connection_exit_stack_runs_pushed_callback_after_close(server: SrvT) -> None: + cleaned: list[str] = [] + + async def push(ctx: Any, params: PaginatedRequestParams | None) -> ListToolsResult: + async def _cleanup() -> None: + cleaned.append("done") + + ctx.connection.exit_stack.push_async_callback(_cleanup) + return ListToolsResult(tools=[]) + + server.add_request_handler("tools/list", PaginatedRequestParams, push) + async with connected_runner(server) as (client, _runner): + await client.send_raw_request("tools/list", None) + assert cleaned == [] + assert cleaned == ["done"] + + +@pytest.mark.anyio +async def test_connection_exit_stack_unwinds_entered_context_manager_after_close(server: SrvT) -> None: + events: list[str] = [] + + class _Tracker(AbstractAsyncContextManager[str]): + async def __aenter__(self) -> str: + events.append("enter") + return "resource" + + async def __aexit__(self, *exc: object) -> None: + events.append("exit") + + async def acquire(ctx: Any, params: PaginatedRequestParams | None) -> ListToolsResult: + res = await ctx.connection.exit_stack.enter_async_context(_Tracker()) + ctx.connection.state["res"] = res + return ListToolsResult(tools=[]) + + server.add_request_handler("tools/list", PaginatedRequestParams, acquire) + async with connected_runner(server) as (client, runner): + await client.send_raw_request("tools/list", None) + assert events == ["enter"] + assert runner.connection.state["res"] == "resource" + assert events == ["enter", "exit"] + + +@pytest.mark.anyio +async def test_connection_exit_stack_runs_callbacks_lifo_after_handler_error(server: SrvT) -> None: + cleaned: list[int] = [] + + async def push_then_fail(ctx: Any, params: PaginatedRequestParams | None) -> ListToolsResult: + for i in (1, 2, 3): + ctx.connection.exit_stack.push_async_callback(_append, i) + raise RuntimeError("boom") + + async def _append(i: int) -> None: + cleaned.append(i) + + server.add_request_handler("tools/list", PaginatedRequestParams, push_then_fail) + async with connected_runner(server) as (client, _runner): + with pytest.raises(MCPError) as ei: + await client.send_raw_request("tools/list", None) + assert ei.value.error.code == INTERNAL_ERROR + assert cleaned == [] + assert cleaned == [3, 2, 1] + + +@pytest.mark.anyio +async def test_context_session_id_and_headers_expose_connection_and_transport(server: SrvT) -> None: + async with connected_runner(server, session_id="sess-abc", headers={"authorization": "Bearer t"}) as (client, _r): + await client.send_raw_request("tools/list", None) + [ctx] = _seen_ctx + assert ctx.session_id == "sess-abc" + assert ctx.session_id == ctx.connection.session_id + assert ctx.headers == {"authorization": "Bearer t"} + assert ctx.headers is ctx.transport.headers + + +@pytest.mark.anyio +async def test_context_session_id_and_headers_default_none(server: SrvT) -> None: + async with connected_runner(server) as (client, _r): + await client.send_raw_request("tools/list", None) + [ctx] = _seen_ctx + assert ctx.session_id is None + assert ctx.headers is None diff --git a/tests/server/test_server_context.py b/tests/server/test_server_context.py new file mode 100644 index 0000000000..43c2069a87 --- /dev/null +++ b/tests/server/test_server_context.py @@ -0,0 +1,156 @@ +"""Tests for the server-side `Context`. + +`Context` composes `BaseContext` (forwarding to a `DispatchContext`) with +`PeerMixin` (typed sample/elicit/roots/ping) plus `lifespan` and `connection`. +End-to-end tested over `DirectDispatcher`. +""" + +from collections.abc import Mapping +from dataclasses import dataclass +from typing import Any + +import anyio +import pytest + +from mcp.server.connection import Connection +from mcp.server.context import Context +from mcp.shared.dispatcher import DispatchContext +from mcp.shared.transport_context import TransportContext +from mcp.types import CreateMessageResult, ListRootsRequest, ListRootsResult, SamplingMessage, TextContent + +from ..shared.conftest import direct_pair +from ..shared.test_dispatcher import Recorder, echo_handlers, running_pair + +DCtx = DispatchContext[TransportContext] + + +@dataclass +class _Lifespan: + name: str + + +@pytest.mark.anyio +async def test_context_exposes_lifespan_and_connection_and_forwards_base_context(): + captured: list[Context[_Lifespan]] = [] + conn = Connection.__new__(Connection) # placeholder until running_pair gives us the dispatcher + + async def server_on_request(dctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: + ctx: Context[_Lifespan] = Context(dctx, lifespan=_Lifespan("app"), connection=conn) + captured.append(ctx) + return {} + + async with running_pair(direct_pair, server_on_request=server_on_request) as (client, server, *_): + # Now we have the server dispatcher; build the real Connection bound to it. + conn.__init__(server, has_standalone_channel=True) + with anyio.fail_after(5): + await client.send_raw_request("t", None) + ctx = captured[0] + assert ctx.lifespan.name == "app" + assert ctx.connection is conn + assert ctx.transport.kind == "direct" + assert ctx.can_send_request is True + + +@pytest.mark.anyio +async def test_context_sample_round_trips_via_peer_mixin_on_base_context_outbound(): + crec = Recorder() + + async def client_on_request(dctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: + crec.requests.append((method, params)) + return {"role": "assistant", "content": {"type": "text", "text": "ok"}, "model": "m"} + + results: list[CreateMessageResult] = [] + + async def server_on_request(dctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: + ctx: Context[_Lifespan] = Context( + dctx, lifespan=_Lifespan("app"), connection=Connection(dctx, has_standalone_channel=True) + ) + results.append( + await ctx.sample( + [SamplingMessage(role="user", content=TextContent(type="text", text="hi"))], + max_tokens=5, + ) + ) + return {} + + async with running_pair( + direct_pair, + server_on_request=server_on_request, + client_on_request=client_on_request, + ) as (client, *_): + with anyio.fail_after(5): + await client.send_raw_request("tools/call", None) + assert crec.requests[0][0] == "sampling/createMessage" + assert isinstance(results[0], CreateMessageResult) + + +@pytest.mark.anyio +async def test_context_send_request_with_spec_type_infers_result_via_typed_mixin(): + async def client_on_request(dctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: + return {"roots": []} + + results: list[ListRootsResult] = [] + + async def server_on_request(dctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: + ctx: Context[_Lifespan] = Context( + dctx, lifespan=_Lifespan("app"), connection=Connection(dctx, has_standalone_channel=True) + ) + results.append(await ctx.send_request(ListRootsRequest())) + return {} + + async with running_pair(direct_pair, server_on_request=server_on_request, client_on_request=client_on_request) as ( + client, + *_, + ): + with anyio.fail_after(5): + await client.send_raw_request("t", None) + assert isinstance(results[0], ListRootsResult) + + +@pytest.mark.anyio +async def test_context_log_sends_request_scoped_message_notification(): + crec = Recorder() + _, c_notify = echo_handlers(crec) + + async def server_on_request(dctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: + ctx: Context[_Lifespan] = Context( + dctx, lifespan=_Lifespan("app"), connection=Connection(dctx, has_standalone_channel=True) + ) + await ctx.log("debug", "hello") + return {} + + async with running_pair(direct_pair, server_on_request=server_on_request, client_on_notify=c_notify) as ( + client, + *_, + ): + with anyio.fail_after(5): + await client.send_raw_request("t", None) + await crec.notified.wait() + method, params = crec.notifications[0] + assert method == "notifications/message" + assert params is not None and params["level"] == "debug" and params["data"] == "hello" + + +@pytest.mark.anyio +async def test_context_log_includes_logger_and_meta_when_supplied(): + crec = Recorder() + _, c_notify = echo_handlers(crec) + + async def server_on_request(dctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: + ctx: Context[_Lifespan] = Context( + dctx, lifespan=_Lifespan("app"), connection=Connection(dctx, has_standalone_channel=True) + ) + await ctx.log("info", "x", logger="my.log", meta={"traceId": "t"}) + return {} + + async with running_pair(direct_pair, server_on_request=server_on_request, client_on_notify=c_notify) as ( + client, + *_, + ): + with anyio.fail_after(5): + await client.send_raw_request("t", None) + await crec.notified.wait() + _, params = crec.notifications[0] + assert params is not None + assert params["logger"] == "my.log" + assert params["_meta"] == {"traceId": "t"} diff --git a/tests/shared/conftest.py b/tests/shared/conftest.py new file mode 100644 index 0000000000..1222c05aba --- /dev/null +++ b/tests/shared/conftest.py @@ -0,0 +1,61 @@ +"""Shared fixtures for `Dispatcher` contract tests. + +The `pair_factory` fixture parametrizes contract tests over every `Dispatcher` +implementation, so the same behavioral assertions run against `DirectDispatcher` +(in-memory) and `JSONRPCDispatcher` (over crossed anyio memory streams). +""" + +from collections.abc import Callable + +import anyio +import pytest + +from mcp.shared.direct_dispatcher import create_direct_dispatcher_pair +from mcp.shared.dispatcher import Dispatcher +from mcp.shared.jsonrpc_dispatcher import JSONRPCDispatcher +from mcp.shared.message import SessionMessage +from mcp.shared.transport_context import TransportContext + +DispatcherTriple = tuple[Dispatcher[TransportContext], Dispatcher[TransportContext], Callable[[], None]] +PairFactory = Callable[..., DispatcherTriple] + + +def direct_pair(*, can_send_request: bool = True) -> DispatcherTriple: + client, server = create_direct_dispatcher_pair(can_send_request=can_send_request) + + def close() -> None: + client.close() + server.close() + + return client, server, close + + +def jsonrpc_pair(*, can_send_request: bool = True) -> DispatcherTriple: + """Two `JSONRPCDispatcher`s wired over crossed in-memory streams.""" + c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) + s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) + + def builder(_rid: object, _meta: object) -> TransportContext: + return TransportContext(kind="jsonrpc", can_send_request=can_send_request) + + client: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(s2c_recv, c2s_send, transport_builder=builder) + server: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(c2s_recv, s2c_send, transport_builder=builder) + + def close() -> None: + for s in (c2s_send, c2s_recv, s2c_send, s2c_recv): + s.close() + + return client, server, close + + +@pytest.fixture( + params=[ + pytest.param(direct_pair, id="direct"), + pytest.param(jsonrpc_pair, id="jsonrpc"), + ] +) +def pair_factory(request: pytest.FixtureRequest) -> PairFactory: + return request.param + + +__all__ = ["PairFactory", "direct_pair", "jsonrpc_pair"] diff --git a/tests/shared/test_context.py b/tests/shared/test_context.py new file mode 100644 index 0000000000..882f90bfab --- /dev/null +++ b/tests/shared/test_context.py @@ -0,0 +1,115 @@ +"""Tests for `BaseContext`. + +`BaseContext` is composition over a `DispatchContext` — it forwards +``transport``/``cancel_requested``/``send_raw_request``/``notify``/``progress`` +and adds ``meta``. It must satisfy `Outbound` so `PeerMixin` works on it. +""" + +from collections.abc import Mapping +from typing import Any + +import anyio +import pytest + +from mcp.shared.context import BaseContext +from mcp.shared.dispatcher import DispatchContext +from mcp.shared.peer import Peer +from mcp.shared.transport_context import TransportContext + +from .conftest import direct_pair +from .test_dispatcher import Recorder, echo_handlers, running_pair + +DCtx = DispatchContext[TransportContext] + + +@pytest.mark.anyio +async def test_base_context_forwards_transport_and_cancel_requested(): + captured: list[BaseContext[TransportContext]] = [] + + async def server_on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: + bctx = BaseContext(ctx) + captured.append(bctx) + return {} + + async with running_pair(direct_pair, server_on_request=server_on_request) as (client, *_): + with anyio.fail_after(5): + await client.send_raw_request("t", None) + bctx = captured[0] + assert bctx.transport.kind == "direct" + assert isinstance(bctx.cancel_requested, anyio.Event) + assert bctx.can_send_request is True + assert bctx.meta is None + + +@pytest.mark.anyio +async def test_base_context_send_raw_request_and_notify_forward_to_dispatch_context(): + crec = Recorder() + c_req, c_notify = echo_handlers(crec) + + async def server_on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: + bctx = BaseContext(ctx) + sample = await bctx.send_raw_request("sampling/createMessage", {"x": 1}) + await bctx.notify("notifications/message", {"level": "info"}) + return {"sample": sample} + + async with running_pair( + direct_pair, + server_on_request=server_on_request, + client_on_request=c_req, + client_on_notify=c_notify, + ) as (client, *_): + with anyio.fail_after(5): + result = await client.send_raw_request("tools/call", None) + await crec.notified.wait() + assert crec.requests == [("sampling/createMessage", {"x": 1})] + assert crec.notifications == [("notifications/message", {"level": "info"})] + assert result["sample"] == {"echoed": "sampling/createMessage", "params": {"x": 1}} + + +@pytest.mark.anyio +async def test_base_context_report_progress_invokes_caller_on_progress(): + received: list[tuple[float, float | None, str | None]] = [] + + async def on_progress(progress: float, total: float | None, message: str | None) -> None: + received.append((progress, total, message)) + + async def server_on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: + bctx = BaseContext(ctx) + await bctx.report_progress(0.5, total=1.0, message="halfway") + return {} + + async with running_pair(direct_pair, server_on_request=server_on_request) as (client, *_): + with anyio.fail_after(5): + await client.send_raw_request("t", None, {"on_progress": on_progress}) + assert received == [(0.5, 1.0, "halfway")] + + +@pytest.mark.anyio +async def test_base_context_satisfies_outbound_so_peer_mixin_works(): + """Wrapping a BaseContext in Peer proves it satisfies Outbound structurally.""" + + async def server_on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: + bctx = BaseContext(ctx) + await Peer(bctx).ping() + return {} + + crec = Recorder() + c_req, c_notify = echo_handlers(crec) + async with running_pair( + direct_pair, server_on_request=server_on_request, client_on_request=c_req, client_on_notify=c_notify + ) as (client, *_): + with anyio.fail_after(5): + await client.send_raw_request("t", None) + assert crec.requests == [("ping", None)] + + +@pytest.mark.anyio +async def test_base_context_meta_holds_supplied_request_params_meta(): + async def server_on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: + bctx = BaseContext(ctx, meta={"progressToken": "abc"}) + assert bctx.meta is not None and bctx.meta.get("progressToken") == "abc" + return {} + + async with running_pair(direct_pair, server_on_request=server_on_request) as (client, *_): + with anyio.fail_after(5): + await client.send_raw_request("t", None) diff --git a/tests/shared/test_dispatcher.py b/tests/shared/test_dispatcher.py new file mode 100644 index 0000000000..bdadd4cdae --- /dev/null +++ b/tests/shared/test_dispatcher.py @@ -0,0 +1,258 @@ +"""Behavioral tests for the Dispatcher Protocol. + +The contract tests are parametrized over every `Dispatcher` implementation via +the `pair_factory` fixture (see ``conftest.py``); they must pass for both +`DirectDispatcher` and `JSONRPCDispatcher`. Implementation-specific tests pass +a concrete factory directly. +""" + +from collections.abc import AsyncIterator, Mapping +from contextlib import asynccontextmanager +from typing import TYPE_CHECKING, Any + +import anyio +import pytest + +from mcp.shared.direct_dispatcher import DirectDispatcher, create_direct_dispatcher_pair +from mcp.shared.dispatcher import DispatchContext, Dispatcher, OnNotify, OnRequest, Outbound +from mcp.shared.exceptions import MCPError +from mcp.shared.transport_context import TransportContext +from mcp.types import INTERNAL_ERROR, INVALID_PARAMS, INVALID_REQUEST, REQUEST_TIMEOUT + +from .conftest import PairFactory, direct_pair + + +class Recorder: + def __init__(self) -> None: + self.requests: list[tuple[str, Mapping[str, Any] | None]] = [] + self.notifications: list[tuple[str, Mapping[str, Any] | None]] = [] + self.contexts: list[DispatchContext[TransportContext]] = [] + self.notified = anyio.Event() + + +def echo_handlers(recorder: Recorder) -> tuple[OnRequest, OnNotify]: + async def on_request( + ctx: DispatchContext[TransportContext], method: str, params: Mapping[str, Any] | None + ) -> dict[str, Any]: + recorder.requests.append((method, params)) + recorder.contexts.append(ctx) + return {"echoed": method, "params": dict(params or {})} + + async def on_notify(ctx: DispatchContext[TransportContext], method: str, params: Mapping[str, Any] | None) -> None: + recorder.notifications.append((method, params)) + recorder.notified.set() + + return on_request, on_notify + + +@asynccontextmanager +async def running_pair( + factory: PairFactory, + *, + server_on_request: OnRequest | None = None, + server_on_notify: OnNotify | None = None, + client_on_request: OnRequest | None = None, + client_on_notify: OnNotify | None = None, + can_send_request: bool = True, +) -> AsyncIterator[tuple[Dispatcher[TransportContext], Dispatcher[TransportContext], Recorder, Recorder]]: + """Yield ``(client, server, client_recorder, server_recorder)`` with both ``run()`` loops live.""" + client, server, close = factory(can_send_request=can_send_request) + client_rec, server_rec = Recorder(), Recorder() + c_req, c_notify = echo_handlers(client_rec) + s_req, s_notify = echo_handlers(server_rec) + try: + async with anyio.create_task_group() as tg: + await tg.start(client.run, client_on_request or c_req, client_on_notify or c_notify) + await tg.start(server.run, server_on_request or s_req, server_on_notify or s_notify) + try: + yield client, server, client_rec, server_rec + finally: + tg.cancel_scope.cancel() + finally: + close() + + +@pytest.mark.anyio +async def test_send_raw_request_returns_result_from_peer_on_request(pair_factory: PairFactory): + async with running_pair(pair_factory) as (client, _server, _crec, srec): + with anyio.fail_after(5): + result = await client.send_raw_request("tools/list", {"cursor": "abc"}) + assert result == {"echoed": "tools/list", "params": {"cursor": "abc"}} + assert srec.requests == [("tools/list", {"cursor": "abc"})] + + +@pytest.mark.anyio +async def test_send_raw_request_reraises_mcperror_from_handler_unchanged(pair_factory: PairFactory): + async def on_request( + ctx: DispatchContext[TransportContext], method: str, params: Mapping[str, Any] | None + ) -> dict[str, Any]: + raise MCPError(code=INVALID_PARAMS, message="bad cursor") + + async with running_pair(pair_factory, server_on_request=on_request) as (client, *_): + with anyio.fail_after(5), pytest.raises(MCPError) as exc: + await client.send_raw_request("tools/list", {}) + assert exc.value.error.code == INVALID_PARAMS + assert exc.value.error.message == "bad cursor" + + +@pytest.mark.anyio +async def test_send_raw_request_with_timeout_raises_mcperror_request_timeout(pair_factory: PairFactory): + async def on_request( + ctx: DispatchContext[TransportContext], method: str, params: Mapping[str, Any] | None + ) -> dict[str, Any]: + await anyio.sleep_forever() + raise NotImplementedError + + async with running_pair(pair_factory, server_on_request=on_request) as (client, *_): + with anyio.fail_after(5), pytest.raises(MCPError) as exc: + await client.send_raw_request("slow", None, {"timeout": 0}) + assert exc.value.error.code == REQUEST_TIMEOUT + + +@pytest.mark.anyio +async def test_notify_invokes_peer_on_notify(pair_factory: PairFactory): + async with running_pair(pair_factory) as (client, _server, _crec, srec): + with anyio.fail_after(5): + await client.notify("notifications/initialized", {"v": 1}) + await srec.notified.wait() + assert srec.notifications == [("notifications/initialized", {"v": 1})] + + +@pytest.mark.anyio +async def test_ctx_send_raw_request_round_trips_to_calling_side(pair_factory: PairFactory): + """A handler's ctx.send_raw_request reaches the side that made the inbound request.""" + + async def server_on_request( + ctx: DispatchContext[TransportContext], method: str, params: Mapping[str, Any] | None + ) -> dict[str, Any]: + sample = await ctx.send_raw_request("sampling/createMessage", {"prompt": "hi"}) + return {"sampled": sample} + + async with running_pair(pair_factory, server_on_request=server_on_request) as (client, _server, crec, _srec): + with anyio.fail_after(5): + result = await client.send_raw_request("tools/call", None) + assert crec.requests == [("sampling/createMessage", {"prompt": "hi"})] + assert result == {"sampled": {"echoed": "sampling/createMessage", "params": {"prompt": "hi"}}} + + +@pytest.mark.anyio +async def test_ctx_send_raw_request_raises_nobackchannelerror_when_transport_disallows(pair_factory: PairFactory): + async def server_on_request( + ctx: DispatchContext[TransportContext], method: str, params: Mapping[str, Any] | None + ) -> dict[str, Any]: + return await ctx.send_raw_request("sampling/createMessage", None) + + async with running_pair(pair_factory, server_on_request=server_on_request, can_send_request=False) as (client, *_): + with anyio.fail_after(5), pytest.raises(MCPError) as exc: + await client.send_raw_request("tools/call", None) + assert exc.value.error.code == INVALID_REQUEST + + +@pytest.mark.anyio +async def test_ctx_notify_invokes_calling_side_on_notify(pair_factory: PairFactory): + async def server_on_request( + ctx: DispatchContext[TransportContext], method: str, params: Mapping[str, Any] | None + ) -> dict[str, Any]: + await ctx.notify("notifications/message", {"level": "info"}) + return {} + + async with running_pair(pair_factory, server_on_request=server_on_request) as (client, _server, crec, _srec): + with anyio.fail_after(5): + await client.send_raw_request("tools/call", None) + await crec.notified.wait() + assert crec.notifications == [("notifications/message", {"level": "info"})] + + +@pytest.mark.anyio +async def test_ctx_progress_invokes_caller_on_progress_callback(pair_factory: PairFactory): + async def server_on_request( + ctx: DispatchContext[TransportContext], method: str, params: Mapping[str, Any] | None + ) -> dict[str, Any]: + await ctx.progress(0.5, total=1.0, message="halfway") + return {} + + received: list[tuple[float, float | None, str | None]] = [] + + async def on_progress(progress: float, total: float | None, message: str | None) -> None: + received.append((progress, total, message)) + + async with running_pair(pair_factory, server_on_request=server_on_request) as (client, *_): + with anyio.fail_after(5): + await client.send_raw_request("tools/call", None, {"on_progress": on_progress}) + assert received == [(0.5, 1.0, "halfway")] + + +@pytest.mark.anyio +async def test_ctx_progress_is_noop_when_caller_supplied_no_callback(pair_factory: PairFactory): + async def server_on_request( + ctx: DispatchContext[TransportContext], method: str, params: Mapping[str, Any] | None + ) -> dict[str, Any]: + await ctx.progress(0.5) + return {"ok": True} + + async with running_pair(pair_factory, server_on_request=server_on_request) as (client, *_): + with anyio.fail_after(5): + result = await client.send_raw_request("tools/call", None) + assert result == {"ok": True} + + +@pytest.mark.anyio +async def test_direct_send_raw_request_wraps_non_mcperror_exception_as_internal_error_with_cause(): + """DirectDispatcher-specific: the original exception is chained via __cause__.""" + + async def on_request( + ctx: DispatchContext[TransportContext], method: str, params: Mapping[str, Any] | None + ) -> dict[str, Any]: + raise ValueError("oops") + + async with running_pair(direct_pair, server_on_request=on_request) as (client, *_): + with anyio.fail_after(5), pytest.raises(MCPError) as exc: + await client.send_raw_request("tools/list", {}) + assert exc.value.error.code == INTERNAL_ERROR + assert isinstance(exc.value.__cause__, ValueError) + + +@pytest.mark.anyio +async def test_direct_send_raw_request_issued_before_peer_run_blocks_until_peer_ready(): + client, server = create_direct_dispatcher_pair() + s_req, s_notify = echo_handlers(Recorder()) + c_req, c_notify = echo_handlers(Recorder()) + + async def late_start(): + await anyio.sleep(0) + await server.run(s_req, s_notify) + + async with anyio.create_task_group() as tg: + tg.start_soon(client.run, c_req, c_notify) + tg.start_soon(late_start) + with anyio.fail_after(5): + result = await client.send_raw_request("ping", None) + assert result == {"echoed": "ping", "params": {}} + client.close() + server.close() + + +@pytest.mark.anyio +async def test_direct_send_raw_request_and_notify_raise_runtimeerror_when_no_peer_connected(): + d = DirectDispatcher(TransportContext(kind="direct", can_send_request=True)) + with pytest.raises(RuntimeError, match="no peer"): + await d.send_raw_request("ping", None) + with pytest.raises(RuntimeError, match="no peer"): + await d.notify("ping", None) + + +@pytest.mark.anyio +async def test_direct_close_makes_run_return(): + client, server = create_direct_dispatcher_pair() + on_request, on_notify = echo_handlers(Recorder()) + with anyio.fail_after(5): + async with anyio.create_task_group() as tg: + tg.start_soon(server.run, on_request, on_notify) + tg.start_soon(client.run, on_request, on_notify) + client.close() + server.close() + + +if TYPE_CHECKING: + _d: Dispatcher[TransportContext] = DirectDispatcher(TransportContext(kind="direct", can_send_request=True)) + _o: Outbound = _d diff --git a/tests/shared/test_jsonrpc_dispatcher.py b/tests/shared/test_jsonrpc_dispatcher.py new file mode 100644 index 0000000000..5755b55d15 --- /dev/null +++ b/tests/shared/test_jsonrpc_dispatcher.py @@ -0,0 +1,614 @@ +"""JSON-RPC-specific Dispatcher tests. + +Behaviors with no `DirectDispatcher` analog: request-id correlation, the +exception-to-wire boundary, peer-cancel handling, and shutdown fan-out. +The contract tests shared with `DirectDispatcher` live in +``test_dispatcher.py``. +""" + +import contextvars +from collections.abc import Mapping +from typing import Any + +import anyio +import pytest + +from mcp.shared._context_streams import ContextReceiveStream, ContextSendStream +from mcp.shared.dispatcher import DispatchContext +from mcp.shared.exceptions import MCPError +from mcp.shared.jsonrpc_dispatcher import ( # pyright: ignore[reportPrivateUsage] + JSONRPCDispatcher, + _coerce_id, + _outbound_metadata, + _Pending, +) +from mcp.shared.message import ClientMessageMetadata, ServerMessageMetadata, SessionMessage +from mcp.shared.transport_context import TransportContext +from mcp.types import ( + CONNECTION_CLOSED, + INTERNAL_ERROR, + INVALID_PARAMS, + ErrorData, + JSONRPCError, + JSONRPCNotification, + JSONRPCRequest, + JSONRPCResponse, + Tool, +) + +from .conftest import jsonrpc_pair +from .test_dispatcher import Recorder, echo_handlers, running_pair + +DCtx = DispatchContext[TransportContext] + + +@pytest.mark.anyio +async def test_concurrent_send_raw_requests_correlate_by_id_when_responses_arrive_out_of_order(): + release_first = anyio.Event() + + async def server_on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: + if method == "first": + await release_first.wait() + return {"m": method} + + async with running_pair(jsonrpc_pair, server_on_request=server_on_request) as (client, *_): + results: dict[str, dict[str, Any]] = {} + + async def call(method: str) -> None: + results[method] = await client.send_raw_request(method, None) + + with anyio.fail_after(5): + async with anyio.create_task_group() as tg: # pragma: no branch + tg.start_soon(call, "first") + await anyio.sleep(0) + tg.start_soon(call, "second") + await anyio.sleep(0) + # second resolves while first is still parked + assert "first" not in results + release_first.set() + assert results == {"first": {"m": "first"}, "second": {"m": "second"}} + + +@pytest.mark.anyio +async def test_handler_raising_exception_sends_internal_error_with_str_message(): + """Per design: INTERNAL_ERROR carries str(e), not a scrubbed message.""" + + async def server_on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: + raise RuntimeError("kaboom") + + async with running_pair(jsonrpc_pair, server_on_request=server_on_request) as (client, *_): + with anyio.fail_after(5), pytest.raises(MCPError) as exc: + await client.send_raw_request("tools/list", None) + assert exc.value.error.code == INTERNAL_ERROR + assert exc.value.error.message == "kaboom" + assert exc.value.__cause__ is None # cause does not survive the wire + + +@pytest.mark.anyio +async def test_peer_cancel_interrupt_mode_sets_cancel_requested_and_sends_no_response(): + handler_started = anyio.Event() + handler_exited = anyio.Event() + seen_ctx: list[DCtx] = [] + + async def server_on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: + seen_ctx.append(ctx) + handler_started.set() + try: + await anyio.sleep_forever() + finally: + handler_exited.set() + raise NotImplementedError + + async with running_pair(jsonrpc_pair, server_on_request=server_on_request) as (client, *_): + with anyio.fail_after(5): + async with anyio.create_task_group() as tg: # pragma: no branch + + async def call_then_record() -> None: + with pytest.raises(MCPError): # we'll cancel via tg below + await client.send_raw_request("slow", None) + + tg.start_soon(call_then_record) + await handler_started.wait() + # cancel just the handler (peer-cancel), not our caller + await client.notify("notifications/cancelled", {"requestId": 1}) + await handler_exited.wait() + # Handler torn down, no response was written; caller is still parked. + # Cancel the caller's task to end the test. + tg.cancel_scope.cancel() + assert seen_ctx[0].cancel_requested.is_set() + + +@pytest.mark.anyio +async def test_peer_cancel_signal_mode_sets_event_but_handler_runs_to_completion(): + handler_started = anyio.Event() + cancel_seen = anyio.Event() + + async def server_on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: + handler_started.set() + await ctx.cancel_requested.wait() + cancel_seen.set() + return {"finished": True} + + def factory(*, can_send_request: bool = True): + client, server, close = jsonrpc_pair(can_send_request=can_send_request) + # Reach in to set signal mode on the server side. + assert isinstance(server, JSONRPCDispatcher) + server._peer_cancel_mode = "signal" # pyright: ignore[reportPrivateUsage] + return client, server, close + + result_box: list[dict[str, Any]] = [] + async with running_pair(factory, server_on_request=server_on_request) as (client, *_): + with anyio.fail_after(5): + async with anyio.create_task_group() as tg: # pragma: no branch + + async def call() -> None: + result_box.append(await client.send_raw_request("slow", None)) + + tg.start_soon(call) + await handler_started.wait() + await client.notify("notifications/cancelled", {"requestId": 1}) + await cancel_seen.wait() + assert result_box == [{"finished": True}] + + +@pytest.mark.anyio +async def test_send_raw_request_raises_connection_closed_when_read_stream_eofs_mid_await(): + """A blocked send_raw_request is woken with CONNECTION_CLOSED when run() exits.""" + c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) + s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) + client: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(s2c_recv, c2s_send) + on_request, on_notify = echo_handlers(Recorder()) + try: + async with anyio.create_task_group() as tg: + await tg.start(client.run, on_request, on_notify) + + async def caller() -> None: + with pytest.raises(MCPError) as exc: + await client.send_raw_request("ping", None) + assert exc.value.error.code == CONNECTION_CLOSED + + tg.start_soon(caller) + await anyio.sleep(0) + # No server: simulate the peer dropping by closing the read side. + s2c_send.close() + finally: + for s in (c2s_send, c2s_recv, s2c_send, s2c_recv): + s.close() + + +@pytest.mark.anyio +async def test_late_response_after_timeout_is_dropped_without_crashing(): + handler_started = anyio.Event() + proceed = anyio.Event() + + async def server_on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: + handler_started.set() + await proceed.wait() + return {"late": True} + + async with running_pair(jsonrpc_pair, server_on_request=server_on_request) as (client, *_): + with anyio.fail_after(5): + with pytest.raises(MCPError): # REQUEST_TIMEOUT + await client.send_raw_request("slow", None, {"timeout": 0}) + # The server handler is still running; let it finish and write a + # response for an id the client has already discarded. + await handler_started.wait() + proceed.set() + # One more round-trip proves the dispatcher is still healthy. + assert await client.send_raw_request("ping", None) == {"late": True} + + +@pytest.mark.anyio +async def test_raise_handler_exceptions_true_propagates_out_of_run(): + c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) + s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) + + def builder(_rid: object, _meta: object) -> TransportContext: + return TransportContext(kind="jsonrpc", can_send_request=True) + + server: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher( + c2s_recv, s2c_send, transport_builder=builder, raise_handler_exceptions=True + ) + + async def boom(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: + raise RuntimeError("propagate me") + + async def on_notify(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> None: + raise NotImplementedError + + try: + with pytest.raises(BaseException) as exc: + async with anyio.create_task_group() as tg: + await tg.start(server.run, boom, on_notify) + # Inject a request directly onto the server's read stream. + await c2s_send.send( + SessionMessage(message=JSONRPCRequest(jsonrpc="2.0", id=1, method="x", params=None)) + ) + assert exc.group_contains(RuntimeError, match="propagate me") + # The error response was still written before re-raising. + sent = s2c_recv.receive_nowait() + assert isinstance(sent, SessionMessage) + assert isinstance(sent.message, JSONRPCError) + assert sent.message.error.code == INTERNAL_ERROR + finally: + for s in (c2s_send, c2s_recv, s2c_send, s2c_recv): + s.close() + + +@pytest.mark.anyio +async def test_ctx_send_raw_request_tags_outbound_with_server_message_metadata(): + """Server-to-client requests carry related_request_id for SHTTP routing.""" + c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) + s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) + server: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(c2s_recv, s2c_send) + + async def server_on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: + return await ctx.send_raw_request("sampling/createMessage", {"prompt": "hi"}) + + async def on_notify(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> None: + raise NotImplementedError + + try: + async with anyio.create_task_group() as tg: + await tg.start(server.run, server_on_request, on_notify) + # Kick the server with an inbound request id=7. + await c2s_send.send(SessionMessage(message=JSONRPCRequest(jsonrpc="2.0", id=7, method="t", params=None))) + with anyio.fail_after(5): + outbound = await s2c_recv.receive() + assert isinstance(outbound, SessionMessage) + assert isinstance(outbound.message, JSONRPCRequest) + assert isinstance(outbound.metadata, ServerMessageMetadata) + assert outbound.metadata.related_request_id == 7 + # Reply so the handler completes cleanly. + await c2s_send.send( + SessionMessage(message=JSONRPCResponse(jsonrpc="2.0", id=outbound.message.id, result={"ok": True})) + ) + with anyio.fail_after(5): + final = await s2c_recv.receive() + assert isinstance(final, SessionMessage) + assert isinstance(final.message, JSONRPCResponse) + assert final.message.id == 7 + tg.cancel_scope.cancel() + finally: + for s in (c2s_send, c2s_recv, s2c_send, s2c_recv): + s.close() + + +@pytest.mark.anyio +async def test_ctx_progress_with_only_progress_value_omits_total_and_message(): + received: list[tuple[float, float | None, str | None]] = [] + + async def on_progress(progress: float, total: float | None, message: str | None) -> None: + received.append((progress, total, message)) + + async def server_on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: + await ctx.progress(0.25) + return {} + + async with running_pair(jsonrpc_pair, server_on_request=server_on_request) as (client, *_): + with anyio.fail_after(5): + await client.send_raw_request("t", None, {"on_progress": on_progress}) + assert received == [(0.25, None, None)] + + +@pytest.mark.anyio +async def test_handler_raising_validation_error_sends_invalid_params(): + async def server_on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: + Tool.model_validate({"name": 123}) # raises ValidationError + raise NotImplementedError + + async with running_pair(jsonrpc_pair, server_on_request=server_on_request) as (client, *_): + with anyio.fail_after(5), pytest.raises(MCPError) as exc: + await client.send_raw_request("t", None) + assert exc.value.error.code == INVALID_PARAMS + + +@pytest.mark.anyio +async def test_send_raw_request_before_run_raises_runtimeerror(): + c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](1) + s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](1) + d: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(s2c_recv, c2s_send) + try: + with pytest.raises(RuntimeError, match="before run"): + await d.send_raw_request("ping", None) + finally: + for s in (c2s_send, c2s_recv, s2c_send, s2c_recv): + s.close() + + +@pytest.mark.anyio +async def test_transport_exception_in_read_stream_is_logged_and_dropped(): + c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](4) + s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](4) + server: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(c2s_recv, s2c_send) + on_request, on_notify = echo_handlers(Recorder()) + try: + async with anyio.create_task_group() as tg: + await tg.start(server.run, on_request, on_notify) + await c2s_send.send(ValueError("transport hiccup")) + # Dispatcher must remain healthy after the dropped exception. + await c2s_send.send(SessionMessage(message=JSONRPCRequest(jsonrpc="2.0", id=1, method="t", params=None))) + with anyio.fail_after(5): + resp = await s2c_recv.receive() + assert isinstance(resp, SessionMessage) + assert isinstance(resp.message, JSONRPCResponse) + tg.cancel_scope.cancel() + finally: + for s in (c2s_send, c2s_recv, s2c_send, s2c_recv): + s.close() + + +@pytest.mark.anyio +async def test_progress_notification_for_unknown_token_falls_through_to_on_notify(): + async with running_pair(jsonrpc_pair) as (client, _server, _crec, srec): + with anyio.fail_after(5): + await client.notify("notifications/progress", {"progressToken": 999, "progress": 0.5}) + await srec.notified.wait() + assert srec.notifications == [("notifications/progress", {"progressToken": 999, "progress": 0.5})] + + +@pytest.mark.anyio +async def test_cancelled_notification_for_unknown_request_id_is_noop(): + async with running_pair(jsonrpc_pair) as (client, _server, _crec, srec): + with anyio.fail_after(5): + await client.notify("notifications/cancelled", {"requestId": 999}) + # No effect; dispatcher remains healthy. + assert await client.send_raw_request("t", None) == {"echoed": "t", "params": {}} + assert srec.notifications == [] # cancelled is fully consumed, never teed + + +_probe: contextvars.ContextVar[str] = contextvars.ContextVar("probe", default="unset") + + +@pytest.mark.anyio +async def test_handler_inherits_sender_contextvars_via_spawn(): + """The handler task sees contextvars set by the task that wrote into the read stream.""" + raw_send, raw_recv = anyio.create_memory_object_stream[tuple[contextvars.Context, SessionMessage | Exception]](4) + read_stream = ContextReceiveStream[SessionMessage | Exception](raw_recv) + write_send = ContextSendStream[SessionMessage | Exception](raw_send) + out_send, out_recv = anyio.create_memory_object_stream[SessionMessage | Exception](4) + server: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(read_stream, out_send) + + seen: list[str] = [] + + async def server_on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: + seen.append(_probe.get()) + return {} + + async def on_notify(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> None: + raise NotImplementedError + + try: + async with anyio.create_task_group() as tg: + await tg.start(server.run, server_on_request, on_notify) + + async def sender() -> None: + _probe.set("from-sender") + await write_send.send( + SessionMessage(message=JSONRPCRequest(jsonrpc="2.0", id=1, method="t", params=None)) + ) + + tg.start_soon(sender) + with anyio.fail_after(5): + resp = await out_recv.receive() + assert isinstance(resp, SessionMessage) + tg.cancel_scope.cancel() + finally: + for s in (raw_send, raw_recv, out_send, out_recv): + s.close() + assert seen == ["from-sender"] + + +@pytest.mark.anyio +async def test_response_write_after_peer_drop_is_swallowed(): + """Handler completes after the write stream is closed; the dropped write doesn't crash run().""" + c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](4) + s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](4) + server: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(c2s_recv, s2c_send) + proceed = anyio.Event() + handlers_done = anyio.Event() + + async def server_on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: + await proceed.wait() + if method == "raise": + handlers_done.set() + raise MCPError(code=INTERNAL_ERROR, message="x") + return {"ok": True} + + async def on_notify(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> None: + raise NotImplementedError + + try: + async with anyio.create_task_group() as tg: + await tg.start(server.run, server_on_request, on_notify) + await c2s_send.send(SessionMessage(message=JSONRPCRequest(jsonrpc="2.0", id=1, method="ok", params=None))) + await c2s_send.send( + SessionMessage(message=JSONRPCRequest(jsonrpc="2.0", id=2, method="raise", params=None)) + ) + await anyio.sleep(0) + # Peer drops: close the receive end so the server's writes hit BrokenResourceError. + s2c_recv.close() + proceed.set() + with anyio.fail_after(5): + await handlers_done.wait() + # run() must still be healthy — close the read side to let it exit cleanly. + c2s_send.close() + finally: + for s in (c2s_send, c2s_recv, s2c_send, s2c_recv): + s.close() + + +@pytest.mark.anyio +async def test_cancel_outbound_after_write_stream_closed_is_swallowed(): + """Courtesy-cancel write hits a closed stream; the error is swallowed and cancellation propagates.""" + c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](4) + s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](4) + client: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(s2c_recv, c2s_send) + on_request, on_notify = echo_handlers(Recorder()) + caller_done = anyio.Event() + try: + async with anyio.create_task_group() as tg: + await tg.start(client.run, on_request, on_notify) + caller_scope = anyio.CancelScope() + + async def caller() -> None: + with caller_scope: + await client.send_raw_request("slow", None) + caller_done.set() + + tg.start_soon(caller) + # Deterministic proof the request write completed: pull it off the wire. + with anyio.fail_after(5): + sent = await c2s_recv.receive() + assert isinstance(sent, SessionMessage) + assert isinstance(sent.message, JSONRPCRequest) + # Now safe: close the client's write end, then cancel the caller. The + # shielded `_cancel_outbound` write hits ClosedResourceError and is + # swallowed; cancellation propagates cleanly. + c2s_send.close() + caller_scope.cancel() + with anyio.fail_after(5): + await caller_done.wait() + tg.cancel_scope.cancel() + finally: + for s in (c2s_send, c2s_recv, s2c_send, s2c_recv): + s.close() + + +def test_resolve_pending_drops_outcome_when_waiter_stream_already_closed(): + """White-box: a response for an id still in _pending but whose waiter has gone.""" + c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](1) + s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](1) + d: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(s2c_recv, c2s_send) + send, recv = anyio.create_memory_object_stream[dict[str, Any] | ErrorData](1) + d._pending[1] = _Pending(send=send, receive=recv) # pyright: ignore[reportPrivateUsage] + recv.close() # waiter gone — send_nowait will raise BrokenResourceError + d._resolve_pending(1, {"late": True}) # pyright: ignore[reportPrivateUsage] + for s in (c2s_send, c2s_recv, s2c_send, s2c_recv, send): + s.close() + + +def test_fan_out_closed_drops_signal_when_waiter_already_has_outcome(): + """White-box: the buffer=1 invariant — WouldBlock means waiter already has an outcome.""" + c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](1) + s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](1) + d: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(s2c_recv, c2s_send) + send, recv = anyio.create_memory_object_stream[dict[str, Any] | ErrorData](1) + # Register a fake pending and pre-fill its single buffer slot. + d._pending[1] = _Pending(send=send, receive=recv) # pyright: ignore[reportPrivateUsage] + send.send_nowait({"real": "result"}) + d._fan_out_closed() # pyright: ignore[reportPrivateUsage] + # The real result is still there; the close signal was dropped. + assert recv.receive_nowait() == {"real": "result"} + assert d._pending == {} # pyright: ignore[reportPrivateUsage] + for s in (c2s_send, c2s_recv, s2c_send, s2c_recv, send, recv): + s.close() + + +def test_outbound_metadata_with_resumption_token_returns_client_metadata(): + md = _outbound_metadata(None, {"resumption_token": "abc"}) + assert isinstance(md, ClientMessageMetadata) + assert md.resumption_token == "abc" + assert _outbound_metadata(None, None) is None + assert _outbound_metadata(None, {}) is None + + +@pytest.mark.anyio +async def test_response_with_string_id_correlates_to_int_keyed_pending_request(): + """A peer that echoes the request ID as a JSON string still resolves the waiter.""" + c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) + s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) + client: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(s2c_recv, c2s_send) + on_request, on_notify = echo_handlers(Recorder()) + try: + async with anyio.create_task_group() as tg: + await tg.start(client.run, on_request, on_notify) + with anyio.fail_after(5): + + async def respond_stringly() -> None: + out = await c2s_recv.receive() + assert isinstance(out, SessionMessage) + assert isinstance(out.message, JSONRPCRequest) + rid = out.message.id + assert isinstance(rid, int) + await s2c_send.send( + SessionMessage(message=JSONRPCResponse(jsonrpc="2.0", id=str(rid), result={"ok": True})) + ) + + tg.start_soon(respond_stringly) + result = await client.send_raw_request("ping", None) + assert result == {"ok": True} + tg.cancel_scope.cancel() + finally: + for s in (c2s_send, c2s_recv, s2c_send, s2c_recv): + s.close() + + +@pytest.mark.anyio +async def test_progress_with_string_token_reaches_callback_for_int_keyed_request(): + c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) + s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) + client: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(s2c_recv, c2s_send) + on_request, on_notify = echo_handlers(Recorder()) + seen: list[float] = [] + try: + async with anyio.create_task_group() as tg: + await tg.start(client.run, on_request, on_notify) + with anyio.fail_after(5): + + async def respond_with_string_token_progress() -> None: + out = await c2s_recv.receive() + assert isinstance(out, SessionMessage) + assert isinstance(out.message, JSONRPCRequest) + rid = out.message.id + assert isinstance(rid, int) + await s2c_send.send( + SessionMessage( + message=JSONRPCNotification( + jsonrpc="2.0", + method="notifications/progress", + params={"progressToken": str(rid), "progress": 0.5}, + ) + ) + ) + await s2c_send.send( + SessionMessage(message=JSONRPCResponse(jsonrpc="2.0", id=rid, result={"ok": True})) + ) + + async def on_progress(progress: float, total: float | None, message: str | None) -> None: + seen.append(progress) + + tg.start_soon(respond_with_string_token_progress) + result = await client.send_raw_request("ping", None, {"on_progress": on_progress}) + assert result == {"ok": True} + tg.cancel_scope.cancel() + finally: + for s in (c2s_send, c2s_recv, s2c_send, s2c_recv): + s.close() + assert seen == [0.5] + + +def test_coerce_id_passes_through_non_numeric_string_and_int(): + assert _coerce_id("7") == 7 + assert _coerce_id("not-an-int") == "not-an-int" + assert _coerce_id(42) == 42 + + +@pytest.mark.anyio +async def test_jsonrpc_error_response_with_null_id_is_dropped(): + """Parse-error responses (id=null) have no waiter; they're logged and dropped.""" + c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) + s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) + client: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(s2c_recv, c2s_send) + on_request, on_notify = echo_handlers(Recorder()) + try: + async with anyio.create_task_group() as tg: + await tg.start(client.run, on_request, on_notify) + await s2c_send.send( + SessionMessage(message=JSONRPCError(jsonrpc="2.0", id=None, error=ErrorData(code=-32700, message="x"))) + ) + await anyio.sleep(0) + tg.cancel_scope.cancel() + finally: + for s in (c2s_send, c2s_recv, s2c_send, s2c_recv): + s.close() diff --git a/tests/shared/test_otel.py b/tests/shared/test_otel.py index ec7ff78cc1..a7df4c4294 100644 --- a/tests/shared/test_otel.py +++ b/tests/shared/test_otel.py @@ -10,9 +10,6 @@ pytestmark = pytest.mark.anyio -# Logfire warns about propagated trace context by default (distributed_tracing=None). -# This is expected here since we're testing cross-boundary context propagation. -@pytest.mark.filterwarnings("ignore::RuntimeWarning") async def test_client_and_server_spans(capfire: CaptureLogfire): """Verify that calling a tool produces client and server spans with correct attributes.""" server = MCPServer("test") diff --git a/tests/shared/test_peer.py b/tests/shared/test_peer.py new file mode 100644 index 0000000000..0be4225818 --- /dev/null +++ b/tests/shared/test_peer.py @@ -0,0 +1,164 @@ +"""Tests for `PeerMixin` and `Peer`. + +Each PeerMixin method is tested by wrapping a `DirectDispatcher` in `Peer`, +calling the typed method, and asserting (a) the right method+params went out +and (b) the return value is the typed result model. +""" + +from collections.abc import Mapping +from typing import Any + +import anyio +import pytest + +from mcp.shared.dispatcher import DispatchContext +from mcp.shared.peer import Peer, dump_params +from mcp.shared.transport_context import TransportContext +from mcp.types import ( + CreateMessageResult, + CreateMessageResultWithTools, + ElicitResult, + ListRootsResult, + SamplingMessage, + TextContent, + Tool, +) + +from .conftest import direct_pair +from .test_dispatcher import running_pair + +DCtx = DispatchContext[TransportContext] + + +class _Recorder: + def __init__(self, result: dict[str, Any]) -> None: + self.result = result + self.seen: list[tuple[str, Mapping[str, Any] | None]] = [] + + async def on_request(self, ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: + self.seen.append((method, params)) + return self.result + + +@pytest.mark.anyio +async def test_peer_sample_sends_create_message_and_returns_typed_result(): + rec = _Recorder({"role": "assistant", "content": {"type": "text", "text": "hi"}, "model": "m"}) + async with running_pair(direct_pair, server_on_request=rec.on_request) as (client, *_): + peer = Peer(client) + with anyio.fail_after(5): + result = await peer.sample( + [SamplingMessage(role="user", content=TextContent(type="text", text="hello"))], + max_tokens=10, + ) + method, params = rec.seen[0] + assert method == "sampling/createMessage" + assert params is not None and params["maxTokens"] == 10 + assert isinstance(result, CreateMessageResult) + assert result.model == "m" + + +@pytest.mark.anyio +async def test_peer_sample_with_tools_returns_with_tools_result(): + rec = _Recorder({"role": "assistant", "content": [{"type": "text", "text": "x"}], "model": "m"}) + async with running_pair(direct_pair, server_on_request=rec.on_request) as (client, *_): + peer = Peer(client) + with anyio.fail_after(5): + result = await peer.sample( + [SamplingMessage(role="user", content=TextContent(type="text", text="q"))], + max_tokens=5, + tools=[Tool(name="t", input_schema={"type": "object"})], + ) + method, params = rec.seen[0] + assert method == "sampling/createMessage" + assert params is not None and params["tools"][0]["name"] == "t" + assert isinstance(result, CreateMessageResultWithTools) + + +@pytest.mark.anyio +async def test_peer_elicit_form_sends_elicitation_create_with_form_params(): + rec = _Recorder({"action": "accept", "content": {"name": "Max"}}) + async with running_pair(direct_pair, server_on_request=rec.on_request) as (client, *_): + peer = Peer(client) + with anyio.fail_after(5): + result = await peer.elicit_form("Your name?", requested_schema={"type": "object", "properties": {}}) + method, params = rec.seen[0] + assert method == "elicitation/create" + assert params is not None and params["mode"] == "form" + assert params["message"] == "Your name?" + assert isinstance(result, ElicitResult) + + +@pytest.mark.anyio +async def test_peer_elicit_url_sends_elicitation_create_with_url_params(): + rec = _Recorder({"action": "accept"}) + async with running_pair(direct_pair, server_on_request=rec.on_request) as (client, *_): + peer = Peer(client) + with anyio.fail_after(5): + result = await peer.elicit_url("Auth needed", url="https://example.com/auth", elicitation_id="e1") + method, params = rec.seen[0] + assert method == "elicitation/create" + assert params is not None and params["mode"] == "url" + assert params["url"] == "https://example.com/auth" + assert isinstance(result, ElicitResult) + + +@pytest.mark.anyio +async def test_peer_list_roots_sends_roots_list_and_returns_typed_result(): + rec = _Recorder({"roots": [{"uri": "file:///workspace"}]}) + async with running_pair(direct_pair, server_on_request=rec.on_request) as (client, *_): + peer = Peer(client) + with anyio.fail_after(5): + result = await peer.list_roots() + method, _ = rec.seen[0] + assert method == "roots/list" + assert isinstance(result, ListRootsResult) + assert len(result.roots) == 1 + assert str(result.roots[0].uri) == "file:///workspace" + + +@pytest.mark.anyio +async def test_peer_list_roots_with_meta_sends_meta_in_params(): + rec = _Recorder({"roots": []}) + async with running_pair(direct_pair, server_on_request=rec.on_request) as (client, *_): + peer = Peer(client) + with anyio.fail_after(5): + await peer.list_roots(meta={"traceId": "t1"}) + method, params = rec.seen[0] + assert method == "roots/list" + assert params == {"_meta": {"traceId": "t1"}} + + +def test_dump_params_merges_meta_over_model_meta(): + out = dump_params(None, None) + assert out is None + out = dump_params(None, {"k": 1}) + assert out == {"_meta": {"k": 1}} + + +@pytest.mark.anyio +async def test_peer_notify_forwards_to_wrapped_outbound(): + sent: list[tuple[str, Mapping[str, Any] | None]] = [] + + class _Out: + async def send_raw_request( + self, method: str, params: Mapping[str, Any] | None, opts: Any = None + ) -> dict[str, Any]: + raise NotImplementedError + + async def notify(self, method: str, params: Mapping[str, Any] | None) -> None: + sent.append((method, params)) + + await Peer(_Out()).notify("n", {"x": 1}) + assert sent == [("n", {"x": 1})] + + +@pytest.mark.anyio +async def test_peer_ping_sends_ping_and_returns_none(): + rec = _Recorder({}) + async with running_pair(direct_pair, server_on_request=rec.on_request) as (client, *_): + peer = Peer(client) + with anyio.fail_after(5): + result = await peer.ping() + method, _ = rec.seen[0] + assert method == "ping" + assert result is None diff --git a/uv.lock b/uv.lock index 5b72e97fce..86ea2f5fb7 100644 --- a/uv.lock +++ b/uv.lock @@ -885,6 +885,7 @@ dev = [ { name = "inline-snapshot" }, { name = "logfire" }, { name = "mcp", extra = ["cli", "ws"] }, + { name = "opentelemetry-sdk" }, { name = "pillow" }, { name = "pyright" }, { name = "pytest" }, @@ -937,6 +938,7 @@ dev = [ { name = "inline-snapshot", specifier = ">=0.23.0" }, { name = "logfire", specifier = ">=3.0.0" }, { name = "mcp", extras = ["cli", "ws"], editable = "." }, + { name = "opentelemetry-sdk", specifier = ">=1.39.1" }, { name = "pillow", specifier = ">=12.0" }, { name = "pyright", specifier = ">=1.1.400" }, { name = "pytest", specifier = ">=8.4.0" },