Skip to content

Commit b6ecd4d

Browse files
authored
Keep MCP alive during entire Agent lifetime (#68)
This change re-uses the MCP connections such that these are alive during the entire Agent lifetime. The mcp lib is safe for concurrent tool calls, added a stress test to prove that, also see official langchain MCP adapters, which reuses MCP session on tool calls: https://github.com/langchain-ai/langchain-mcp-adapters/blob/main/langchain_mcp_adapters/tools.py#L273
1 parent 3e3508a commit b6ecd4d

9 files changed

Lines changed: 393 additions & 205 deletions

File tree

splunklib/ai/agent.py

Lines changed: 101 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,10 @@
1212
# License for the specific language governing permissions and limitations
1313
# under the License.
1414

15+
from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontextmanager
1516
from logging import Logger
1617
import os
17-
from collections.abc import Sequence
18+
from collections.abc import AsyncGenerator, Sequence
1819
from typing import Self, final, override
1920

2021
from pydantic import BaseModel
@@ -26,7 +27,14 @@
2627
from splunklib.ai.messages import AgentResponse, BaseMessage, OutputT
2728
from splunklib.ai.model import PredefinedModel
2829
from splunklib.ai.tool_filtering import ToolFilters, filter_tools
29-
from splunklib.ai.tools import Tool, build_local_tools_path, load_mcp_tools, locate_app
30+
from splunklib.ai.tools import (
31+
Tool,
32+
build_local_tools_path,
33+
connect_local_mcp,
34+
connect_remote_mcp,
35+
load_mcp_tools,
36+
locate_app,
37+
)
3038
from splunklib.client import Service
3139

3240
# For testing purposes, overrides the automatically inferred tools.py path.
@@ -101,14 +109,15 @@ class Agent(BaseAgent[OutputT]):
101109
is appropriate for a given task. Ignored for top-level agents.
102110
103111
logger:
104-
Optional logger instance used for tracing and debugging the agents execution.
112+
Optional logger instance used for tracing and debugging the agent's execution.
105113
Additionally logs from the local tools are forwarded to this logger.
106114
"""
107115

108116
_impl: AgentImpl[OutputT] | None
109117
_use_mcp_tools: bool
110118
_service: Service
111119
_tool_filters: ToolFilters | None
120+
_agent_context_manager: AbstractAsyncContextManager[Self] | None = None
112121

113122
def __init__(
114123
self,
@@ -119,9 +128,9 @@ def __init__(
119128
tool_filters: ToolFilters | None = None,
120129
agents: Sequence[BaseAgent[BaseModel | None]] | None = None,
121130
output_schema: type[OutputT] | None = None,
122-
input_schema: type[BaseModel] | None = None, # Only used by Subgents
131+
input_schema: type[BaseModel] | None = None, # Only used by Subagents
123132
hooks: Sequence[AgentHook] | None = None,
124-
name: str = "", # Only used by Subgents
133+
name: str = "", # Only used by Subagents
125134
description: str = "", # Only used by Subagents
126135
logger: Logger | None = None,
127136
) -> None:
@@ -142,36 +151,94 @@ def __init__(
142151
self._service = service
143152
self._impl = None
144153

145-
async def __aenter__(self) -> Self:
146-
if self._impl:
147-
raise AssertionError("Agent is already in `async with` context")
148-
149-
if self.name:
150-
self.logger.debug(f"Creating agent {self.name}; trace_id={self.trace_id}")
151-
else:
152-
self.logger.debug(f"Creating agent; trace_id={self.trace_id}")
153-
154-
if self._use_mcp_tools:
155-
self._tools = await _load_tools_from_mcp(
156-
self._service,
157-
self._tool_filters,
158-
self.trace_id,
159-
self.logger,
154+
@asynccontextmanager
155+
async def _start_agent(self) -> AsyncGenerator[Self]:
156+
async with AsyncExitStack() as stack:
157+
assert self._impl is None, (
158+
"internal error: _impl was not set to None after agent invocation"
160159
)
161160

162-
backend = get_backend()
163-
self._impl = await backend.create_agent(self)
161+
if self.name:
162+
self.logger.debug(
163+
f"Creating agent {self.name}; trace_id={self.trace_id}"
164+
)
165+
else:
166+
self.logger.debug(f"Creating agent; trace_id={self.trace_id}")
167+
168+
if self._use_mcp_tools:
169+
tools: list[Tool] = []
170+
171+
self.logger.debug("Local tool registry detected")
172+
local_tools_path, app_id = _local_tools_path()
173+
if local_tools_path:
174+
local_session = await stack.enter_async_context(
175+
connect_local_mcp(local_tools_path, self.logger)
176+
)
177+
self.logger.debug("Loading local tools")
178+
local_tools = await load_mcp_tools(
179+
local_session, "local", app_id, self.trace_id, self._service
180+
)
181+
self.logger.debug(f"Local tools loaded; {local_tools=}")
182+
tools.extend(local_tools)
183+
184+
self.logger.debug("Probing MCP Server App availability")
185+
remote_session = await stack.enter_async_context(
186+
connect_remote_mcp(
187+
self._service,
188+
app_id,
189+
self.trace_id,
190+
)
191+
)
192+
if remote_session:
193+
self.logger.debug("Loading remote tools - MCP Server available")
194+
remote_tools = await load_mcp_tools(
195+
remote_session,
196+
"remote",
197+
app_id,
198+
self.trace_id,
199+
self._service,
200+
)
201+
self.logger.debug(f"Remote tools loaded; {remote_tools=}")
202+
tools.extend(remote_tools)
203+
204+
if self._tool_filters:
205+
tools = filter_tools(tools, self._tool_filters)
206+
207+
self.logger.debug(
208+
f"Tools loaded & filtered successfully; tools_after_filtering={[tool.name for tool in tools]}"
209+
)
210+
211+
self._tools = tools
212+
213+
backend = get_backend()
214+
self._impl = await backend.create_agent(self)
215+
216+
if self.name:
217+
self.logger.debug(
218+
f"Agent {self.name} created; trace_id={self.trace_id}"
219+
)
220+
else:
221+
self.logger.debug(f"Agent created; trace_id={self.trace_id}")
222+
223+
yield self
224+
225+
self._impl = None
164226

165-
if self.name:
166-
self.logger.debug(f"Agent {self.name} created; trace_id={self.trace_id}")
167-
else:
168-
self.logger.debug(f"Agent created; trace_id={self.trace_id}")
169-
170-
return self
171-
172-
async def __aexit__(self, exc_type, exc_value, traceback) -> None: # noqa: ANN001 # pyright: ignore[reportUnknownParameterType, reportMissingParameterType]
173-
self._impl = None # Make sure invoke fails if called after exit.
174-
return None
227+
async def __aenter__(self) -> Self:
228+
if self._agent_context_manager:
229+
raise AssertionError("Agent is already in `async with` context")
230+
self._agent_context_manager = self._start_agent()
231+
return await self._agent_context_manager.__aenter__()
232+
233+
async def __aexit__(
234+
self, exc_type: ..., exc_value: ..., traceback: ...
235+
) -> bool | None:
236+
assert self._agent_context_manager is not None
237+
return await self._agent_context_manager.__aexit__(
238+
exc_type,
239+
exc_value,
240+
traceback,
241+
)
175242

176243
@override
177244
async def invoke(self, messages: list[BaseMessage]) -> AgentResponse[OutputT]:
@@ -181,12 +248,7 @@ async def invoke(self, messages: list[BaseMessage]) -> AgentResponse[OutputT]:
181248
return await self._impl.invoke(messages)
182249

183250

184-
async def _load_tools_from_mcp(
185-
service: Service,
186-
filters: ToolFilters | None,
187-
trace_id: str,
188-
logger: Logger,
189-
) -> list[Tool]:
251+
def _local_tools_path() -> tuple[str | None, str]:
190252
local_tools_path = _testing_local_tools_path
191253
app_id = _testing_app_id
192254

@@ -201,14 +263,4 @@ async def _load_tools_from_mcp(
201263
if not os.path.exists(local_tools_path):
202264
local_tools_path = None
203265

204-
mcp_tools = await load_mcp_tools(
205-
service, local_tools_path, app_id, trace_id, logger
206-
)
207-
if filters:
208-
return filter_tools(mcp_tools, filters)
209-
210-
logger.debug(
211-
f"Tools loaded & filtered successfully; tools_after_filtering={[tool.name for tool in mcp_tools]}"
212-
)
213-
214-
return mcp_tools
266+
return local_tools_path, app_id

splunklib/ai/engines/langchain.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
from collections.abc import Sequence
1919
from dataclasses import asdict, dataclass
2020
from functools import partial
21-
from time import monotonic
2221
from typing import Any, Awaitable, Callable, cast, override
2322

2423
from langchain.agents import create_agent

splunklib/ai/registry.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,20 +85,29 @@ def _min_logging_level(level: types.LoggingLevel) -> int:
8585
return logging.CRITICAL
8686

8787

88+
@dataclass
89+
class LogData:
90+
tool_name: str
91+
message: str
92+
93+
8894
class _MCPLoggingHandler(logging.Handler):
8995
_group: asyncio.TaskGroup
9096
_session: ServerSession
9197
_request_id: types.RequestId
98+
_tool_name: str
9299

93100
def __init__(
94101
self,
95102
group: asyncio.TaskGroup,
96103
session: ServerSession,
97104
request_id: types.RequestId,
105+
tool_name: str,
98106
) -> None:
99107
self._group = group
100108
self._session = session
101109
self._request_id = request_id
110+
self._tool_name = tool_name
102111
super().__init__()
103112

104113
@override
@@ -108,7 +117,7 @@ def emit(self, record: logging.LogRecord) -> None:
108117
async def send_log() -> None:
109118
await self._session.send_log_message(
110119
level=mcp_level,
111-
data=record.msg,
120+
data=asdict(LogData(tool_name=self._tool_name, message=record.msg)),
112121
logger="",
113122
related_request_id=self._request_id,
114123
)
@@ -265,6 +274,7 @@ async def _call_tool(
265274
task_group,
266275
req_ctx.session,
267276
req_ctx.request_id,
277+
name,
268278
)
269279

270280
# Create a logger that forwards all logs to the client over MCP.

0 commit comments

Comments
 (0)