Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions src/mcp/client/session_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def __init__(
self._session_exit_stacks = {}
self._component_name_hook = component_name_hook

async def __aenter__(self) -> Self: # pragma: no cover
async def __aenter__(self) -> Self:
# Enter the exit stack only if we created it ourselves
if self._owns_exit_stack:
await self._exit_stack.__aenter__()
Expand All @@ -158,7 +158,7 @@ async def __aexit__(
_exc_type: type[BaseException] | None,
_exc_val: BaseException | None,
_exc_tb: TracebackType | None,
) -> bool | None: # pragma: no cover
) -> bool | None:
"""Closes session exit stacks and main exit stack upon completion."""

# Only close the main exit stack if we created it
Expand Down Expand Up @@ -323,7 +323,7 @@ async def _establish_session(
await self._exit_stack.enter_async_context(session_stack)

return result.server_info, session
except Exception: # pragma: no cover
except Exception:
# If anything during this setup fails, ensure the session-specific
# stack is closed.
await session_stack.aclose()
Expand Down
15 changes: 12 additions & 3 deletions src/mcp/client/streamable_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,17 +467,26 @@ async def _handle_message(session_message: SessionMessage) -> None:
read_stream_writer=read_stream_writer,
)

async def handle_request_async():
async def send_message() -> None:
if is_resumption:
await self._handle_resumption_request(ctx)
else:
await self._handle_post_request(ctx)

async def handle_request_async(request: JSONRPCRequest) -> None:
try:
await send_message()
except httpx.TransportError as exc:
logger.debug("Error handling request", exc_info=True)
error_data = ErrorData(code=INTERNAL_ERROR, message=f"Transport error: {exc}")
error_msg = SessionMessage(JSONRPCError(jsonrpc="2.0", id=request.id, error=error_data))
await ctx.read_stream_writer.send(error_msg)

# If this is a request, start a new task to handle it
if isinstance(message, JSONRPCRequest):
tg.start_soon(handle_request_async)
tg.start_soon(handle_request_async, message)
else:
await handle_request_async()
await send_message()

async for session_message in write_stream_reader:
sender_ctx = write_stream_reader.last_context
Expand Down
40 changes: 40 additions & 0 deletions tests/client/test_session_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,27 @@ def test_client_session_group_component_properties():
assert mcp_session_group.tools == {"my_tool": mock_tool}


@pytest.mark.anyio
async def test_client_session_group_context_manager_closes_session_stacks_with_external_stack():
class SessionStack(contextlib.AsyncExitStack):
def __init__(self) -> None:
super().__init__()
self.closed = False

async def aclose(self) -> None:
self.closed = True
await super().aclose()

session_stack = SessionStack()
group = ClientSessionGroup(exit_stack=contextlib.AsyncExitStack())
group._session_exit_stacks[mock.Mock(spec=mcp.ClientSession)] = session_stack

async with group as entered:
assert entered is group

assert session_stack.closed


@pytest.mark.anyio
async def test_client_session_group_call_tool():
# --- Mock Dependencies ---
Expand Down Expand Up @@ -278,6 +299,25 @@ async def test_client_session_group_disconnect_non_existent_server():
await group.disconnect_from_server(session)


@pytest.mark.anyio
async def test_client_session_group_streamable_http_connection_error_surfaces() -> None:
async def fail_request(request: httpx.Request) -> httpx.Response:
raise httpx.ConnectError("offline", request=request)

http_client = httpx.AsyncClient(transport=httpx.MockTransport(fail_request))

with mock.patch("mcp.client.session_group.create_mcp_http_client", return_value=http_client):
async with ClientSessionGroup() as group:
with pytest.raises(MCPError) as excinfo: # pragma: no branch
await group.connect_to_server(
StreamableHttpParameters(url="http://example.test/mcp"),
ClientSessionParameters(read_timeout_seconds=2),
)

assert excinfo.value.error.code == types.INTERNAL_ERROR
assert excinfo.value.error.message == "Transport error: offline"


# TODO(Marcelo): This is horrible. We should drop this test.
@pytest.mark.anyio
@pytest.mark.parametrize(
Expand Down
Loading