Skip to content

Commit 52ed34a

Browse files
fix: preserve stdio streams in server transport
1 parent 616476f commit 52ed34a

2 files changed

Lines changed: 60 additions & 26 deletions

File tree

src/mcp/server/stdio.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ async def run_server():
1717
```
1818
"""
1919

20+
import os
2021
import sys
2122
from contextlib import asynccontextmanager
2223
from io import TextIOWrapper
@@ -38,10 +39,18 @@ async def stdio_server(stdin: anyio.AsyncFile[str] | None = None, stdout: anyio.
3839
# standard process handles. Encoding of stdin/stdout as text streams on
3940
# python is platform-dependent (Windows is particularly problematic), so we
4041
# re-wrap the underlying binary stream to ensure UTF-8.
42+
close_stdin = False
43+
close_stdout = False
4144
if not stdin:
42-
stdin = anyio.wrap_file(TextIOWrapper(sys.stdin.buffer, encoding="utf-8", errors="replace"))
45+
stdin_fd = os.dup(sys.stdin.fileno())
46+
stdin_buffer = os.fdopen(stdin_fd, "rb", closefd=True)
47+
stdin = anyio.wrap_file(TextIOWrapper(stdin_buffer, encoding="utf-8", errors="replace"))
48+
close_stdin = True
4349
if not stdout:
44-
stdout = anyio.wrap_file(TextIOWrapper(sys.stdout.buffer, encoding="utf-8"))
50+
stdout_fd = os.dup(sys.stdout.fileno())
51+
stdout_buffer = os.fdopen(stdout_fd, "wb", closefd=True)
52+
stdout = anyio.wrap_file(TextIOWrapper(stdout_buffer, encoding="utf-8"))
53+
close_stdout = True
4554

4655
read_stream_writer, read_stream = create_context_streams[SessionMessage | Exception](0)
4756
write_stream, write_stream_reader = create_context_streams[SessionMessage](0)
@@ -71,7 +80,13 @@ async def stdout_writer():
7180
except anyio.ClosedResourceError: # pragma: no cover
7281
await anyio.lowlevel.checkpoint()
7382

74-
async with anyio.create_task_group() as tg:
75-
tg.start_soon(stdin_reader)
76-
tg.start_soon(stdout_writer)
77-
yield read_stream, write_stream
83+
try:
84+
async with anyio.create_task_group() as tg:
85+
tg.start_soon(stdin_reader)
86+
tg.start_soon(stdout_writer)
87+
yield read_stream, write_stream
88+
finally:
89+
if close_stdin:
90+
await stdin.aclose()
91+
if close_stdout:
92+
await stdout.aclose()

tests/server/test_stdio.py

Lines changed: 39 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import io
22
import sys
3+
import tempfile
34
from io import TextIOWrapper
45

56
import anyio
@@ -64,7 +65,7 @@ async def test_stdio_server():
6465

6566

6667
@pytest.mark.anyio
67-
async def test_stdio_server_invalid_utf8(monkeypatch: pytest.MonkeyPatch):
68+
async def test_stdio_server_invalid_utf8():
6869
"""Non-UTF-8 bytes on stdin must not crash the server.
6970
7071
Invalid bytes are replaced with U+FFFD, which then fails JSON parsing and
@@ -73,22 +74,40 @@ async def test_stdio_server_invalid_utf8(monkeypatch: pytest.MonkeyPatch):
7374
"""
7475
# \xff\xfe are invalid UTF-8 start bytes.
7576
valid = JSONRPCRequest(jsonrpc="2.0", id=1, method="ping")
76-
raw_stdin = io.BytesIO(b"\xff\xfe\n" + valid.model_dump_json(by_alias=True, exclude_none=True).encode() + b"\n")
77-
78-
# Replace sys.stdin with a wrapper whose .buffer is our raw bytes, so that
79-
# stdio_server()'s default path wraps it with errors='replace'.
80-
monkeypatch.setattr(sys, "stdin", TextIOWrapper(raw_stdin, encoding="utf-8"))
81-
monkeypatch.setattr(sys, "stdout", TextIOWrapper(io.BytesIO(), encoding="utf-8"))
82-
83-
with anyio.fail_after(5):
84-
async with stdio_server() as (read_stream, write_stream):
85-
await write_stream.aclose()
86-
async with read_stream: # pragma: no branch
87-
# First line: \xff\xfe -> U+FFFD U+FFFD -> JSON parse fails -> exception in stream
88-
first = await read_stream.receive()
89-
assert isinstance(first, Exception)
90-
91-
# Second line: valid message still comes through
92-
second = await read_stream.receive()
93-
assert isinstance(second, SessionMessage)
94-
assert second.message == valid
77+
raw_stdin = tempfile.TemporaryFile()
78+
raw_stdin.write(b"\xff\xfe\n" + valid.model_dump_json(by_alias=True, exclude_none=True).encode() + b"\n")
79+
raw_stdin.seek(0)
80+
raw_stdout = tempfile.TemporaryFile()
81+
82+
# Replace sys.stdin/stdout with wrappers backed by real file descriptors so
83+
# stdio_server()'s default path can duplicate them without closing the
84+
# original process-level streams.
85+
original_stdin = sys.stdin
86+
original_stdout = sys.stdout
87+
test_stdin = TextIOWrapper(raw_stdin, encoding="utf-8")
88+
test_stdout = TextIOWrapper(raw_stdout, encoding="utf-8")
89+
sys.stdin = test_stdin
90+
sys.stdout = test_stdout
91+
92+
try:
93+
with anyio.fail_after(5):
94+
async with stdio_server() as (read_stream, write_stream):
95+
await write_stream.aclose()
96+
async with read_stream: # pragma: no branch
97+
# First line: \xff\xfe -> U+FFFD U+FFFD -> JSON parse fails -> exception in stream
98+
first = await read_stream.receive()
99+
assert isinstance(first, Exception)
100+
101+
# Second line: valid message still comes through
102+
second = await read_stream.receive()
103+
assert isinstance(second, SessionMessage)
104+
assert second.message == valid
105+
106+
assert not sys.stdin.closed
107+
assert not sys.stdout.closed
108+
sys.stdout.write("stdio still open")
109+
finally:
110+
sys.stdin = original_stdin
111+
sys.stdout = original_stdout
112+
test_stdin.close()
113+
test_stdout.close()

0 commit comments

Comments
 (0)