1414# under the License.
1515import asyncio
1616import inspect
17+ import logging
1718from collections .abc import Sequence
1819from dataclasses import asdict , dataclass
19- from typing import Any , Callable , Generic , ParamSpec , TypeVar , get_type_hints
20+ from logging import Logger
21+ from typing import (
22+ Any ,
23+ Callable ,
24+ Generic ,
25+ ParamSpec ,
26+ TypeVar ,
27+ get_type_hints ,
28+ override ,
29+ )
2030
2131import mcp .types as types
32+ from mcp import LoggingLevel , ServerSession
2233from mcp .server .lowlevel import Server
2334from pydantic import TypeAdapter
2435
2536from splunklib .binding import _spliturl
2637from splunklib .client import Service , connect
2738
2839
40+ def _normalize_logger_level (levelno : int ) -> int :
41+ if levelno < logging .INFO :
42+ return logging .DEBUG
43+ elif levelno < logging .WARNING :
44+ return logging .INFO
45+ elif levelno < logging .ERROR :
46+ return logging .WARN
47+ elif levelno < logging .CRITICAL :
48+ return logging .ERROR
49+ else :
50+ return logging .CRITICAL
51+
52+
53+ def _map_logger_to_mcp_logging_level (levelno : int ) -> types .LoggingLevel :
54+ match _normalize_logger_level (levelno ):
55+ case logging .FATAL :
56+ return "critical"
57+ case logging .ERROR :
58+ return "error"
59+ case logging .WARN :
60+ return "warning"
61+ case logging .INFO :
62+ return "info"
63+ case logging .DEBUG :
64+ return "debug"
65+ case _:
66+ raise AssertionError ("invalid logging level" )
67+
68+
69+ def _min_logging_level (level : types .LoggingLevel ) -> int :
70+ match level :
71+ case "debug" :
72+ return logging .NOTSET
73+ case "info" :
74+ return logging .INFO
75+ case "notice" :
76+ return logging .INFO
77+ case "warning" :
78+ return logging .WARN
79+ case "error" :
80+ return logging .ERROR
81+ case "critical" :
82+ return logging .CRITICAL
83+ case "alert" :
84+ return logging .CRITICAL
85+ case "emergency" :
86+ return logging .CRITICAL
87+
88+
89+ class _MCPLoggingHandler (logging .Handler ):
90+ _group : asyncio .TaskGroup
91+ _session : ServerSession
92+ _request_id : types .RequestId
93+
94+ def __init__ (
95+ self ,
96+ group : asyncio .TaskGroup ,
97+ session : ServerSession ,
98+ request_id : types .RequestId ,
99+ ) -> None :
100+ self ._group = group
101+ self ._session = session
102+ self ._request_id = request_id
103+ super ().__init__ ()
104+
105+ @override
106+ def emit (self , record : logging .LogRecord ) -> None :
107+ mcp_level = _map_logger_to_mcp_logging_level (record .levelno )
108+
109+ async def send_log () -> None :
110+ await self ._session .send_log_message (
111+ level = mcp_level ,
112+ data = record .msg ,
113+ logger = "" ,
114+ related_request_id = self ._request_id ,
115+ )
116+
117+ # We can't await send_log() here, so we create a task, that will
118+ # send the logs concurrently.
119+ #
120+ # Note: These logs, since are executed concurrently might not be sent
121+ # in the same order, in which were created.
122+ # The root cause of this is that log Handlers cannot be async.
123+ #
124+ # We could fix this with the use of a asyncio.Queue().put_nowait, but that
125+ # has a problem, that it might raise an QueueFull exception, if there
126+ # are bunch of logs created. We would have to handle that exception with
127+ # a create_task(send_log()), which would still cause such unordered execution.
128+ #
129+ # Alternatively, we could maintain a set of all tasks that are not yet completed
130+ # and await them in send_log, before calling the send_log_message, but note
131+ # that this would require a clone of that set here, before creating the task
132+ # (also a removal of a task from that set (task.add_done_callback())
133+ #
134+ # I also wonder whether task.add_done_callback() could be leveraged to order these tasks
135+ # i.e. by storing the previous task (self._task) and setting self._task.add_done_callback()
136+ # to execute send_log() when self._task.done == False.
137+ _ = self ._group .create_task (send_log ())
138+
139+
29140class ToolContext :
30141 """
31142 ToolContext provides a way to interact with the tool execution context.
@@ -35,6 +146,7 @@ class ToolContext:
35146
36147 _management_url : str | None = None
37148 _management_token : str | None = None
149+ _logger : Logger | None = None
38150
39151 _service : Service | None = None
40152
@@ -63,6 +175,14 @@ def service(self) -> Service:
63175 self ._service = s
64176 return s
65177
178+ @property
179+ def logger (self ) -> Logger :
180+ """
181+ This logger can be used by tools to emit logs during execution of a tool.
182+ """
183+ assert self ._logger is not None
184+ return self ._logger
185+
66186
67187_T = TypeVar ("_T" , default = Any )
68188
@@ -89,6 +209,8 @@ class ToolRegistry:
89209 _tools_wrapped_result : dict [str , bool ]
90210 _executing : bool = False
91211
212+ _logging_level : LoggingLevel = "warning"
213+
92214 def __init__ (self ) -> None :
93215 self ._server = Server ("Tool Registry" )
94216 self ._tools = []
@@ -105,6 +227,14 @@ async def _() -> list[types.Tool]:
105227 async def _ (name : str , arguments : dict [str , Any ]) -> types .CallToolResult :
106228 return await self ._call_tool (name , arguments )
107229
230+ @self ._server .set_logging_level ()
231+ async def _ (level : LoggingLevel ) -> None :
232+ # Note: We do not update the logging level of already created loggers, see `self._call_tool`,
233+ # but that is fine for our use case, since we only call the set_logging_level once, before
234+ # tool calls.
235+ self ._logging_level = level
236+ return None
237+
108238 def _list_tools (self ) -> list [types .Tool ]:
109239 return self ._tools
110240
@@ -115,35 +245,56 @@ async def _call_tool(
115245 if func is None :
116246 raise ValueError (f"Tool { name } does not exist" )
117247
118- ctx = ToolContext ()
119- meta = self ._server .request_context .meta
120- if meta is not None :
121- splunk_meta = meta .model_dump ().get ("splunk" )
122- if splunk_meta is not None :
123- ctx ._management_url = splunk_meta .get ("management_url" )
124- ctx ._management_token = splunk_meta .get ("management_token" )
248+ req_ctx = self ._server .request_context
125249
126- for k in func .__annotations__ :
127- if func .__annotations__ [k ] == ToolContext :
128- assert arguments .get (k ) is None , (
129- "Improper input schema was generated or schema verification is malfunctioning"
250+ try :
251+ # Use a TaskGroup such that all logs are send before finishing the tool execution
252+ # and all errors propagated (if any).
253+ async with asyncio .TaskGroup () as task_group :
254+ handler = _MCPLoggingHandler (
255+ task_group ,
256+ req_ctx .session ,
257+ req_ctx .request_id ,
130258 )
131- arguments [k ] = ctx
132-
133- res = func (** arguments )
134259
135- # In case func was an async function, await the returned coroutine.
136- # If not then we already have the result.
137- if inspect .isawaitable (res ):
138- res = await res
139-
140- if self ._tools_wrapped_result .get (name ):
141- res = _WrappedResult (res )
142-
143- return types .CallToolResult (
144- structuredContent = asdict (res ),
145- content = [],
146- )
260+ # Create a logger that forwards all logs to the client over MCP.
261+ logger = logging .Logger (name = "MCP Logger" )
262+ logger .setLevel (_min_logging_level (self ._logging_level ))
263+ logger .addHandler (handler )
264+
265+ ctx = ToolContext ()
266+ ctx ._logger = logger
267+ meta = req_ctx .meta
268+ if meta is not None :
269+ splunk_meta = meta .model_dump ().get ("splunk" )
270+ if splunk_meta is not None :
271+ ctx ._management_url = splunk_meta .get ("management_url" )
272+ ctx ._management_token = splunk_meta .get ("management_token" )
273+
274+ for k in func .__annotations__ :
275+ if func .__annotations__ [k ] == ToolContext :
276+ assert arguments .get (k ) is None , (
277+ "Improper input schema was generated or schema verification is malfunctioning"
278+ )
279+ arguments [k ] = ctx
280+
281+ res = func (** arguments )
282+
283+ # In case func was an async function, await the returned coroutine.
284+ # If not then we already have the result.
285+ if inspect .isawaitable (res ):
286+ res = await res
287+
288+ if self ._tools_wrapped_result .get (name ):
289+ res = _WrappedResult (res )
290+
291+ return types .CallToolResult (
292+ structuredContent = asdict (res ),
293+ content = [],
294+ )
295+ except BaseExceptionGroup as e :
296+ # Re-raise the first exception.
297+ raise e .exceptions [0 ]
147298
148299 def _input_schema (self , func : Callable [_P , _R ]) -> dict [str , Any ]:
149300 """
0 commit comments