diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index 243eef5ae..357865133 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -456,26 +456,6 @@ async def _handle_session_message(message: SessionMessage) -> None: pass self._response_streams.clear() - def _normalize_request_id(self, response_id: RequestId) -> RequestId: - """Normalize a response ID to match how request IDs are stored. - - Since the client always sends integer IDs, we normalize string IDs - to integers when possible. This matches the TypeScript SDK approach: - https://github.com/modelcontextprotocol/typescript-sdk/blob/a606fb17909ea454e83aab14c73f14ea45c04448/src/shared/protocol.ts#L861 - - Args: - response_id: The response ID from the incoming message. - - Returns: - The normalized ID (int if possible, otherwise original value). - """ - if isinstance(response_id, str): - try: - return int(response_id) - except ValueError: - logging.warning(f"Response ID {response_id!r} cannot be normalized to match pending requests") - return response_id - async def _handle_response(self, message: SessionMessage) -> None: """Handle an incoming response or error message. @@ -495,8 +475,7 @@ async def _handle_response(self, message: SessionMessage) -> None: logging.warning(f"Received error with null ID: {error.message}") await self._handle_incoming(MCPError(error.code, error.message, error.data)) return - # Normalize response ID to handle type mismatches (e.g., "0" vs 0) - response_id = self._normalize_request_id(message.message.id) + response_id = message.message.id # First, check response routers (e.g., TaskResultHandler) if isinstance(message.message, JSONRPCError): diff --git a/tests/shared/test_session.py b/tests/shared/test_session.py index d7c6cc3b5..a1582d0a2 100644 --- a/tests/shared/test_session.py +++ b/tests/shared/test_session.py @@ -99,44 +99,47 @@ async def make_request(client: Client): @pytest.mark.anyio -async def test_response_id_type_mismatch_string_to_int(): - """Test that responses with string IDs are correctly matched to requests sent with - integer IDs. +async def test_response_id_type_mismatch_string_to_int_rejected(): + """Verify that a response with a string ID does not match a request sent with + an integer ID. - This handles the case where a server returns "id": "0" (string) but the client - sent "id": 0 (integer). Without ID type normalization, this would cause a timeout. + Per JSON-RPC 2.0, the response ID "MUST be the same as the value of the id + member in the Request Object". Since Python treats 0 != "0", a server that + echoes back "0" instead of 0 is non-compliant and the request should time out. """ - ev_response_received = anyio.Event() - result_holder: list[types.EmptyResult] = [] + ev_timeout = anyio.Event() async with create_client_server_memory_streams() as (client_streams, server_streams): client_read, client_write = client_streams server_read, server_write = server_streams async def mock_server(): - """Receive a request and respond with a string ID instead of integer.""" message = await server_read.receive() assert isinstance(message, SessionMessage) root = message.message assert isinstance(root, JSONRPCRequest) - # Get the original request ID (which is an integer) request_id = root.id - assert isinstance(request_id, int), f"Expected int, got {type(request_id)}" + assert isinstance(request_id, int) - # Respond with the ID as a string (simulating a buggy server) + # Respond with the ID as a string (non-compliant server) response = JSONRPCResponse( jsonrpc="2.0", - id=str(request_id), # Convert to string to simulate mismatch + id=str(request_id), result={}, ) await server_write.send(SessionMessage(message=response)) async def make_request(client_session: ClientSession): - nonlocal result_holder - # Send a ping request (uses integer ID internally) - result = await client_session.send_ping() - result_holder.append(result) - ev_response_received.set() + try: + await client_session.send_request( + types.PingRequest(), + types.EmptyResult, + request_read_timeout_seconds=0.5, + ) + pytest.fail("Expected timeout") # pragma: no cover + except MCPError as e: + assert "Timed out" in str(e) + ev_timeout.set() async with ( anyio.create_task_group() as tg, @@ -146,29 +149,23 @@ async def make_request(client_session: ClientSession): tg.start_soon(make_request, client_session) with anyio.fail_after(2): # pragma: no branch - await ev_response_received.wait() - - assert len(result_holder) == 1 - assert isinstance(result_holder[0], EmptyResult) + await ev_timeout.wait() @pytest.mark.anyio -async def test_error_response_id_type_mismatch_string_to_int(): - """Test that error responses with string IDs are correctly matched to requests - sent with integer IDs. +async def test_error_response_id_type_mismatch_string_to_int_rejected(): + """Verify that an error response with a string ID does not match a request + sent with an integer ID. - This handles the case where a server returns an error with "id": "0" (string) - but the client sent "id": 0 (integer). + The JSON-RPC spec requires exact ID matching including type. """ - ev_error_received = anyio.Event() - error_holder: list[MCPError | Exception] = [] + ev_timeout = anyio.Event() async with create_client_server_memory_streams() as (client_streams, server_streams): client_read, client_write = client_streams server_read, server_write = server_streams async def mock_server(): - """Receive a request and respond with an error using a string ID.""" message = await server_read.receive() assert isinstance(message, SessionMessage) root = message.message @@ -176,22 +173,25 @@ async def mock_server(): request_id = root.id assert isinstance(request_id, int) - # Respond with an error, using the ID as a string + # Respond with an error using the ID as a string (non-compliant) error_response = JSONRPCError( jsonrpc="2.0", - id=str(request_id), # Convert to string to simulate mismatch + id=str(request_id), error=ErrorData(code=-32600, message="Test error"), ) await server_write.send(SessionMessage(message=error_response)) async def make_request(client_session: ClientSession): - nonlocal error_holder try: - await client_session.send_ping() - pytest.fail("Expected MCPError to be raised") # pragma: no cover + await client_session.send_request( + types.PingRequest(), + types.EmptyResult, + request_read_timeout_seconds=0.5, + ) + pytest.fail("Expected timeout") # pragma: no cover except MCPError as e: - error_holder.append(e) - ev_error_received.set() + assert "Timed out" in str(e) + ev_timeout.set() async with ( anyio.create_task_group() as tg, @@ -201,16 +201,13 @@ async def make_request(client_session: ClientSession): tg.start_soon(make_request, client_session) with anyio.fail_after(2): # pragma: no branch - await ev_error_received.wait() - - assert len(error_holder) == 1 - assert "Test error" in str(error_holder[0]) + await ev_timeout.wait() @pytest.mark.anyio async def test_response_id_non_numeric_string_no_match(): - """Test that responses with non-numeric string IDs don't incorrectly match - integer request IDs. + """Test that responses with non-numeric string IDs don't match integer + request IDs. If a server returns "id": "abc" (non-numeric string), it should not match a request sent with "id": 0 (integer).