Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 22 additions & 6 deletions src/mcp/client/streamable_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from mcp.shared._httpx_utils import create_mcp_http_client
from mcp.shared.message import ClientMessageMetadata, SessionMessage
from mcp.types import (
CONNECTION_CLOSED,
INTERNAL_ERROR,
INVALID_REQUEST,
PARSE_ERROR,
Expand Down Expand Up @@ -366,10 +367,16 @@ async def _handle_sse_response(
except Exception:
logger.debug("SSE stream ended", exc_info=True) # pragma: no cover

# Stream ended without response - reconnect if we received an event with ID
if last_event_id is not None: # pragma: no branch
logger.info("SSE stream disconnected, reconnecting...")
await self._handle_reconnection(ctx, last_event_id, retry_interval_ms)
# Stream ended without a terminal response/error. If the server provided an event id,
# try resuming; otherwise fail the request instead of hanging forever.
if last_event_id is None:
error_data = ErrorData(code=CONNECTION_CLOSED, message="SSE stream disconnected before response completed")
error_msg = SessionMessage(JSONRPCError(jsonrpc="2.0", id=original_request_id, error=error_data))
await ctx.read_stream_writer.send(error_msg)
return

logger.info("SSE stream disconnected, reconnecting...")
await self._handle_reconnection(ctx, last_event_id, retry_interval_ms)

async def _handle_reconnection(
self,
Expand All @@ -380,7 +387,16 @@ async def _handle_reconnection(
) -> None:
"""Reconnect with Last-Event-ID to resume stream after server disconnect."""
# Bail if max retries exceeded
if attempt >= MAX_RECONNECTION_ATTEMPTS: # pragma: no cover
if attempt >= MAX_RECONNECTION_ATTEMPTS:
assert isinstance(ctx.session_message.message, JSONRPCRequest)
original_request_id = ctx.session_message.message.id
error_data = ErrorData(
code=CONNECTION_CLOSED,
message="SSE stream disconnected and could not be resumed",
data={"last_event_id": last_event_id},
)
error_msg = SessionMessage(JSONRPCError(jsonrpc="2.0", id=original_request_id, error=error_data))
await ctx.read_stream_writer.send(error_msg)
logger.debug(f"Max reconnection attempts ({MAX_RECONNECTION_ATTEMPTS}) exceeded")
return

Expand Down Expand Up @@ -424,7 +440,7 @@ async def _handle_reconnection(
# Stream ended again without response - reconnect again (reset attempt counter)
logger.info("SSE stream disconnected, reconnecting...")
await self._handle_reconnection(ctx, reconnect_last_event_id, reconnect_retry_ms, 0)
except Exception as e: # pragma: no cover
except Exception as e:
logger.debug(f"Reconnection failed: {e}")
# Try to reconnect again if we still have an event ID
await self._handle_reconnection(ctx, last_event_id, retry_interval_ms, attempt + 1)
Expand Down
33 changes: 33 additions & 0 deletions tests/client/test_streamable_http.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import httpx
import pytest

from mcp.client.streamable_http import RequestContext, StreamableHTTPTransport
from mcp.shared._context_streams import create_context_streams
from mcp.shared.message import SessionMessage
from mcp.types import CONNECTION_CLOSED, JSONRPCError, JSONRPCRequest

pytestmark = pytest.mark.anyio


async def test_sse_response_disconnect_before_any_event_id_fails_request() -> None:
transport = StreamableHTTPTransport("http://example.com/mcp")
async with httpx.AsyncClient() as client:
read_stream_writer, read_stream = create_context_streams[SessionMessage | Exception](1)
request = JSONRPCRequest(jsonrpc="2.0", id=1, method="tools/call", params={"name": "noop", "arguments": {}})
ctx = RequestContext(
client=client,
session_id=None,
session_message=SessionMessage(request),
metadata=None,
read_stream_writer=read_stream_writer,
)
response = httpx.Response(200, headers={"content-type": "text/event-stream"}, content=b"")

async with read_stream_writer, read_stream:
await transport._handle_sse_response(response, ctx)
message = await read_stream.receive()

assert isinstance(message, SessionMessage)
assert isinstance(message.message, JSONRPCError)
assert message.message.id == 1
assert message.message.error.code == CONNECTION_CLOSED
76 changes: 76 additions & 0 deletions tests/interaction/transports/test_hosting_resume.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,82 @@ async def call() -> None:
assert received == snapshot(["before close", "after close"])


@requirement("hosting:resume:close-stream")
@requirement("transport:streamable-http:resumability")
@requirement("client-transport:http:reconnect-post-priming")
@requirement("client-transport:http:reconnect-retry-value")
async def test_a_call_whose_stream_closes_and_cannot_be_resumed_fails_instead_of_hanging() -> None:
"""If a resumable response stream disconnects and the server session is gone, the client fails
the request instead of hanging forever.

The server closes the call's SSE stream after emitting one related notification. The test then
deletes the active server-side session to force the client's reconnect GET to return 404.
Without a terminal response/error on the read stream, ClientSession.send_request waits forever
(read timeout defaults to None). The transport must surface a request-scoped error when it
gives up reconnecting.
"""
reconnect_attempted = anyio.Event()
allow_exit = anyio.Event()
done = anyio.Event()
raised: list[BaseException] = []
manager_ref = None
deleted_session = False

mcp = MCPServer("resumable")

@mcp.tool()
async def interrupt(ctx: Context) -> str:
await ctx.info("before close")
await ctx.close_sse_stream()
await allow_exit.wait()
return "unreachable"

async def record_request(request: httpx.Request) -> None:
nonlocal deleted_session
if request.method != "GET":
return
if request.headers.get("last-event-id") is None:
return
reconnect_attempted.set()
if deleted_session or manager_ref is None:
return
session_ids = list(manager_ref._server_instances.keys())
if session_ids: # pragma: no branch
del manager_ref._server_instances[session_ids[0]]
deleted_session = True

async with mounted_app(mcp, event_store=SequencedEventStore(), retry_interval=0, on_request=record_request) as (
http,
manager,
):
manager_ref = manager
with anyio.fail_after(5): # pragma: no branch
async with ( # pragma: no branch
streamable_http_client(f"{BASE_URL}/mcp", http_client=http, terminate_on_close=False) as (r, w),
ClientSession(r, w) as session,
anyio.create_task_group() as tg,
):
await session.initialize()

async def call() -> None:
try:
await session.call_tool("interrupt", {})
except BaseException as exc:
raised.append(exc)
finally:
done.set()

tg.start_soon(call)
await reconnect_attempted.wait()
await done.wait()
allow_exit.set()
tg.cancel_scope.cancel()

assert len(raised) == 1
assert isinstance(raised[0], Exception)
assert "disconnected" in str(raised[0]).lower()


@requirement("client-transport:http:resume-stream-api")
async def test_a_captured_resumption_token_replays_missed_messages_on_a_new_connection() -> None:
"""A resumption token captured via on_resumption_token_update on one connection lets a fresh
Expand Down
Loading