diff --git a/src/mcp/server/streamable_http.py b/src/mcp/server/streamable_http.py index f2f4407ce..33d7fcd2e 100644 --- a/src/mcp/server/streamable_http.py +++ b/src/mcp/server/streamable_http.py @@ -767,10 +767,19 @@ async def terminate(self) -> None: Once terminated, all requests with this session ID will receive 404 Not Found. """ + if self._terminated: + return self._terminated = True logger.info(f"Terminating session: {self.mcp_session_id}") + # Close active SSE responses so ASGI response tasks can finish before + # the session manager cancels the owning task group. + sse_stream_writers = list(self._sse_stream_writers.values()) + self._sse_stream_writers.clear() + for writer in sse_stream_writers: + writer.close() + # We need a copy of the keys to avoid modification during iteration request_stream_keys = list(self._request_streams.keys()) diff --git a/src/mcp/server/streamable_http_manager.py b/src/mcp/server/streamable_http_manager.py index 81350a8f2..1f6d2a321 100644 --- a/src/mcp/server/streamable_http_manager.py +++ b/src/mcp/server/streamable_http_manager.py @@ -133,12 +133,23 @@ async def lifespan(app: Starlette) -> AsyncIterator[None]: yield # Let the application run finally: logger.info("StreamableHTTP session manager shutting down") - # Cancel task group to stop all spawned tasks - tg.cancel_scope.cancel() - self._task_group = None - # Clear any remaining server instances - self._server_instances.clear() - self._session_owners.clear() + try: + await self._terminate_active_sessions() + finally: + # Cancel task group to stop all spawned tasks + tg.cancel_scope.cancel() + self._task_group = None + # Clear any remaining server instances + self._server_instances.clear() + self._session_owners.clear() + + async def _terminate_active_sessions(self) -> None: + """Terminate tracked transports before cancelling their task group.""" + for transport in list(self._server_instances.values()): + try: + await transport.terminate() + except Exception: + logger.exception("Error terminating StreamableHTTP session during shutdown") async def handle_request(self, scope: Scope, receive: Receive, send: Send) -> None: """Process ASGI request with proper session handling and transport setup. diff --git a/tests/server/test_streamable_http_manager.py b/tests/server/test_streamable_http_manager.py index ba7554796..f110dc577 100644 --- a/tests/server/test_streamable_http_manager.py +++ b/tests/server/test_streamable_http_manager.py @@ -2,7 +2,8 @@ import json import logging -from typing import Any +from types import SimpleNamespace +from typing import Any, cast from unittest.mock import AsyncMock, patch import anyio @@ -64,6 +65,50 @@ async def try_run(): assert "StreamableHTTPSessionManager .run() can only be called once per instance" in str(errors[0]) +@pytest.mark.anyio +async def test_run_terminates_active_streaming_session_before_shutdown(): + """run() should close active SSE transports before task cancellation.""" + app = Server("test-shutdown-cleanup") + manager = StreamableHTTPSessionManager(app=app) + transport = StreamableHTTPServerTransport(mcp_session_id="session-id") + sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream[dict[str, str]](1) + + try: + transport._sse_stream_writers["request-id"] = sse_stream_writer + + async with manager.run(): + manager._server_instances["session-id"] = transport + + assert transport.is_terminated + assert transport._sse_stream_writers == {} + assert manager._server_instances == {} + with pytest.raises(anyio.ClosedResourceError): + await sse_stream_writer.send({"data": "still-open"}) + finally: + await sse_stream_reader.aclose() + + +@pytest.mark.anyio +async def test_run_terminates_remaining_sessions_if_one_shutdown_fails(caplog: pytest.LogCaptureFixture): + """One failed transport shutdown should not skip later active sessions.""" + app = Server("test-shutdown-cleanup-error") + manager = StreamableHTTPSessionManager(app=app) + failing_terminate = AsyncMock(side_effect=RuntimeError("terminate failed")) + healthy_terminate = AsyncMock() + failing_transport = cast(StreamableHTTPServerTransport, SimpleNamespace(terminate=failing_terminate)) + healthy_transport = cast(StreamableHTTPServerTransport, SimpleNamespace(terminate=healthy_terminate)) + + with caplog.at_level(logging.ERROR): + async with manager.run(): + manager._server_instances["bad-session"] = failing_transport + manager._server_instances["healthy-session"] = healthy_transport + + failing_terminate.assert_awaited_once_with() + healthy_terminate.assert_awaited_once_with() + assert "Error terminating StreamableHTTP session during shutdown" in caplog.text + assert manager._server_instances == {} + + @pytest.mark.anyio async def test_handle_request_without_run_raises_error(): """Test that handle_request raises error if run() hasn't been called.""" @@ -271,6 +316,43 @@ async def mock_receive(): assert len(transport._request_streams) == 0, "Transport should have no active request streams" +@pytest.mark.anyio +async def test_transport_terminate_closes_sse_stream_writers(): + """terminate() should close active SSE writers so streaming responses can finish.""" + transport = StreamableHTTPServerTransport(mcp_session_id="test-session") + sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream[dict[str, str]](1) + + try: + transport._sse_stream_writers["request-id"] = sse_stream_writer + + await transport.terminate() + + assert transport._sse_stream_writers == {} + with pytest.raises(anyio.ClosedResourceError): + await sse_stream_writer.send({"data": "still-open"}) + + await transport.terminate() + finally: + await sse_stream_reader.aclose() + + +@pytest.mark.anyio +async def test_transport_connect_cleans_request_streams_on_exit(): + """connect() should close registered request streams when the transport exits.""" + transport = StreamableHTTPServerTransport(mcp_session_id="test-session") + request_stream_writer, request_stream_reader = anyio.create_memory_object_stream[Any](1) + + transport._request_streams["request-id"] = (request_stream_writer, request_stream_reader) + + async with transport.connect(): + assert "request-id" in transport._request_streams + transport._terminated = True + + assert transport._request_streams == {} + with pytest.raises(anyio.ClosedResourceError): + await request_stream_writer.send(cast(Any, object())) + + @pytest.mark.anyio async def test_unknown_session_id_returns_404(caplog: pytest.LogCaptureFixture): """Test that requests with unknown session IDs return HTTP 404 per MCP spec."""