Skip to content

Commit c8ae043

Browse files
committed
mcp(fix[_utils]): Add thread safety to server cache and fix invalidation
why: FastMCP runs sync tool functions in a thread pool via `anyio.to_thread.run_sync()`. The compound check-then-act pattern in `_get_server` (check `in`, access `[]`, possibly `del`) was not atomic, allowing concurrent tool calls to hit a `KeyError` when one thread deletes a dead server's cache entry between another thread's `in` check and `[]` access. Additionally, `_invalidate_server` did not resolve env vars (`LIBTMUX_SOCKET`, `LIBTMUX_SOCKET_PATH`), so calling `_invalidate_server(socket_name=None)` would search for `key[0] == None` but the cache key created by `_get_server` used the resolved env var value. what: - Add `threading.Lock` to protect `_server_cache` in both `_get_server` and `_invalidate_server` - Add env var resolution to `_invalidate_server` to match `_get_server`
1 parent b793fec commit c8ae043

1 file changed

Lines changed: 30 additions & 19 deletions

File tree

src/libtmux/mcp/_utils.py

Lines changed: 30 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import json
1111
import logging
1212
import os
13+
import threading
1314
import typing as t
1415

1516
from libtmux import exc
@@ -24,6 +25,7 @@
2425
logger = logging.getLogger(__name__)
2526

2627
_server_cache: dict[tuple[str | None, str | None, str | None], Server] = {}
28+
_server_cache_lock = threading.Lock()
2729

2830

2931
def _get_server(
@@ -52,22 +54,23 @@ def _get_server(
5254
tmux_bin = os.environ.get("LIBTMUX_TMUX_BIN")
5355

5456
cache_key = (socket_name, socket_path, tmux_bin)
55-
if cache_key in _server_cache:
56-
cached = _server_cache[cache_key]
57-
if not cached.is_alive():
58-
del _server_cache[cache_key]
57+
with _server_cache_lock:
58+
if cache_key in _server_cache:
59+
cached = _server_cache[cache_key]
60+
if not cached.is_alive():
61+
del _server_cache[cache_key]
5962

60-
if cache_key not in _server_cache:
61-
kwargs: dict[str, t.Any] = {}
62-
if socket_name is not None:
63-
kwargs["socket_name"] = socket_name
64-
if socket_path is not None:
65-
kwargs["socket_path"] = socket_path
66-
if tmux_bin is not None:
67-
kwargs["tmux_bin"] = tmux_bin
68-
_server_cache[cache_key] = Server(**kwargs)
63+
if cache_key not in _server_cache:
64+
kwargs: dict[str, t.Any] = {}
65+
if socket_name is not None:
66+
kwargs["socket_name"] = socket_name
67+
if socket_path is not None:
68+
kwargs["socket_path"] = socket_path
69+
if tmux_bin is not None:
70+
kwargs["tmux_bin"] = tmux_bin
71+
_server_cache[cache_key] = Server(**kwargs)
6972

70-
return _server_cache[cache_key]
73+
return _server_cache[cache_key]
7174

7275

7376
def _invalidate_server(
@@ -83,11 +86,19 @@ def _invalidate_server(
8386
socket_path : str, optional
8487
tmux socket path used in the cache key.
8588
"""
86-
keys_to_remove = [
87-
key for key in _server_cache if key[0] == socket_name and key[1] == socket_path
88-
]
89-
for key in keys_to_remove:
90-
del _server_cache[key]
89+
if socket_name is None:
90+
socket_name = os.environ.get("LIBTMUX_SOCKET")
91+
if socket_path is None:
92+
socket_path = os.environ.get("LIBTMUX_SOCKET_PATH")
93+
94+
with _server_cache_lock:
95+
keys_to_remove = [
96+
key
97+
for key in _server_cache
98+
if key[0] == socket_name and key[1] == socket_path
99+
]
100+
for key in keys_to_remove:
101+
del _server_cache[key]
91102

92103

93104
def _resolve_session(

0 commit comments

Comments
 (0)