1313# License for the specific language governing permissions and limitations
1414# under the License.
1515
16+ import logging
1617import uuid
1718from collections .abc import Sequence
1819from dataclasses import dataclass
1920from functools import partial
2021from time import monotonic
21- from typing import Any , cast , override
22+ from typing import Any , Awaitable , Callable , cast , override
2223
2324from langchain .agents import create_agent
2425from langchain .agents .middleware import (
2526 AgentMiddleware as LC_AgentMiddleware ,
27+ wrap_tool_call ,
2628)
2729from langchain .agents .middleware import (
2830 AgentState as LC_AgentState ,
4042from langchain .messages import ToolCall as LC_ToolCall
4143from langchain .messages import ToolMessage as LC_ToolMessage
4244from langchain .tools import ToolException as LC_ToolException
45+ from langchain .tools .tool_node import ToolCallRequest as LC_ToolCallRequest
4346from langchain_core .language_models import BaseChatModel
4447from langchain_core .messages .base import BaseMessage as LC_BaseMessage
4548from langchain_core .messages .utils import count_tokens_approximately
4649from langchain_core .tools import BaseTool , StructuredTool
4750from langgraph .checkpoint .memory import InMemorySaver
4851from langgraph .graph .state import CompiledStateGraph , RunnableConfig
4952from langgraph .runtime import Runtime
53+ from langgraph .types import Command as LC_Command
5054
5155from splunklib .ai .base_agent import BaseAgent
5256from splunklib .ai .core .backend import (
5862from splunklib .ai .hooks import (
5963 AgentHook ,
6064 AgentState ,
61- StepsLimitExceededException ,
62- TimeoutExceededException ,
63- TokenLimitExceededException ,
65+ after_model as hook_after_model ,
66+ before_model as hook_before_model ,
6467)
6568from splunklib .ai .messages import (
6669 AgentCall ,
@@ -192,12 +195,28 @@ async def create_agent(
192195
193196 system_prompt = AGENT_AS_TOOLS_PROMPT + "\n " + system_prompt
194197
195- middleware = []
198+ before_user_hooks , after_user_hooks , before_user_lc_middlewares = (
199+ _debugging_middleware (agent .logger )
200+ )
201+
202+ middleware = [
203+ _convert_hook_to_middleware (h , model_impl ) for h in before_user_hooks
204+ ]
205+ middleware .extend (before_user_lc_middlewares )
206+
207+ # User-provided hooks go in between our hooks.
196208 if agent .hooks :
197209 middleware .extend (
198- (_convert_hook_to_middleware (h , model_impl ) for h in agent .hooks )
210+ (
211+ _convert_hook_to_middleware (h , model_impl , logger = agent .logger )
212+ for h in agent .hooks
213+ )
199214 )
200215
216+ middleware .extend (
217+ (_convert_hook_to_middleware (h , model_impl ) for h in after_user_hooks )
218+ )
219+
201220 return LangChainAgentImpl (
202221 system_prompt = system_prompt ,
203222 model = model_impl ,
@@ -207,6 +226,73 @@ async def create_agent(
207226 )
208227
209228
229+ def _debugging_middleware (
230+ logger : logging .Logger ,
231+ ) -> tuple [list [AgentHook ], list [AgentHook ], list [LC_AgentMiddleware ]]:
232+ # TODO: These names can conflict with user-provided names.
233+
234+ # TODO: replace this with ours middleware, once we add them.
235+ @wrap_tool_call # pyright: ignore[reportArgumentType, reportCallIssue, reportUntypedFunctionDecorator]
236+ async def _tool_call (
237+ request : LC_ToolCallRequest ,
238+ handler : Callable [
239+ [LC_ToolCallRequest ], Awaitable [LC_ToolMessage | LC_Command [None ]]
240+ ],
241+ ) -> LC_ToolMessage | LC_Command [None ]:
242+ call = _map_tool_call_from_langchain (request .tool_call )
243+
244+ tool_or_agent = "Tool"
245+ if isinstance (call , AgentCall ):
246+ tool_or_agent = "Agent"
247+
248+ logger .debug (f"{ tool_or_agent } call { call .name } stared; id={ call .id } " )
249+ try :
250+ result = await handler (request )
251+ assert isinstance (result , LC_ToolMessage )
252+
253+ if result .status == "success" :
254+ logger .debug (
255+ f"{ tool_or_agent } call { call .name } succeeded; id={ call .id } "
256+ )
257+ else :
258+ logger .debug (f"{ tool_or_agent } call { call .name } failed; id={ call .id } " )
259+
260+ return result
261+ except Exception :
262+ logger .debug (f"{ tool_or_agent } call { call .name } failed; id={ call .id } " )
263+ raise
264+
265+ before_user_lc_middlewares = [_tool_call ]
266+
267+ @hook_after_model
268+ def _debug_after_model (state : AgentState ) -> None :
269+ last = state .response .messages [- 1 ]
270+ if isinstance (last , AIMessage ):
271+ tool_calls = [
272+ (call .name , call .id )
273+ for call in last .calls
274+ if isinstance (call , ToolCall )
275+ ]
276+ subagent_calls = [
277+ (call .name , call .id )
278+ for call in last .calls
279+ if isinstance (call , AgentCall )
280+ ]
281+ logger .debug (
282+ f"LLM model invocation ended; requested_tool_calls={ tool_calls } ; requested_subagent_calls={ subagent_calls } "
283+ )
284+
285+ before_user_hooks = [_debug_after_model ]
286+
287+ @hook_before_model
288+ def _debug_before_model (state : AgentState ) -> None :
289+ logger .debug ("Invoking LLM model" )
290+
291+ after_user_hooks = [_debug_before_model ]
292+
293+ return before_user_hooks , after_user_hooks , before_user_lc_middlewares # pyright: ignore[reportReturnType]
294+
295+
210296def _create_langchain_tool (tool : Tool ) -> BaseTool :
211297 async def _tool_call (
212298 ** kwargs : dict [str , Any ],
@@ -389,6 +475,7 @@ def _map_message_to_langchain(message: BaseMessage) -> LC_BaseMessage:
389475def _convert_hook_to_middleware (
390476 hook : AgentHook ,
391477 model : BaseChatModel ,
478+ logger : logging .Logger | None = None ,
392479) -> LC_AgentMiddleware :
393480 match hook .type :
394481 case "before_model" :
@@ -414,6 +501,10 @@ def _middleware(state: LC_AgentState, runtime: Runtime) -> dict[str, Any] | None
414501 # the token counting function as part of the Backend interface, so that
415502 # it's only used when needed instead.
416503 sdk_state = _convert_agent_state_from_langchain (state , model )
504+
505+ if logger :
506+ logger .debug (f"Executing { hook .type } hook { hook .name } " )
507+
417508 hook (sdk_state )
418509
419510 return wrapper (_middleware )
0 commit comments