From 52ed34ab479f3323937cc7ee5ea1c3e18265d1bd Mon Sep 17 00:00:00 2001 From: Pranjal Bhatia <233476158+pranjalbhatia710@users.noreply.github.com> Date: Sun, 31 May 2026 11:42:51 +0400 Subject: [PATCH] fix: preserve stdio streams in server transport --- src/mcp/server/stdio.py | 27 +++++++++++++---- tests/server/test_stdio.py | 59 +++++++++++++++++++++++++------------- 2 files changed, 60 insertions(+), 26 deletions(-) diff --git a/src/mcp/server/stdio.py b/src/mcp/server/stdio.py index 5c1459dff6..559353e2b2 100644 --- a/src/mcp/server/stdio.py +++ b/src/mcp/server/stdio.py @@ -17,6 +17,7 @@ async def run_server(): ``` """ +import os import sys from contextlib import asynccontextmanager from io import TextIOWrapper @@ -38,10 +39,18 @@ async def stdio_server(stdin: anyio.AsyncFile[str] | None = None, stdout: anyio. # standard process handles. Encoding of stdin/stdout as text streams on # python is platform-dependent (Windows is particularly problematic), so we # re-wrap the underlying binary stream to ensure UTF-8. + close_stdin = False + close_stdout = False if not stdin: - stdin = anyio.wrap_file(TextIOWrapper(sys.stdin.buffer, encoding="utf-8", errors="replace")) + stdin_fd = os.dup(sys.stdin.fileno()) + stdin_buffer = os.fdopen(stdin_fd, "rb", closefd=True) + stdin = anyio.wrap_file(TextIOWrapper(stdin_buffer, encoding="utf-8", errors="replace")) + close_stdin = True if not stdout: - stdout = anyio.wrap_file(TextIOWrapper(sys.stdout.buffer, encoding="utf-8")) + stdout_fd = os.dup(sys.stdout.fileno()) + stdout_buffer = os.fdopen(stdout_fd, "wb", closefd=True) + stdout = anyio.wrap_file(TextIOWrapper(stdout_buffer, encoding="utf-8")) + close_stdout = True read_stream_writer, read_stream = create_context_streams[SessionMessage | Exception](0) write_stream, write_stream_reader = create_context_streams[SessionMessage](0) @@ -71,7 +80,13 @@ async def stdout_writer(): except anyio.ClosedResourceError: # pragma: no cover await anyio.lowlevel.checkpoint() - async with anyio.create_task_group() as tg: - tg.start_soon(stdin_reader) - tg.start_soon(stdout_writer) - yield read_stream, write_stream + try: + async with anyio.create_task_group() as tg: + tg.start_soon(stdin_reader) + tg.start_soon(stdout_writer) + yield read_stream, write_stream + finally: + if close_stdin: + await stdin.aclose() + if close_stdout: + await stdout.aclose() diff --git a/tests/server/test_stdio.py b/tests/server/test_stdio.py index 677a993567..7ebb709099 100644 --- a/tests/server/test_stdio.py +++ b/tests/server/test_stdio.py @@ -1,5 +1,6 @@ import io import sys +import tempfile from io import TextIOWrapper import anyio @@ -64,7 +65,7 @@ async def test_stdio_server(): @pytest.mark.anyio -async def test_stdio_server_invalid_utf8(monkeypatch: pytest.MonkeyPatch): +async def test_stdio_server_invalid_utf8(): """Non-UTF-8 bytes on stdin must not crash the server. Invalid bytes are replaced with U+FFFD, which then fails JSON parsing and @@ -73,22 +74,40 @@ async def test_stdio_server_invalid_utf8(monkeypatch: pytest.MonkeyPatch): """ # \xff\xfe are invalid UTF-8 start bytes. valid = JSONRPCRequest(jsonrpc="2.0", id=1, method="ping") - raw_stdin = io.BytesIO(b"\xff\xfe\n" + valid.model_dump_json(by_alias=True, exclude_none=True).encode() + b"\n") - - # Replace sys.stdin with a wrapper whose .buffer is our raw bytes, so that - # stdio_server()'s default path wraps it with errors='replace'. - monkeypatch.setattr(sys, "stdin", TextIOWrapper(raw_stdin, encoding="utf-8")) - monkeypatch.setattr(sys, "stdout", TextIOWrapper(io.BytesIO(), encoding="utf-8")) - - with anyio.fail_after(5): - async with stdio_server() as (read_stream, write_stream): - await write_stream.aclose() - async with read_stream: # pragma: no branch - # First line: \xff\xfe -> U+FFFD U+FFFD -> JSON parse fails -> exception in stream - first = await read_stream.receive() - assert isinstance(first, Exception) - - # Second line: valid message still comes through - second = await read_stream.receive() - assert isinstance(second, SessionMessage) - assert second.message == valid + raw_stdin = tempfile.TemporaryFile() + raw_stdin.write(b"\xff\xfe\n" + valid.model_dump_json(by_alias=True, exclude_none=True).encode() + b"\n") + raw_stdin.seek(0) + raw_stdout = tempfile.TemporaryFile() + + # Replace sys.stdin/stdout with wrappers backed by real file descriptors so + # stdio_server()'s default path can duplicate them without closing the + # original process-level streams. + original_stdin = sys.stdin + original_stdout = sys.stdout + test_stdin = TextIOWrapper(raw_stdin, encoding="utf-8") + test_stdout = TextIOWrapper(raw_stdout, encoding="utf-8") + sys.stdin = test_stdin + sys.stdout = test_stdout + + try: + with anyio.fail_after(5): + async with stdio_server() as (read_stream, write_stream): + await write_stream.aclose() + async with read_stream: # pragma: no branch + # First line: \xff\xfe -> U+FFFD U+FFFD -> JSON parse fails -> exception in stream + first = await read_stream.receive() + assert isinstance(first, Exception) + + # Second line: valid message still comes through + second = await read_stream.receive() + assert isinstance(second, SessionMessage) + assert second.message == valid + + assert not sys.stdin.closed + assert not sys.stdout.closed + sys.stdout.write("stdio still open") + finally: + sys.stdin = original_stdin + sys.stdout = original_stdout + test_stdin.close() + test_stdout.close()