Skip to content

Commit 7e1ec18

Browse files
authored
Don't use tokens to create a Service in local tools (#64)
Instead serialize all of the auth fields and propagate them to the local tools. I also noted during testing that on Cloud, that token auth needs to be enabled explicitly and without doing so the authorization/tokens fails. With this approach we use identical credentials as the Service, thus eliminate this problem.
1 parent bbdedea commit 7e1ec18

6 files changed

Lines changed: 193 additions & 140 deletions

File tree

splunklib/ai/registry.py

Lines changed: 13 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,8 @@
3232
from mcp.server.lowlevel import Server
3333
from pydantic import TypeAdapter
3434

35-
from splunklib.binding import _spliturl
36-
from splunklib.client import Service, connect
35+
from splunklib.ai.serialized_service import SerializedService
36+
from splunklib.client import Service
3737

3838

3939
def _normalize_logger_level(levelno: int) -> int:
@@ -147,8 +147,7 @@ class _ToolContextParams:
147147
in this class name i.e. internal class).
148148
"""
149149

150-
management_url: str | None
151-
management_token: str | None
150+
service: SerializedService | None
152151
logger: Logger
153152

154153

@@ -176,21 +175,13 @@ def service(self) -> Service:
176175
if self._service is not None:
177176
return self._service
178177

179-
assert all((self._params.management_url, self._params.management_token)), (
180-
"Invalid tool invocation, missing management_url and/or management_token"
178+
assert self._params.service is not None, (
179+
"Invalid tool invocation, missing serialized service details"
181180
)
182181

183-
scheme, host, port, path = _spliturl(self._params.management_url)
184-
s = connect(
185-
scheme=scheme,
186-
host=host,
187-
port=port,
188-
path=path,
189-
token=self._params.management_token,
190-
autologin=True,
191-
)
192-
self._service = s
193-
return s
182+
# TODO: Shouldn't this function be async and this use asyncio.to_thread()?
183+
self._service = self._params.service.connect()
184+
return self._service
194185

195186
@property
196187
def logger(self) -> Logger:
@@ -281,20 +272,19 @@ async def _call_tool(
281272
logger.setLevel(_min_logging_level(self._logging_level))
282273
logger.addHandler(handler)
283274

284-
management_url: str | None = None
285-
management_token: str | None = None
275+
service: SerializedService | None = None
286276

287277
meta = req_ctx.meta
288278
if meta is not None:
289279
splunk_meta = meta.model_dump().get("splunk")
290280
if splunk_meta is not None:
291-
management_url = splunk_meta.get("management_url")
292-
management_token = splunk_meta.get("management_token")
281+
service = SerializedService.model_validate(
282+
splunk_meta.get("service")
283+
)
293284

294285
ctx = ToolContext(
295286
params=_ToolContextParams(
296-
management_url=management_url,
297-
management_token=management_token,
287+
service=service,
298288
logger=logger,
299289
)
300290
)

splunklib/ai/serialized_service.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
# Copyright © 2011-2026 Splunk, Inc.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"): you may
4+
# not use this file except in compliance with the License. You may obtain
5+
# a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11+
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12+
# License for the specific language governing permissions and limitations
13+
# under the License.
14+
15+
16+
from typing import Self
17+
18+
from pydantic import BaseModel
19+
20+
from splunklib.binding import _spliturl
21+
from splunklib.client import Service, connect
22+
23+
24+
class SerializedService(BaseModel):
25+
management_url: str = ""
26+
username: str | None = None
27+
password: str | None = None
28+
token: str | None = None
29+
bearer_token: str | None = None
30+
auth_cookies: dict[str, str] | None = None
31+
32+
@classmethod
33+
def from_service(cls, service: Service) -> Self:
34+
return cls(
35+
management_url=f"{service.scheme}://{service.host}:{service.port}", # pyright: ignore[reportUnknownMemberType]
36+
username=service.username if service.username else None, # pyright: ignore[reportUnknownMemberType, reportUnknownArgumentType]
37+
password=service.password if service.password else None, # pyright: ignore[reportUnknownMemberType, reportUnknownArgumentType]
38+
token=service.token if isinstance(service.token, str) else None, # pyright: ignore[reportUnknownMemberType, reportArgumentType]
39+
bearer_token=service.bearerToken if service.bearerToken else None, # pyright: ignore[reportUnknownMemberType, reportUnknownArgumentType]
40+
auth_cookies=(
41+
service.get_cookies() if len(service.get_cookies()) != 0 else None # pyright: ignore[reportUnknownArgumentType]
42+
),
43+
)
44+
45+
def connect(self) -> Service:
46+
scheme, host, port, path = _spliturl(self.management_url) # pyright: ignore[reportUnknownVariableType]
47+
return connect(
48+
scheme=scheme, # pyright: ignore[reportUnknownArgumentType]
49+
host=host, # pyright: ignore[reportUnknownArgumentType]
50+
port=port,
51+
path=path,
52+
username=self.username if self.username else None,
53+
password=self.password if self.password else None,
54+
token=self.token if self.token else None,
55+
splunkToken=self.bearer_token if self.bearer_token else None,
56+
cookie="; ".join(
57+
f"{key}={self.auth_cookies[key]}" for key in self.auth_cookies
58+
)
59+
if self.auth_cookies
60+
else None,
61+
autologin=True,
62+
)

splunklib/ai/tools.py

Lines changed: 4 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from mcp.types import Tool as MCPTool
2424
from pydantic import BaseModel
2525

26+
from splunklib.ai.serialized_service import SerializedService
2627
from splunklib.binding import HTTPError
2728
from splunklib.client import Service
2829
from splunklib.ai.registry import _map_logger_to_mcp_logging_level
@@ -137,8 +138,7 @@ async def __call__(
137138
@dataclass
138139
class LocalCfg:
139140
tools_path: str
140-
management_url: str
141-
token: str
141+
service: SerializedService
142142

143143

144144
@dataclass
@@ -256,8 +256,7 @@ async def call_tool(
256256
# Provide access to the splunk instance in local tools.
257257
# No need to do anything special for remote tools, since
258258
# these tools are already authenticated with the token.
259-
"management_url": cfg.management_url,
260-
"management_token": cfg.token,
259+
"service": cfg.service.model_dump(),
261260
# Currently we don't need to send the trace_id and app_id to local tools, since
262261
# that is only really needed to correlate logs, but for local tools we know
263262
# that logs coming from the local tool registry are already reladed to this
@@ -342,30 +341,6 @@ class ResponseBody(BaseModel):
342341
return body.entry[0].content.username
343342

344343

345-
def _get_splunk_token_for_mcp(service: Service) -> str:
346-
res = service.post(
347-
path_segment="authorization/tokens",
348-
name=_get_splunk_username(service),
349-
audience="mcp",
350-
type="ephemeral",
351-
output_mode="json",
352-
)
353-
354-
class Content(BaseModel):
355-
token: str
356-
357-
class Entry(BaseModel):
358-
content: Content
359-
360-
class ResponseBody(BaseModel):
361-
entry: list[Entry]
362-
363-
body = ResponseBody.model_validate_json(str(res.body))
364-
if len(body.entry) == 0:
365-
return ""
366-
return body.entry[0].content.token
367-
368-
369344
def _get_mcp_token(service: Service) -> str | None:
370345
try:
371346
res = service.get(
@@ -401,7 +376,6 @@ async def load_mcp_tools(
401376

402377
management_url = f"{service.scheme}://{service.host}:{service.port}"
403378
mcp_url = f"{management_url}/services/mcp"
404-
token = await asyncio.to_thread(lambda: _get_splunk_token_for_mcp(service))
405379

406380
mcp_token = await asyncio.to_thread(lambda: _get_mcp_token(service))
407381
if mcp_token is not None:
@@ -425,11 +399,7 @@ async def load_mcp_tools(
425399
local_tools = await _load_tools(
426400
LocalCfg(
427401
tools_path=local_tools_path,
428-
management_url=management_url,
429-
# TODO: Is this right? I think we should do this differentlly and either serialize
430-
# the Service auth fields and send them or generate a separate token, that does not have
431-
# the "mcp" audience set.
432-
token=token,
402+
service=SerializedService.from_service(service),
433403
),
434404
logger,
435405
)

tests/integration/ai/test_agent_mcp_tools.py

Lines changed: 16 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import contextlib
33
from dataclasses import asdict, dataclass
44
import logging
5+
import json
56
import os
67
import socket
78
from typing import Annotated
@@ -23,7 +24,6 @@
2324
from splunklib.ai.messages import HumanMessage, ToolMessage
2425
from splunklib.ai.tool_filtering import ToolFilters
2526
from splunklib.ai.tools import (
26-
_get_splunk_token_for_mcp,
2727
_get_splunk_username,
2828
locate_app,
2929
)
@@ -140,21 +140,30 @@ async def test_agent_filtering_tools(self) -> None:
140140
assert tool_names == ["test_tool_1", "test_tool_2", "test_tool_4"]
141141

142142

143-
class TestSplunkToken(testlib.SDKTestCase):
143+
class TestSplunkGetUsername(testlib.SDKTestCase):
144+
def get_splunk_bearer_token(self) -> str:
145+
res = self.service.post(
146+
path_segment="authorization/tokens",
147+
name=self.service.username,
148+
audience="test",
149+
type="ephemeral",
150+
output_mode="json",
151+
)
152+
token = json.loads(str(res.body))["entry"][0]["content"]["token"]
153+
return token
154+
144155
def test_get_splunk_username(self) -> None:
145156
self.assertTrue(
146-
self.service.username is not None and self.service.username != ""
157+
self.service.username and self.service.password
147158
) # our CI logs-in with username and password.
148159

149160
self.assertEqual(_get_splunk_username(self.service), self.service.username)
150161

151-
token = _get_splunk_token_for_mcp(self.service)
152-
153162
service = connect(
154163
scheme=self.service.scheme,
155164
host=self.service.host,
156165
port=self.service.port,
157-
token=token,
166+
token=self.get_splunk_bearer_token(),
158167
)
159168

160169
self.assertEqual(_get_splunk_username(service), self.service.username)
@@ -176,28 +185,6 @@ def test_locate_app(self) -> None:
176185
AUTH_TOKEN = "foobarbaz"
177186

178187

179-
async def tokens_handler(request: Request) -> Response:
180-
class Content(BaseModel):
181-
token: str
182-
183-
class Entry(BaseModel):
184-
content: Content
185-
186-
class ResponseBody(BaseModel):
187-
entry: list[Entry]
188-
189-
body = ResponseBody(
190-
entry=[
191-
Entry(content=Content(token=AUTH_TOKEN)),
192-
]
193-
)
194-
195-
return JSONResponse(
196-
content=body.model_dump(),
197-
status_code=200,
198-
)
199-
200-
201188
async def mcp_token_handler(_: Request) -> Response:
202189
return JSONResponse(
203190
content={"token": AUTH_TOKEN},
@@ -276,11 +263,6 @@ async def dispatch(self, request: Request, call_next):
276263
mcp_token_handler,
277264
methods=["GET"],
278265
),
279-
Route(
280-
"/services/authorization/tokens",
281-
tokens_handler,
282-
methods=["POST"],
283-
),
284266
],
285267
lifespan=lifespan,
286268
middleware=[Middleware(MCPMiddleware)],
@@ -344,13 +326,7 @@ async def test_remote_tools_mcp_app_unavail(self):
344326

345327
async with run_http_server(
346328
Starlette(
347-
routes=[
348-
Route(
349-
"/services/authorization/tokens",
350-
tokens_handler,
351-
methods=["POST"],
352-
),
353-
],
329+
routes=[],
354330
)
355331
) as (host, port):
356332
service = await asyncio.to_thread(
@@ -420,11 +396,6 @@ async def lifespan(app: Starlette):
420396
mcp_token_handler,
421397
methods=["GET"],
422398
),
423-
Route(
424-
"/services/authorization/tokens",
425-
tokens_handler,
426-
methods=["POST"],
427-
),
428399
],
429400
lifespan=lifespan,
430401
)
@@ -521,11 +492,6 @@ async def lifespan(app: Starlette):
521492
mcp_token_handler,
522493
methods=["GET"],
523494
),
524-
Route(
525-
"/services/authorization/tokens",
526-
tokens_handler,
527-
methods=["POST"],
528-
),
529495
],
530496
lifespan=lifespan,
531497
)

0 commit comments

Comments
 (0)