Skip to content

Commit a1eae32

Browse files
authored
Drop name field of hooks (#67)
Instead of solving the collision of names bwetween internal/external hooks lets remove these all together, these exist such that we have something to pass to LC, but other that that these don't serve any usercase (except in debug logs, where we have been printing these names). This change removes the name field of hooks, and while converting to LC we generate a random uuid4 to name these hooks. To not loose the DEBUG logging experience we infer a name for logs for these from the class/function name.
1 parent b4d2c66 commit a1eae32

4 files changed

Lines changed: 44 additions & 69 deletions

File tree

splunklib/ai/agent.py

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -137,9 +137,6 @@ def __init__(
137137
logger=logger,
138138
)
139139

140-
if duplicate_hook_names := _find_duplicate_hook_names(self.hooks):
141-
raise ValueError(f"Duplicate hook names found: {duplicate_hook_names!r}")
142-
143140
self._use_mcp_tools = use_mcp_tools
144141
self._tool_filters = tool_filters
145142
self._service = service
@@ -215,19 +212,3 @@ async def _load_tools_from_mcp(
215212
)
216213

217214
return mcp_tools
218-
219-
220-
def _find_duplicate_hook_names(hooks: Sequence[AgentHook] | None) -> set[str]:
221-
seen: set[str] = set()
222-
duplicates: set[str] = set()
223-
224-
if not hooks:
225-
return set()
226-
227-
for hook in hooks:
228-
if hook.name in seen:
229-
duplicates.add(hook.name)
230-
else:
231-
seen.add(hook.name)
232-
233-
return duplicates

splunklib/ai/engines/langchain.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@
6262
from splunklib.ai.hooks import (
6363
AgentHook,
6464
AgentState,
65+
FunctionHook,
6566
after_model as hook_after_model,
6667
before_model as hook_before_model,
6768
)
@@ -229,8 +230,6 @@ async def create_agent(
229230
def _debugging_middleware(
230231
logger: logging.Logger,
231232
) -> tuple[list[AgentHook], list[AgentHook], list[LC_AgentMiddleware]]:
232-
# TODO: These names can conflict with user-provided names.
233-
234233
# TODO: replace this with ours middleware, once we add them.
235234
@wrap_tool_call # pyright: ignore[reportArgumentType, reportCallIssue, reportUntypedFunctionDecorator]
236235
async def _tool_call(
@@ -489,15 +488,25 @@ def _convert_hook_to_middleware(
489488
model: BaseChatModel,
490489
logger: logging.Logger | None = None,
491490
) -> LC_AgentMiddleware:
491+
# Inspect the hook to generate a useful name for debug log messages.
492+
hook_name = hook.__class__.__name__
493+
if isinstance(hook, FunctionHook):
494+
hook_name = hook.func.__name__
495+
496+
# Generate a random name to name this hook in langchain.
497+
# We can't use the hook_name, derived above, since it might not be unique, we
498+
# also don't want to force the users to name these hooks, as langchain does.
499+
lc_hook_name = str(uuid.uuid4())
500+
492501
match hook.type:
493502
case "before_model":
494-
wrapper = before_model(can_jump_to=["end"], name=hook.name)
503+
wrapper = before_model(can_jump_to=["end"], name=lc_hook_name)
495504
case "after_model":
496-
wrapper = after_model(can_jump_to=["end"], name=hook.name)
505+
wrapper = after_model(can_jump_to=["end"], name=lc_hook_name)
497506
case "before_agent":
498-
wrapper = before_agent(can_jump_to=["end"], name=hook.name)
507+
wrapper = before_agent(can_jump_to=["end"], name=lc_hook_name)
499508
case "after_agent":
500-
wrapper = after_agent(can_jump_to=["end"], name=hook.name)
509+
wrapper = after_agent(can_jump_to=["end"], name=lc_hook_name)
501510
case _:
502511
raise AssertionError(f"Unsupported middleware type: {hook.type}")
503512

@@ -517,7 +526,7 @@ async def _middleware(
517526
sdk_state = _convert_agent_state_from_langchain(state, model)
518527

519528
if logger:
520-
logger.debug(f"Executing {hook.type} hook {hook.name}")
529+
logger.debug(f"Executing {hook.type} hook {hook_name}")
521530

522531
res = hook(sdk_state)
523532
if isawaitable(res):

splunklib/ai/hooks.py

Lines changed: 28 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,6 @@ class AgentHook(Protocol):
3131
"""
3232

3333
type: HookType
34-
# Name of the middleware must be unique
35-
name: str
3634

3735
def __call__(self, state: AgentState) -> None | Awaitable[None]:
3836
"""Called at specific points during the agent execution, depending on the hook type."""
@@ -63,48 +61,54 @@ def __init__(self, timeout_seconds: float) -> None:
6361
super().__init__(f"Timed out after {timeout_seconds} seconds.")
6462

6563

66-
def _create_hook(
67-
type: HookType,
68-
func: Callable[[AgentState], None | Awaitable[None]],
69-
name: str | None = None,
70-
) -> AgentHook:
71-
mw_name = name or func.__name__
72-
mw_type = type
64+
@final
65+
class FunctionHook(AgentHook):
66+
"""
67+
Implementation of AgentHook that wraps a single callable function.
68+
69+
FunctionHook allows creation of a hook from a plain function instead of
70+
defining a full AgentHook subclass.
71+
72+
Use helper decorators: before_model, after_model, before_agent, after_agent to
73+
construct such hook.
74+
"""
7375

74-
@final
75-
class CustomHook(AgentHook):
76-
type = mw_type
77-
name = mw_name
76+
type: HookType
77+
func: Callable[[AgentState], None | Awaitable[None]]
7878

79-
@override
80-
def __call__(self, state: AgentState) -> None | Awaitable[None]:
81-
return func(state)
79+
def __init__(
80+
self, hookType: HookType, func: Callable[[AgentState], None | Awaitable[None]]
81+
) -> None:
82+
self.type = hookType
83+
self.func = func
8284

83-
return CustomHook()
85+
@override
86+
def __call__(self, state: AgentState) -> None | Awaitable[None]:
87+
return self.func(state)
8488

8589

8690
def before_model(func: Callable[[AgentState], None | Awaitable[None]]) -> AgentHook:
8791
"""This hook is called before each model call."""
8892

89-
return _create_hook("before_model", func)
93+
return FunctionHook("before_model", func)
9094

9195

9296
def after_model(func: Callable[[AgentState], None | Awaitable[None]]) -> AgentHook:
9397
"""This hook is called after each model call."""
9498

95-
return _create_hook("after_model", func)
99+
return FunctionHook("after_model", func)
96100

97101

98102
def before_agent(func: Callable[[AgentState], None | Awaitable[None]]) -> AgentHook:
99103
"""This hook is called once per agent invocation. Before any model calls."""
100104

101-
return _create_hook("before_agent", func)
105+
return FunctionHook("before_agent", func)
102106

103107

104108
def after_agent(func: Callable[[AgentState], None | Awaitable[None]]) -> AgentHook:
105109
"""This hook is called once per agent invocation. After all model calls."""
106110

107-
return _create_hook("after_agent", func)
111+
return FunctionHook("after_agent", func)
108112

109113

110114
def token_limit(limit: float) -> AgentHook:
@@ -114,7 +118,7 @@ def _token_limit_hook(state: AgentState) -> None:
114118
if state.token_count > limit:
115119
raise TokenLimitExceededException(token_limit=limit)
116120

117-
return _create_hook("before_model", _token_limit_hook, name="builtin_token_limit")
121+
return FunctionHook("before_model", _token_limit_hook)
118122

119123

120124
def step_limit(limit: int) -> AgentHook:
@@ -124,7 +128,7 @@ def _step_limit_hook(state: AgentState) -> None:
124128
if state.total_steps >= limit:
125129
raise StepsLimitExceededException(steps_limit=limit)
126130

127-
return _create_hook("before_model", _step_limit_hook, name="builtin_step_limit")
131+
return FunctionHook("before_model", _step_limit_hook)
128132

129133

130134
def timeout_limit(seconds: float) -> AgentHook:
@@ -137,6 +141,4 @@ def _timeout_limit_hook(_state: AgentState) -> None:
137141
if monotonic() >= timeout:
138142
raise TimeoutExceededException(timeout_seconds=seconds)
139143

140-
return _create_hook(
141-
"before_model", _timeout_limit_hook, name="builtin_timeout_limit"
142-
)
144+
return FunctionHook("before_model", _timeout_limit_hook)

tests/integration/ai/test_hooks.py

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -38,21 +38,6 @@
3838

3939

4040
class TestHook(AITestCase):
41-
@pytest.mark.asyncio
42-
async def test_agent_hooks_duplicated(self):
43-
pytest.importorskip("langchain_openai")
44-
45-
with pytest.raises(
46-
ValueError, match="Duplicate hook names found: {'builtin_step_limit'}"
47-
):
48-
async with Agent(
49-
model=(await self.model()),
50-
system_prompt="Your name is stefan",
51-
service=self.service,
52-
hooks=[step_limit(5), step_limit(10)],
53-
) as agent:
54-
...
55-
5641
@pytest.mark.asyncio
5742
async def test_agent_hook(self):
5843
pytest.importorskip("langchain_openai")
@@ -62,7 +47,6 @@ async def test_agent_hook(self):
6247
@final
6348
class TestHook(AgentHook):
6449
type = "before_model"
65-
name = "test_async_hook"
6650

6751
@override
6852
def __call__(self, state: AgentState) -> None:
@@ -73,7 +57,6 @@ def __call__(self, state: AgentState) -> None:
7357
@final
7458
class TestAsyncHook(AgentHook):
7559
type = "before_model"
76-
name = "test_hook"
7760

7861
@override
7962
async def __call__(self, state: AgentState) -> None:

0 commit comments

Comments
 (0)