Skip to content

Commit 1577899

Browse files
authored
Add hooks (#45)
Hooks help with inserting callbacks into different moments of the Agent Loop execution. This can be used to log/audit/debug and to stop the Agentic Loop early on.
1 parent 0d0bbc2 commit 1577899

8 files changed

Lines changed: 592 additions & 235 deletions

File tree

splunklib/ai/README.md

Lines changed: 102 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -365,12 +365,105 @@ async with Agent(
365365

366366
**Note**: Currently input schemas can only be used by subagents, not by regular agents.
367367

368-
## Loop Stop Conditions
368+
## Hooks
369+
370+
Hooks are user-defined callback functions that can be registered to execute at specific points
371+
during the agent's operation. Hooks allow developers to add custom behavior, logging and monitoring
372+
or implement custom stopping conditions for the agent loop without modifying the core agent logic.
373+
374+
There are several types of hooks available.
375+
They differ by the point in the execution flow where they are invoked:
376+
377+
- before_model: before each model call
378+
- after_model: after each model call
379+
- before_agent: once per agent invocation, before any model calls
380+
- after_agent: once per agent invocation, after all model calls
381+
382+
Example hook that logs token usage after each model call:
383+
384+
```py
385+
from splunklib.ai import Agent, OpenAIModel
386+
from splunklib.ai.hooks import after_model
387+
from splunklib.client import connect
388+
389+
import logging
390+
391+
logger = logging.getLogger(__name__)
392+
393+
model = OpenAIModel(...)
394+
service = connect(...)
395+
396+
@after_model
397+
def log_token_usage(state: AgentState) -> None:
398+
logger.debug(f"Model used {state.token_count} tokens up to this point")
399+
400+
401+
async with Agent(
402+
model=model,
403+
service=service,
404+
system_prompt="..." ,
405+
hooks=[log_token_usage],
406+
) as agent: ...
407+
```
408+
409+
The same hook can be defined as a class. It needs to provide the type and name attributes, and implement the `__call__` method:
410+
411+
```py
412+
from typing import final, override
413+
from splunklib.ai.hooks import AgentHook, AgentState
414+
import logging
415+
416+
logger = logging.getLogger(__name__)
417+
418+
@final
419+
class LoggingHook(AgentHook):
420+
type = "before_model"
421+
name = "test_hook"
422+
423+
@override
424+
def __call__(self, state: AgentState) -> None:
425+
logger.debug(f"Model used {state.token_count} tokens up to this point")
426+
427+
async with Agent(
428+
model=model,
429+
service=service,
430+
system_prompt="..." ,
431+
hooks=[LoggingHook()],
432+
) as agent: ...
433+
```
434+
435+
The hooks can stop the Agentic Loop under custom conditions by raising exceptions.
436+
The logic of the hook can be more advanced and include multiple conditions, for example, based on both token usage and execution time:
437+
438+
```py
439+
from splunklib.ai import Agent, OpenAIModel
440+
from splunklib.ai.hooks import before_model, AgentHook
441+
from time import monotonic
442+
443+
def timeout_or_token_limit(seconds_limit: float, token_limit: float) -> AgentHook:
444+
now = monotonic()
445+
timeout = now + seconds_limit
446+
447+
@before_model
448+
def _limit_hook(state: AgentState) -> None:
449+
if state.token_count > token_limit or monotonic() >= timeout:
450+
raise Exception("Stopping Agentic Loop")
451+
452+
return _limit_hook
453+
454+
455+
async with Agent(
456+
...,
457+
hooks=[timeout_or_token_limit(seconds_limit=10.0, token_limit=10000)],
458+
) as agent: ...
459+
```
460+
461+
### Predefined hooks for loop stopping conditions
369462

370463
To prevent excessive token usage or runaway execution, an Agent can be constrained
371-
using loop stop conditions.
464+
using predefined hooks.
372465

373-
Stop conditions allow you to automatically terminate the agent loop when one or more
466+
Those hooks allow you to automatically terminate the agent loop when one or more
374467
limits are reached, such as:
375468

376469
- Maximum number of generated tokens
@@ -379,7 +472,7 @@ limits are reached, such as:
379472

380473
```py
381474
from splunklib.ai import Agent, OpenAIModel
382-
from splunklib.ai.stop_conditions import StopConditions
475+
from splunklib.ai.hooks import token_limit, step_limit, timeout_limit
383476
from splunklib.client import connect
384477

385478
model = OpenAIModel(...)
@@ -389,11 +482,11 @@ async with Agent(
389482
model=model,
390483
service=service,
391484
system_prompt="..." ,
392-
loop_stop_conditions=StopConditions(
393-
token_limit = 10000,
394-
steps_limit = 25,
395-
timeout_seconds = 10.5,
396-
),
485+
hooks=[
486+
token_limit(10000),
487+
step_limit(25),
488+
timeout_limit(10.5),
489+
],
397490
) as agent: ...
398491
```
399492

splunklib/ai/agent.py

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,9 @@
2222
from splunklib.ai.base_agent import BaseAgent
2323
from splunklib.ai.core.backend import AgentImpl
2424
from splunklib.ai.core.backend_registry import get_backend
25+
from splunklib.ai.hooks import AgentHook
2526
from splunklib.ai.messages import AgentResponse, BaseMessage, OutputT
2627
from splunklib.ai.model import PredefinedModel
27-
from splunklib.ai.stop_conditions import StopConditions
2828
from splunklib.ai.tool_filtering import ToolFilters, filter_tools
2929
from splunklib.ai.tools import (
3030
Tool,
@@ -88,11 +88,10 @@ class Agent(BaseAgent[OutputT]):
8888
used as a *subagent*. The supervisor agent uses this schema to
8989
understand how to call the subagent and how to format its inputs.
9090
91-
loop_stop_conditions:
92-
Optional `StopConditions` instance defining automatic termination.
93-
If any limit is exceeded, the corresponding exception
94-
(`TokenLimitExceededException`, `StepsLimitExceededException`,
95-
or `TimeoutExceededException`) is raised.
91+
hooks:
92+
Optional sequence of `AgentHook`. Hooks are user-defined callback
93+
functions that can be registered to execute at specific points
94+
during the agent's operation.
9695
9796
name:
9897
Name of the agent when used as a subagent. This is
@@ -122,7 +121,7 @@ def __init__(
122121
agents: Sequence[BaseAgent[BaseModel | None]] | None = None,
123122
output_schema: type[OutputT] | None = None,
124123
input_schema: type[BaseModel] | None = None, # Only used by Subgents
125-
loop_stop_conditions: StopConditions | None = None,
124+
hooks: Sequence[AgentHook] | None = None,
126125
name: str = "", # Only used by Subgents
127126
description: str = "", # Only used by Subagents
128127
) -> None:
@@ -134,9 +133,12 @@ def __init__(
134133
agents=agents,
135134
input_schema=input_schema,
136135
output_schema=output_schema,
137-
loop_stop_conditions=loop_stop_conditions,
136+
hooks=hooks,
138137
)
139138

139+
if duplicate_hook_names := _find_duplicate_hook_names(self.hooks):
140+
raise ValueError(f"Duplicate hook names found: {duplicate_hook_names!r}")
141+
140142
self._use_mcp_tools = use_mcp_tools
141143
self._tool_filters = tool_filters
142144
self._service = service
@@ -181,3 +183,19 @@ async def _load_tools_from_mcp(
181183
return filter_tools(mcp_tools, filters)
182184

183185
return mcp_tools
186+
187+
188+
def _find_duplicate_hook_names(hooks: Sequence[AgentHook] | None) -> set[str]:
189+
seen: set[str] = set()
190+
duplicates: set[str] = set()
191+
192+
if not hooks:
193+
return set()
194+
195+
for hook in hooks:
196+
if hook.name in seen:
197+
duplicates.add(hook.name)
198+
else:
199+
seen.add(hook.name)
200+
201+
return duplicates

splunklib/ai/base_agent.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,9 @@
1919

2020
from pydantic import BaseModel
2121

22+
from splunklib.ai.hooks import AgentHook
2223
from splunklib.ai.messages import AgentResponse, BaseMessage, OutputT
2324
from splunklib.ai.model import PredefinedModel
24-
from splunklib.ai.stop_conditions import StopConditions
2525
from splunklib.ai.tools import Tool
2626

2727

@@ -34,7 +34,7 @@ class BaseAgent(Generic[OutputT], ABC):
3434
_description: str = ""
3535
_input_schema: type[BaseModel] | None = None
3636
_output_schema: type[OutputT] | None = None
37-
_loop_stop_conditions: StopConditions | None = None
37+
_hooks: Sequence[AgentHook] | None = None
3838

3939
def __init__(
4040
self,
@@ -46,7 +46,7 @@ def __init__(
4646
agents: Sequence["BaseAgent[BaseModel | None]"] | None = None,
4747
input_schema: type[BaseModel] | None = None,
4848
output_schema: type[OutputT] | None = None,
49-
loop_stop_conditions: StopConditions | None = None,
49+
hooks: Sequence[AgentHook] | None = None,
5050
) -> None:
5151
self._system_prompt = system_prompt
5252
self._model = model
@@ -56,7 +56,7 @@ def __init__(
5656
self._agents = tuple(agents) if agents else ()
5757
self._input_schema = input_schema
5858
self._output_schema = output_schema
59-
self._loop_stop_conditions = loop_stop_conditions
59+
self._hooks = tuple(hooks) if hooks else ()
6060

6161
@abstractmethod
6262
async def invoke(self, messages: list[BaseMessage]) -> AgentResponse[OutputT]: ...
@@ -94,5 +94,5 @@ def output_schema(self) -> type[OutputT] | None:
9494
return self._output_schema
9595

9696
@property
97-
def loop_stop_conditions(self) -> StopConditions | None:
98-
return self._loop_stop_conditions
97+
def hooks(self) -> Sequence[AgentHook] | None:
98+
return self._hooks

0 commit comments

Comments
 (0)