Skip to content

Commit 9304de2

Browse files
authored
Propagate trace_id & app_id to MCP (#49)
This change makes the Agent generate a trace_id during startup and propagates it with app_id to the MCP Server App in `_meta` fields of MCP requests and in the HTTP headers.
1 parent 8e222bd commit 9304de2

5 files changed

Lines changed: 121 additions & 30 deletions

File tree

splunklib/ai/agent.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -26,15 +26,12 @@
2626
from splunklib.ai.messages import AgentResponse, BaseMessage, OutputT
2727
from splunklib.ai.model import PredefinedModel
2828
from splunklib.ai.tool_filtering import ToolFilters, filter_tools
29-
from splunklib.ai.tools import (
30-
Tool,
31-
load_mcp_tools,
32-
locate_tools_path_by_sdk_location,
33-
)
29+
from splunklib.ai.tools import Tool, build_local_tools_path, load_mcp_tools, locate_app
3430
from splunklib.client import Service
3531

3632
# For testing purposes, overrides the automatically inferred tools.py path.
3733
_testing_local_tools_path: str | None = None
34+
_testing_app_id: str | None = None
3835

3936

4037
@final
@@ -149,7 +146,9 @@ async def __aenter__(self) -> Self:
149146
raise AssertionError("Agent is already in `async with` context")
150147

151148
if self._use_mcp_tools:
152-
self._tools = await _load_tools_from_mcp(self._service, self._tool_filters)
149+
self._tools = await _load_tools_from_mcp(
150+
self._service, self._tool_filters, self.trace_id
151+
)
153152

154153
backend = get_backend()
155154
self._impl = await backend.create_agent(self)
@@ -169,16 +168,25 @@ async def invoke(self, messages: list[BaseMessage]) -> AgentResponse[OutputT]:
169168

170169

171170
async def _load_tools_from_mcp(
172-
service: Service, filters: ToolFilters | None
171+
service: Service,
172+
filters: ToolFilters | None,
173+
trace_id: str,
173174
) -> list[Tool]:
174175
local_tools_path = _testing_local_tools_path
176+
app_id = _testing_app_id
177+
175178
if local_tools_path is None:
176-
local_tools_path = locate_tools_path_by_sdk_location()
179+
app_id, app_dir = locate_app()
180+
local_tools_path = build_local_tools_path(app_dir)
181+
182+
assert app_id is not None, (
183+
"_load_tools_from_mcp was mocked, but _testing_app_id not"
184+
)
177185

178186
if not os.path.exists(local_tools_path):
179187
local_tools_path = None
180188

181-
mcp_tools = await load_mcp_tools(service, local_tools_path)
189+
mcp_tools = await load_mcp_tools(service, local_tools_path, app_id, trace_id)
182190
if filters:
183191
return filter_tools(mcp_tools, filters)
184192

splunklib/ai/base_agent.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
from abc import ABC, abstractmethod
1717
from collections.abc import Sequence
18+
import secrets
1819
from typing import Generic
1920

2021
from pydantic import BaseModel
@@ -35,6 +36,7 @@ class BaseAgent(Generic[OutputT], ABC):
3536
_input_schema: type[BaseModel] | None = None
3637
_output_schema: type[OutputT] | None = None
3738
_hooks: Sequence[AgentHook] | None = None
39+
_trace_id: str
3840

3941
def __init__(
4042
self,
@@ -57,6 +59,7 @@ def __init__(
5759
self._input_schema = input_schema
5860
self._output_schema = output_schema
5961
self._hooks = tuple(hooks) if hooks else ()
62+
self._trace_id = secrets.token_hex(16) # 32 Hex characters
6063

6164
@abstractmethod
6265
async def invoke(self, messages: list[BaseMessage]) -> AgentResponse[OutputT]: ...
@@ -96,3 +99,7 @@ def output_schema(self) -> type[OutputT] | None:
9699
@property
97100
def hooks(self) -> Sequence[AgentHook] | None:
98101
return self._hooks
102+
103+
@property
104+
def trace_id(self) -> str:
105+
return self._trace_id

splunklib/ai/registry.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ class ToolContext:
3535

3636
_management_url: str | None = None
3737
_management_token: str | None = None
38+
3839
_service: Service | None = None
3940

4041
@property

splunklib/ai/tools.py

Lines changed: 40 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,9 @@ def _splunk_home() -> str:
4747
return splunk_home
4848

4949

50-
def locate_tools_path_by_sdk_location(
50+
def locate_app(
5151
splunk_home: str | None = None, sdk_location_path: str = __file__
52-
) -> str:
52+
) -> tuple[str, str]:
5353
"""
5454
This function returns the path to the tools file of the app, assumes that the SDK
5555
is vendored into the app.
@@ -76,7 +76,11 @@ def locate_tools_path_by_sdk_location(
7676
assert parts[0] != "." and parts[1] != ".."
7777

7878
app_id = parts[0]
79-
return os.path.join(splunk_home, "etc", "apps", app_id, "bin", TOOLS_FILENAME)
79+
return (app_id, os.path.join(splunk_home, "etc", "apps", app_id))
80+
81+
82+
def build_local_tools_path(dir: str) -> str:
83+
return os.path.join(dir, "bin", TOOLS_FILENAME)
8084

8185

8286
@dataclass
@@ -90,6 +94,8 @@ class LocalCfg:
9094
class RemoteCfg:
9195
mcp_url: str
9296
token: str
97+
app_id: str
98+
trace_id: str
9399

94100

95101
@asynccontextmanager
@@ -120,6 +126,10 @@ async def _connect_remote_mcp(cfg: RemoteCfg):
120126
async with streamable_http_client(
121127
url=cfg.mcp_url,
122128
http_client=httpx.AsyncClient(
129+
headers={
130+
"x-splunk-trace-id": cfg.trace_id,
131+
"x-splunk-app-id": cfg.app_id,
132+
},
123133
auth=_MCPAuth(f"Bearer {cfg.token}"),
124134
verify=False,
125135
follow_redirects=True,
@@ -174,17 +184,29 @@ def _convert_mcp_tool(
174184
async def call_tool(
175185
**arguments: dict[str, Any],
176186
) -> ToolResult:
177-
# Provide access to the splunk instance in local tools.
178-
# No need to do anything special for remote tools, since
179-
# these tools are already authenticated with the token.
180187
meta: dict[str, Any] | None = None
181-
if isinstance(cfg, LocalCfg):
182-
meta = {
183-
"splunk": {
184-
"management_url": cfg.management_url,
185-
"management_token": cfg.token,
188+
match cfg:
189+
case LocalCfg():
190+
meta = {
191+
"splunk": {
192+
# Provide access to the splunk instance in local tools.
193+
# No need to do anything special for remote tools, since
194+
# these tools are already authenticated with the token.
195+
"management_url": cfg.management_url,
196+
"management_token": cfg.token,
197+
# Currently we don't need to send the trace_id and app_id to local tools, since
198+
# that is only really needed to correlate logs, but for local tools we know
199+
# that logs coming from the local tool registry are already reladed to this
200+
# agent.
201+
}
202+
}
203+
case RemoteCfg():
204+
meta = {
205+
"splunk": {
206+
"trace_id": cfg.trace_id,
207+
"app_id": cfg.app_id,
208+
}
186209
}
187-
}
188210

189211
async with _connect(cfg) as session:
190212
call_tool_result = await session.call_tool(
@@ -291,7 +313,9 @@ async def _load_tools(cfg: LocalCfg | RemoteCfg) -> list[Tool]:
291313

292314
async def load_mcp_tools(
293315
service: Service,
294-
local_tools_path: str | None = None,
316+
local_tools_path: str | None,
317+
app_id: str,
318+
trace_id: str,
295319
) -> list[Tool]:
296320
# TODO: Add tool.name collision between local/remote tools
297321
tools: list[Tool] = []
@@ -304,7 +328,9 @@ async def load_mcp_tools(
304328
client = httpx.AsyncClient(auth=_MCPAuth(f"Bearer {token}"), verify=False)
305329
res = await client.get(mcp_url)
306330
if res.status_code != 404:
307-
remote_tools = await _load_tools(RemoteCfg(mcp_url=mcp_url, token=token))
331+
remote_tools = await _load_tools(
332+
RemoteCfg(mcp_url=mcp_url, token=token, app_id=app_id, trace_id=trace_id)
333+
)
308334
tools.extend(remote_tools)
309335

310336
if local_tools_path is not None:

tests/integration/ai/test_agent_mcp_tools.py

Lines changed: 56 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,21 +5,23 @@
55
from unittest.mock import patch
66

77
import pytest
8+
from starlette.middleware import Middleware
89
import uvicorn
9-
from mcp.server.fastmcp import FastMCP
10+
from mcp.server.fastmcp import Context, FastMCP
1011
from pydantic import BaseModel
1112
from starlette.applications import Starlette
1213
from starlette.requests import Request
1314
from starlette.responses import JSONResponse, Response
1415
from starlette.routing import Mount, Route
16+
from starlette.middleware.base import BaseHTTPMiddleware
1517

1618
from splunklib.ai import Agent
1719
from splunklib.ai.messages import HumanMessage, ToolMessage
1820
from splunklib.ai.tool_filtering import ToolFilters
1921
from splunklib.ai.tools import (
2022
_get_splunk_token_for_mcp,
2123
_get_splunk_username,
22-
locate_tools_path_by_sdk_location,
24+
locate_app,
2325
)
2426
from splunklib.client import connect
2527
from tests import testlib
@@ -38,6 +40,7 @@ class TestTools(AITestCase):
3840
"weather.py",
3941
),
4042
)
43+
@patch("splunklib.ai.agent._testing_app_id", "app_id")
4144
async def test_tool_execution_structured_output(self) -> None:
4245
# Skip if the langchain_openai package is not installed
4346
pytest.importorskip("langchain_openai")
@@ -77,6 +80,7 @@ async def test_tool_execution_structured_output(self) -> None:
7780
"tool_context.py",
7881
),
7982
)
83+
@patch("splunklib.ai.agent._testing_app_id", "app_id")
8084
async def test_tool_execution_service_access(self) -> None:
8185
# Skip if the langchain_openai package is not installed
8286
pytest.importorskip("langchain_openai")
@@ -114,6 +118,7 @@ async def test_tool_execution_service_access(self) -> None:
114118
"splunklib.ai.agent._testing_local_tools_path",
115119
os.path.join(os.path.dirname(__file__), "testdata", "tool_filtering.py"),
116120
)
121+
@patch("splunklib.ai.agent._testing_app_id", "app_id")
117122
@pytest.mark.asyncio
118123
async def test_agent_filtering_tools(self) -> None:
119124
pytest.importorskip("langchain_openai")
@@ -151,16 +156,17 @@ def test_get_splunk_username(self) -> None:
151156
self.assertEqual(_get_splunk_username(service), self.service.username)
152157

153158

154-
class TestToolsPathInference:
155-
def test_infer_tools_path(self) -> None:
159+
class TestAppLocate:
160+
def test_locate_app(self) -> None:
156161
path = os.path.join(os.path.dirname(__file__), "testdata", "app-inference")
157-
got = locate_tools_path_by_sdk_location(
162+
app_id, app_dir = locate_app(
158163
splunk_home=path,
159164
sdk_location_path=os.path.join(
160165
path, "etc", "apps", "appname", "bin", "lib", "somefile.py"
161166
),
162167
)
163-
assert got == os.path.join(path, "etc", "apps", "appname", "bin", "tools.py")
168+
assert app_id == "appname"
169+
assert app_dir == os.path.join(path, "etc", "apps", "appname")
164170

165171

166172
AUTH_TOKEN = "foobarbaz"
@@ -197,14 +203,26 @@ class TestRemoteTools(AITestCase):
197203
"non_existent.py",
198204
),
199205
)
206+
@patch("splunklib.ai.agent._testing_app_id", "fancyapp")
200207
@pytest.mark.asyncio
201208
async def test_remote_tools(self):
202209
pytest.importorskip("langchain_openai")
203210

204211
mcp = FastMCP("MCP Server", streamable_http_path="/")
205212

213+
trace_id: str | None = None
214+
app_id: str | None = None
215+
206216
@mcp.tool(description="Returns the current temperature in the city")
207-
def temperature(city: str) -> str:
217+
def temperature(ctx: Context, city: str) -> str:
218+
nonlocal trace_id, app_id
219+
assert trace_id is None and app_id is None
220+
assert ctx.request_context.meta is not None
221+
meta = ctx.request_context.meta.model_dump()
222+
splunk = meta.get("splunk", {})
223+
trace_id = splunk.get("trace_id")
224+
app_id = splunk.get("app_id")
225+
208226
if city == "Krakow":
209227
return "31.5C"
210228
else:
@@ -215,6 +233,29 @@ async def lifespan(app: Starlette):
215233
async with mcp.session_manager.run():
216234
yield
217235

236+
http_trace_id: str | None = None
237+
http_app_id: str | None = None
238+
middleware_called = False
239+
240+
class MCPMiddleware(BaseHTTPMiddleware):
241+
async def dispatch(self, request: Request, call_next):
242+
if request.url.path.startswith("/services/mcp/"):
243+
nonlocal http_trace_id, http_app_id, middleware_called
244+
245+
trace_id = request.headers.get("x-splunk-trace-id")
246+
app_id = request.headers.get("x-splunk-app-id")
247+
248+
# Make sure header values do not change over time.
249+
if middleware_called:
250+
assert http_trace_id == trace_id
251+
assert http_app_id == app_id
252+
253+
middleware_called = True
254+
http_trace_id = trace_id
255+
http_app_id = app_id
256+
257+
return await call_next(request)
258+
218259
async with run_http_server(
219260
Starlette(
220261
routes=[
@@ -226,6 +267,7 @@ async def lifespan(app: Starlette):
226267
),
227268
],
228269
lifespan=lifespan,
270+
middleware=[Middleware(MCPMiddleware)],
229271
)
230272
) as (host, port):
231273
service = await asyncio.to_thread(
@@ -266,6 +308,11 @@ async def lifespan(app: Starlette):
266308
response = result.messages[-1].content
267309
assert "31.5" in response, "Invalid LLM response"
268310

311+
assert trace_id == agent.trace_id
312+
assert app_id == "fancyapp"
313+
assert http_trace_id == agent.trace_id
314+
assert http_app_id == "fancyapp"
315+
269316
@patch(
270317
"splunklib.ai.agent._testing_local_tools_path",
271318
os.path.join(
@@ -274,6 +321,7 @@ async def lifespan(app: Starlette):
274321
"non_existent.py",
275322
),
276323
)
324+
@patch("splunklib.ai.agent._testing_app_id", "app_id")
277325
@pytest.mark.asyncio
278326
async def test_remote_tools_mcp_app_unavail(self):
279327
pytest.importorskip("langchain_openai")
@@ -326,6 +374,7 @@ async def test_remote_tools_mcp_app_unavail(self):
326374
"non_existent.py",
327375
),
328376
)
377+
@patch("splunklib.ai.agent._testing_app_id", "app_id")
329378
@pytest.mark.asyncio
330379
async def test_remote_tools_failure(self):
331380
pytest.importorskip("langchain_openai")

0 commit comments

Comments
 (0)