55from unittest .mock import patch
66
77import pytest
8+ from starlette .middleware import Middleware
89import uvicorn
9- from mcp .server .fastmcp import FastMCP
10+ from mcp .server .fastmcp import Context , FastMCP
1011from pydantic import BaseModel
1112from starlette .applications import Starlette
1213from starlette .requests import Request
1314from starlette .responses import JSONResponse , Response
1415from starlette .routing import Mount , Route
16+ from starlette .middleware .base import BaseHTTPMiddleware
1517
1618from splunklib .ai import Agent
1719from splunklib .ai .messages import HumanMessage , ToolMessage
1820from splunklib .ai .tool_filtering import ToolFilters
1921from splunklib .ai .tools import (
2022 _get_splunk_token_for_mcp ,
2123 _get_splunk_username ,
22- locate_tools_path_by_sdk_location ,
24+ locate_app ,
2325)
2426from splunklib .client import connect
2527from tests import testlib
@@ -38,6 +40,7 @@ class TestTools(AITestCase):
3840 "weather.py" ,
3941 ),
4042 )
43+ @patch ("splunklib.ai.agent._testing_app_id" , "app_id" )
4144 async def test_tool_execution_structured_output (self ) -> None :
4245 # Skip if the langchain_openai package is not installed
4346 pytest .importorskip ("langchain_openai" )
@@ -77,6 +80,7 @@ async def test_tool_execution_structured_output(self) -> None:
7780 "tool_context.py" ,
7881 ),
7982 )
83+ @patch ("splunklib.ai.agent._testing_app_id" , "app_id" )
8084 async def test_tool_execution_service_access (self ) -> None :
8185 # Skip if the langchain_openai package is not installed
8286 pytest .importorskip ("langchain_openai" )
@@ -114,6 +118,7 @@ async def test_tool_execution_service_access(self) -> None:
114118 "splunklib.ai.agent._testing_local_tools_path" ,
115119 os .path .join (os .path .dirname (__file__ ), "testdata" , "tool_filtering.py" ),
116120 )
121+ @patch ("splunklib.ai.agent._testing_app_id" , "app_id" )
117122 @pytest .mark .asyncio
118123 async def test_agent_filtering_tools (self ) -> None :
119124 pytest .importorskip ("langchain_openai" )
@@ -151,16 +156,17 @@ def test_get_splunk_username(self) -> None:
151156 self .assertEqual (_get_splunk_username (service ), self .service .username )
152157
153158
154- class TestToolsPathInference :
155- def test_infer_tools_path (self ) -> None :
159+ class TestAppLocate :
160+ def test_locate_app (self ) -> None :
156161 path = os .path .join (os .path .dirname (__file__ ), "testdata" , "app-inference" )
157- got = locate_tools_path_by_sdk_location (
162+ app_id , app_dir = locate_app (
158163 splunk_home = path ,
159164 sdk_location_path = os .path .join (
160165 path , "etc" , "apps" , "appname" , "bin" , "lib" , "somefile.py"
161166 ),
162167 )
163- assert got == os .path .join (path , "etc" , "apps" , "appname" , "bin" , "tools.py" )
168+ assert app_id == "appname"
169+ assert app_dir == os .path .join (path , "etc" , "apps" , "appname" )
164170
165171
166172AUTH_TOKEN = "foobarbaz"
@@ -197,14 +203,26 @@ class TestRemoteTools(AITestCase):
197203 "non_existent.py" ,
198204 ),
199205 )
206+ @patch ("splunklib.ai.agent._testing_app_id" , "fancyapp" )
200207 @pytest .mark .asyncio
201208 async def test_remote_tools (self ):
202209 pytest .importorskip ("langchain_openai" )
203210
204211 mcp = FastMCP ("MCP Server" , streamable_http_path = "/" )
205212
213+ trace_id : str | None = None
214+ app_id : str | None = None
215+
206216 @mcp .tool (description = "Returns the current temperature in the city" )
207- def temperature (city : str ) -> str :
217+ def temperature (ctx : Context , city : str ) -> str :
218+ nonlocal trace_id , app_id
219+ assert trace_id is None and app_id is None
220+ assert ctx .request_context .meta is not None
221+ meta = ctx .request_context .meta .model_dump ()
222+ splunk = meta .get ("splunk" , {})
223+ trace_id = splunk .get ("trace_id" )
224+ app_id = splunk .get ("app_id" )
225+
208226 if city == "Krakow" :
209227 return "31.5C"
210228 else :
@@ -215,6 +233,29 @@ async def lifespan(app: Starlette):
215233 async with mcp .session_manager .run ():
216234 yield
217235
236+ http_trace_id : str | None = None
237+ http_app_id : str | None = None
238+ middleware_called = False
239+
240+ class MCPMiddleware (BaseHTTPMiddleware ):
241+ async def dispatch (self , request : Request , call_next ):
242+ if request .url .path .startswith ("/services/mcp/" ):
243+ nonlocal http_trace_id , http_app_id , middleware_called
244+
245+ trace_id = request .headers .get ("x-splunk-trace-id" )
246+ app_id = request .headers .get ("x-splunk-app-id" )
247+
248+ # Make sure header values do not change over time.
249+ if middleware_called :
250+ assert http_trace_id == trace_id
251+ assert http_app_id == app_id
252+
253+ middleware_called = True
254+ http_trace_id = trace_id
255+ http_app_id = app_id
256+
257+ return await call_next (request )
258+
218259 async with run_http_server (
219260 Starlette (
220261 routes = [
@@ -226,6 +267,7 @@ async def lifespan(app: Starlette):
226267 ),
227268 ],
228269 lifespan = lifespan ,
270+ middleware = [Middleware (MCPMiddleware )],
229271 )
230272 ) as (host , port ):
231273 service = await asyncio .to_thread (
@@ -266,6 +308,11 @@ async def lifespan(app: Starlette):
266308 response = result .messages [- 1 ].content
267309 assert "31.5" in response , "Invalid LLM response"
268310
311+ assert trace_id == agent .trace_id
312+ assert app_id == "fancyapp"
313+ assert http_trace_id == agent .trace_id
314+ assert http_app_id == "fancyapp"
315+
269316 @patch (
270317 "splunklib.ai.agent._testing_local_tools_path" ,
271318 os .path .join (
@@ -274,6 +321,7 @@ async def lifespan(app: Starlette):
274321 "non_existent.py" ,
275322 ),
276323 )
324+ @patch ("splunklib.ai.agent._testing_app_id" , "app_id" )
277325 @pytest .mark .asyncio
278326 async def test_remote_tools_mcp_app_unavail (self ):
279327 pytest .importorskip ("langchain_openai" )
@@ -326,6 +374,7 @@ async def test_remote_tools_mcp_app_unavail(self):
326374 "non_existent.py" ,
327375 ),
328376 )
377+ @patch ("splunklib.ai.agent._testing_app_id" , "app_id" )
329378 @pytest .mark .asyncio
330379 async def test_remote_tools_failure (self ):
331380 pytest .importorskip ("langchain_openai" )
0 commit comments