Skip to content

Commit a9cbdc7

Browse files
authored
Expose Logger in ToolContext (#50)
This chagnge exposes a Logger inside of a ToolContext, which allows developers to instrument their tools with logs. This change only exposes such logging functionality to tools, a follow-up change will collect these logs and send them into the Agent logger.
1 parent 9304de2 commit a9cbdc7

4 files changed

Lines changed: 343 additions & 32 deletions

File tree

splunklib/ai/README.md

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,10 @@ if __name__ == "__main__":
155155
Unlike regular tool inputs, this parameter is not provided by the LLM. Instead, it is
156156
automatically injected by the runtime for every tool invocation.
157157

158-
`ToolContext` currently provides access to the SDK’s `Service` object, allowing tools to perform
158+
159+
##### Service access
160+
161+
`ToolContext` provides access to the SDK’s `Service` object, allowing tools to perform
159162
authenticated actions against Splunk on behalf of the **user who executed the Agent**.
160163

161164
```py
@@ -177,6 +180,22 @@ def runSplunkQuery(ctx: ToolContext) -> list[str]:
177180
return output
178181
```
179182

183+
##### Logger access
184+
185+
`ToolContext` exposes a `Logger` instance that can be used for logging within your tool implementation.
186+
187+
188+
```py
189+
from splunklib.ai.registry import ToolContext
190+
191+
@registry.tool()
192+
def tool(ctx: ToolContext) -> None:
193+
ctx.logger.info("executing tool")
194+
195+
```
196+
In this example, the `Logger` instance is accessed via `ctx.logger` and used to emit an informational
197+
log message during tool execution.
198+
180199
### Tool filtering
181200

182201
Tools can be filtered, before these are made available to the LLM, via the `tool_filters` parameter.

splunklib/ai/registry.py

Lines changed: 178 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -14,18 +14,129 @@
1414
# under the License.
1515
import asyncio
1616
import inspect
17+
import logging
1718
from collections.abc import Sequence
1819
from dataclasses import asdict, dataclass
19-
from typing import Any, Callable, Generic, ParamSpec, TypeVar, get_type_hints
20+
from logging import Logger
21+
from typing import (
22+
Any,
23+
Callable,
24+
Generic,
25+
ParamSpec,
26+
TypeVar,
27+
get_type_hints,
28+
override,
29+
)
2030

2131
import mcp.types as types
32+
from mcp import LoggingLevel, ServerSession
2233
from mcp.server.lowlevel import Server
2334
from pydantic import TypeAdapter
2435

2536
from splunklib.binding import _spliturl
2637
from splunklib.client import Service, connect
2738

2839

40+
def _normalize_logger_level(levelno: int) -> int:
41+
if levelno < logging.INFO:
42+
return logging.DEBUG
43+
elif levelno < logging.WARNING:
44+
return logging.INFO
45+
elif levelno < logging.ERROR:
46+
return logging.WARN
47+
elif levelno < logging.CRITICAL:
48+
return logging.ERROR
49+
else:
50+
return logging.CRITICAL
51+
52+
53+
def _map_logger_to_mcp_logging_level(levelno: int) -> types.LoggingLevel:
54+
match _normalize_logger_level(levelno):
55+
case logging.FATAL:
56+
return "critical"
57+
case logging.ERROR:
58+
return "error"
59+
case logging.WARN:
60+
return "warning"
61+
case logging.INFO:
62+
return "info"
63+
case logging.DEBUG:
64+
return "debug"
65+
case _:
66+
raise AssertionError("invalid logging level")
67+
68+
69+
def _min_logging_level(level: types.LoggingLevel) -> int:
70+
match level:
71+
case "debug":
72+
return logging.NOTSET
73+
case "info":
74+
return logging.INFO
75+
case "notice":
76+
return logging.INFO
77+
case "warning":
78+
return logging.WARN
79+
case "error":
80+
return logging.ERROR
81+
case "critical":
82+
return logging.CRITICAL
83+
case "alert":
84+
return logging.CRITICAL
85+
case "emergency":
86+
return logging.CRITICAL
87+
88+
89+
class _MCPLoggingHandler(logging.Handler):
90+
_group: asyncio.TaskGroup
91+
_session: ServerSession
92+
_request_id: types.RequestId
93+
94+
def __init__(
95+
self,
96+
group: asyncio.TaskGroup,
97+
session: ServerSession,
98+
request_id: types.RequestId,
99+
) -> None:
100+
self._group = group
101+
self._session = session
102+
self._request_id = request_id
103+
super().__init__()
104+
105+
@override
106+
def emit(self, record: logging.LogRecord) -> None:
107+
mcp_level = _map_logger_to_mcp_logging_level(record.levelno)
108+
109+
async def send_log() -> None:
110+
await self._session.send_log_message(
111+
level=mcp_level,
112+
data=record.msg,
113+
logger="",
114+
related_request_id=self._request_id,
115+
)
116+
117+
# We can't await send_log() here, so we create a task, that will
118+
# send the logs concurrently.
119+
#
120+
# Note: These logs, since are executed concurrently might not be sent
121+
# in the same order, in which were created.
122+
# The root cause of this is that log Handlers cannot be async.
123+
#
124+
# We could fix this with the use of a asyncio.Queue().put_nowait, but that
125+
# has a problem, that it might raise an QueueFull exception, if there
126+
# are bunch of logs created. We would have to handle that exception with
127+
# a create_task(send_log()), which would still cause such unordered execution.
128+
#
129+
# Alternatively, we could maintain a set of all tasks that are not yet completed
130+
# and await them in send_log, before calling the send_log_message, but note
131+
# that this would require a clone of that set here, before creating the task
132+
# (also a removal of a task from that set (task.add_done_callback())
133+
#
134+
# I also wonder whether task.add_done_callback() could be leveraged to order these tasks
135+
# i.e. by storing the previous task (self._task) and setting self._task.add_done_callback()
136+
# to execute send_log() when self._task.done == False.
137+
_ = self._group.create_task(send_log())
138+
139+
29140
class ToolContext:
30141
"""
31142
ToolContext provides a way to interact with the tool execution context.
@@ -35,6 +146,7 @@ class ToolContext:
35146

36147
_management_url: str | None = None
37148
_management_token: str | None = None
149+
_logger: Logger | None = None
38150

39151
_service: Service | None = None
40152

@@ -63,6 +175,14 @@ def service(self) -> Service:
63175
self._service = s
64176
return s
65177

178+
@property
179+
def logger(self) -> Logger:
180+
"""
181+
This logger can be used by tools to emit logs during execution of a tool.
182+
"""
183+
assert self._logger is not None
184+
return self._logger
185+
66186

67187
_T = TypeVar("_T", default=Any)
68188

@@ -89,6 +209,8 @@ class ToolRegistry:
89209
_tools_wrapped_result: dict[str, bool]
90210
_executing: bool = False
91211

212+
_logging_level: LoggingLevel = "warning"
213+
92214
def __init__(self) -> None:
93215
self._server = Server("Tool Registry")
94216
self._tools = []
@@ -105,6 +227,14 @@ async def _() -> list[types.Tool]:
105227
async def _(name: str, arguments: dict[str, Any]) -> types.CallToolResult:
106228
return await self._call_tool(name, arguments)
107229

230+
@self._server.set_logging_level()
231+
async def _(level: LoggingLevel) -> None:
232+
# Note: We do not update the logging level of already created loggers, see `self._call_tool`,
233+
# but that is fine for our use case, since we only call the set_logging_level once, before
234+
# tool calls.
235+
self._logging_level = level
236+
return None
237+
108238
def _list_tools(self) -> list[types.Tool]:
109239
return self._tools
110240

@@ -115,35 +245,56 @@ async def _call_tool(
115245
if func is None:
116246
raise ValueError(f"Tool {name} does not exist")
117247

118-
ctx = ToolContext()
119-
meta = self._server.request_context.meta
120-
if meta is not None:
121-
splunk_meta = meta.model_dump().get("splunk")
122-
if splunk_meta is not None:
123-
ctx._management_url = splunk_meta.get("management_url")
124-
ctx._management_token = splunk_meta.get("management_token")
248+
req_ctx = self._server.request_context
125249

126-
for k in func.__annotations__:
127-
if func.__annotations__[k] == ToolContext:
128-
assert arguments.get(k) is None, (
129-
"Improper input schema was generated or schema verification is malfunctioning"
250+
try:
251+
# Use a TaskGroup such that all logs are send before finishing the tool execution
252+
# and all errors propagated (if any).
253+
async with asyncio.TaskGroup() as task_group:
254+
handler = _MCPLoggingHandler(
255+
task_group,
256+
req_ctx.session,
257+
req_ctx.request_id,
130258
)
131-
arguments[k] = ctx
132-
133-
res = func(**arguments)
134259

135-
# In case func was an async function, await the returned coroutine.
136-
# If not then we already have the result.
137-
if inspect.isawaitable(res):
138-
res = await res
139-
140-
if self._tools_wrapped_result.get(name):
141-
res = _WrappedResult(res)
142-
143-
return types.CallToolResult(
144-
structuredContent=asdict(res),
145-
content=[],
146-
)
260+
# Create a logger that forwards all logs to the client over MCP.
261+
logger = logging.Logger(name="MCP Logger")
262+
logger.setLevel(_min_logging_level(self._logging_level))
263+
logger.addHandler(handler)
264+
265+
ctx = ToolContext()
266+
ctx._logger = logger
267+
meta = req_ctx.meta
268+
if meta is not None:
269+
splunk_meta = meta.model_dump().get("splunk")
270+
if splunk_meta is not None:
271+
ctx._management_url = splunk_meta.get("management_url")
272+
ctx._management_token = splunk_meta.get("management_token")
273+
274+
for k in func.__annotations__:
275+
if func.__annotations__[k] == ToolContext:
276+
assert arguments.get(k) is None, (
277+
"Improper input schema was generated or schema verification is malfunctioning"
278+
)
279+
arguments[k] = ctx
280+
281+
res = func(**arguments)
282+
283+
# In case func was an async function, await the returned coroutine.
284+
# If not then we already have the result.
285+
if inspect.isawaitable(res):
286+
res = await res
287+
288+
if self._tools_wrapped_result.get(name):
289+
res = _WrappedResult(res)
290+
291+
return types.CallToolResult(
292+
structuredContent=asdict(res),
293+
content=[],
294+
)
295+
except BaseExceptionGroup as e:
296+
# Re-raise the first exception.
297+
raise e.exceptions[0]
147298

148299
def _input_schema(self, func: Callable[_P, _R]) -> dict[str, Any]:
149300
"""

0 commit comments

Comments
 (0)