1212# License for the specific language governing permissions and limitations
1313# under the License.
1414
15+ from contextlib import AbstractAsyncContextManager , AsyncExitStack , asynccontextmanager
1516from logging import Logger
1617import os
17- from collections .abc import Sequence
18+ from collections .abc import AsyncGenerator , Sequence
1819from typing import Self , final , override
1920
2021from pydantic import BaseModel
2627from splunklib .ai .messages import AgentResponse , BaseMessage , OutputT
2728from splunklib .ai .model import PredefinedModel
2829from splunklib .ai .tool_filtering import ToolFilters , filter_tools
29- from splunklib .ai .tools import Tool , build_local_tools_path , load_mcp_tools , locate_app
30+ from splunklib .ai .tools import (
31+ Tool ,
32+ build_local_tools_path ,
33+ connect_local_mcp ,
34+ connect_remote_mcp ,
35+ load_mcp_tools ,
36+ locate_app ,
37+ )
3038from splunklib .client import Service
3139
3240# For testing purposes, overrides the automatically inferred tools.py path.
@@ -101,14 +109,15 @@ class Agent(BaseAgent[OutputT]):
101109 is appropriate for a given task. Ignored for top-level agents.
102110
103111 logger:
104- Optional logger instance used for tracing and debugging the agent’ s execution.
112+ Optional logger instance used for tracing and debugging the agent' s execution.
105113 Additionally logs from the local tools are forwarded to this logger.
106114 """
107115
108116 _impl : AgentImpl [OutputT ] | None
109117 _use_mcp_tools : bool
110118 _service : Service
111119 _tool_filters : ToolFilters | None
120+ _agent_context_manager : AbstractAsyncContextManager [Self ] | None = None
112121
113122 def __init__ (
114123 self ,
@@ -119,9 +128,9 @@ def __init__(
119128 tool_filters : ToolFilters | None = None ,
120129 agents : Sequence [BaseAgent [BaseModel | None ]] | None = None ,
121130 output_schema : type [OutputT ] | None = None ,
122- input_schema : type [BaseModel ] | None = None , # Only used by Subgents
131+ input_schema : type [BaseModel ] | None = None , # Only used by Subagents
123132 hooks : Sequence [AgentHook ] | None = None ,
124- name : str = "" , # Only used by Subgents
133+ name : str = "" , # Only used by Subagents
125134 description : str = "" , # Only used by Subagents
126135 logger : Logger | None = None ,
127136 ) -> None :
@@ -142,36 +151,94 @@ def __init__(
142151 self ._service = service
143152 self ._impl = None
144153
145- async def __aenter__ (self ) -> Self :
146- if self ._impl :
147- raise AssertionError ("Agent is already in `async with` context" )
148-
149- if self .name :
150- self .logger .debug (f"Creating agent { self .name } ; trace_id={ self .trace_id } " )
151- else :
152- self .logger .debug (f"Creating agent; trace_id={ self .trace_id } " )
153-
154- if self ._use_mcp_tools :
155- self ._tools = await _load_tools_from_mcp (
156- self ._service ,
157- self ._tool_filters ,
158- self .trace_id ,
159- self .logger ,
154+ @asynccontextmanager
155+ async def _start_agent (self ) -> AsyncGenerator [Self ]:
156+ async with AsyncExitStack () as stack :
157+ assert self ._impl is None , (
158+ "internal error: _impl was not set to None after agent invocation"
160159 )
161160
162- backend = get_backend ()
163- self ._impl = await backend .create_agent (self )
161+ if self .name :
162+ self .logger .debug (
163+ f"Creating agent { self .name } ; trace_id={ self .trace_id } "
164+ )
165+ else :
166+ self .logger .debug (f"Creating agent; trace_id={ self .trace_id } " )
167+
168+ if self ._use_mcp_tools :
169+ tools : list [Tool ] = []
170+
171+ self .logger .debug ("Local tool registry detected" )
172+ local_tools_path , app_id = _local_tools_path ()
173+ if local_tools_path :
174+ local_session = await stack .enter_async_context (
175+ connect_local_mcp (local_tools_path , self .logger )
176+ )
177+ self .logger .debug ("Loading local tools" )
178+ local_tools = await load_mcp_tools (
179+ local_session , "local" , app_id , self .trace_id , self ._service
180+ )
181+ self .logger .debug (f"Local tools loaded; { local_tools = } " )
182+ tools .extend (local_tools )
183+
184+ self .logger .debug ("Probing MCP Server App availability" )
185+ remote_session = await stack .enter_async_context (
186+ connect_remote_mcp (
187+ self ._service ,
188+ app_id ,
189+ self .trace_id ,
190+ )
191+ )
192+ if remote_session :
193+ self .logger .debug ("Loading remote tools - MCP Server available" )
194+ remote_tools = await load_mcp_tools (
195+ remote_session ,
196+ "remote" ,
197+ app_id ,
198+ self .trace_id ,
199+ self ._service ,
200+ )
201+ self .logger .debug (f"Remote tools loaded; { remote_tools = } " )
202+ tools .extend (remote_tools )
203+
204+ if self ._tool_filters :
205+ tools = filter_tools (tools , self ._tool_filters )
206+
207+ self .logger .debug (
208+ f"Tools loaded & filtered successfully; tools_after_filtering={ [tool .name for tool in tools ]} "
209+ )
210+
211+ self ._tools = tools
212+
213+ backend = get_backend ()
214+ self ._impl = await backend .create_agent (self )
215+
216+ if self .name :
217+ self .logger .debug (
218+ f"Agent { self .name } created; trace_id={ self .trace_id } "
219+ )
220+ else :
221+ self .logger .debug (f"Agent created; trace_id={ self .trace_id } " )
222+
223+ yield self
224+
225+ self ._impl = None
164226
165- if self .name :
166- self .logger .debug (f"Agent { self .name } created; trace_id={ self .trace_id } " )
167- else :
168- self .logger .debug (f"Agent created; trace_id={ self .trace_id } " )
169-
170- return self
171-
172- async def __aexit__ (self , exc_type , exc_value , traceback ) -> None : # noqa: ANN001 # pyright: ignore[reportUnknownParameterType, reportMissingParameterType]
173- self ._impl = None # Make sure invoke fails if called after exit.
174- return None
227+ async def __aenter__ (self ) -> Self :
228+ if self ._agent_context_manager :
229+ raise AssertionError ("Agent is already in `async with` context" )
230+ self ._agent_context_manager = self ._start_agent ()
231+ return await self ._agent_context_manager .__aenter__ ()
232+
233+ async def __aexit__ (
234+ self , exc_type : ..., exc_value : ..., traceback : ...
235+ ) -> bool | None :
236+ assert self ._agent_context_manager is not None
237+ return await self ._agent_context_manager .__aexit__ (
238+ exc_type ,
239+ exc_value ,
240+ traceback ,
241+ )
175242
176243 @override
177244 async def invoke (self , messages : list [BaseMessage ]) -> AgentResponse [OutputT ]:
@@ -181,12 +248,7 @@ async def invoke(self, messages: list[BaseMessage]) -> AgentResponse[OutputT]:
181248 return await self ._impl .invoke (messages )
182249
183250
184- async def _load_tools_from_mcp (
185- service : Service ,
186- filters : ToolFilters | None ,
187- trace_id : str ,
188- logger : Logger ,
189- ) -> list [Tool ]:
251+ def _local_tools_path () -> tuple [str | None , str ]:
190252 local_tools_path = _testing_local_tools_path
191253 app_id = _testing_app_id
192254
@@ -201,14 +263,4 @@ async def _load_tools_from_mcp(
201263 if not os .path .exists (local_tools_path ):
202264 local_tools_path = None
203265
204- mcp_tools = await load_mcp_tools (
205- service , local_tools_path , app_id , trace_id , logger
206- )
207- if filters :
208- return filter_tools (mcp_tools , filters )
209-
210- logger .debug (
211- f"Tools loaded & filtered successfully; tools_after_filtering={ [tool .name for tool in mcp_tools ]} "
212- )
213-
214- return mcp_tools
266+ return local_tools_path , app_id
0 commit comments