Skip to content

Commit b4d2c66

Browse files
authored
Support async hooks (#65)
1 parent 7e1ec18 commit b4d2c66

3 files changed

Lines changed: 79 additions & 17 deletions

File tree

splunklib/ai/engines/langchain.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# License for the specific language governing permissions and limitations
1313
# under the License.
1414

15+
from inspect import isawaitable
1516
import logging
1617
import uuid
1718
from collections.abc import Sequence
@@ -500,7 +501,9 @@ def _convert_hook_to_middleware(
500501
case _:
501502
raise AssertionError(f"Unsupported middleware type: {hook.type}")
502503

503-
def _middleware(state: LC_AgentState, runtime: Runtime) -> dict[str, Any] | None:
504+
async def _middleware(
505+
state: LC_AgentState, runtime: Runtime
506+
) -> dict[str, Any] | None:
504507
# NOTE: We're converting the langchain AgentState into the SDK AgentState
505508
# on each middleware call.
506509
# We're converting all the messages back to the SDK format and counting the
@@ -516,7 +519,10 @@ def _middleware(state: LC_AgentState, runtime: Runtime) -> dict[str, Any] | None
516519
if logger:
517520
logger.debug(f"Executing {hook.type} hook {hook.name}")
518521

519-
hook(sdk_state)
522+
res = hook(sdk_state)
523+
if isawaitable(res):
524+
await res
525+
return None
520526

521527
return wrapper(_middleware)
522528

splunklib/ai/hooks.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from dataclasses import dataclass
22
from time import monotonic
3-
from typing import Any, Callable, Literal, Protocol, final, override
3+
from typing import Any, Awaitable, Callable, Literal, Protocol, final, override
44

55
from splunklib.ai.messages import AgentResponse
66

@@ -34,7 +34,7 @@ class AgentHook(Protocol):
3434
# Name of the middleware must be unique
3535
name: str
3636

37-
def __call__(self, state: AgentState) -> None:
37+
def __call__(self, state: AgentState) -> None | Awaitable[None]:
3838
"""Called at specific points during the agent execution, depending on the hook type."""
3939

4040

@@ -65,7 +65,7 @@ def __init__(self, timeout_seconds: float) -> None:
6565

6666
def _create_hook(
6767
type: HookType,
68-
func: Callable[[AgentState], None],
68+
func: Callable[[AgentState], None | Awaitable[None]],
6969
name: str | None = None,
7070
) -> AgentHook:
7171
mw_name = name or func.__name__
@@ -77,31 +77,31 @@ class CustomHook(AgentHook):
7777
name = mw_name
7878

7979
@override
80-
def __call__(self, state: AgentState) -> None:
80+
def __call__(self, state: AgentState) -> None | Awaitable[None]:
8181
return func(state)
8282

8383
return CustomHook()
8484

8585

86-
def before_model(func: Callable[[AgentState], None]) -> AgentHook:
86+
def before_model(func: Callable[[AgentState], None | Awaitable[None]]) -> AgentHook:
8787
"""This hook is called before each model call."""
8888

8989
return _create_hook("before_model", func)
9090

9191

92-
def after_model(func: Callable[[AgentState], None]) -> AgentHook:
92+
def after_model(func: Callable[[AgentState], None | Awaitable[None]]) -> AgentHook:
9393
"""This hook is called after each model call."""
9494

9595
return _create_hook("after_model", func)
9696

9797

98-
def before_agent(func: Callable[[AgentState], None]) -> AgentHook:
98+
def before_agent(func: Callable[[AgentState], None | Awaitable[None]]) -> AgentHook:
9999
"""This hook is called once per agent invocation. Before any model calls."""
100100

101101
return _create_hook("before_agent", func)
102102

103103

104-
def after_agent(func: Callable[[AgentState], None]) -> AgentHook:
104+
def after_agent(func: Callable[[AgentState], None | Awaitable[None]]) -> AgentHook:
105105
"""This hook is called once per agent invocation. After all model calls."""
106106

107107
return _create_hook("after_agent", func)

tests/integration/ai/test_hooks.py

Lines changed: 63 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -57,20 +57,35 @@ async def test_agent_hooks_duplicated(self):
5757
async def test_agent_hook(self):
5858
pytest.importorskip("langchain_openai")
5959

60+
hook_calls = 0
61+
6062
@final
6163
class TestHook(AgentHook):
6264
type = "before_model"
63-
name = "test_hook"
65+
name = "test_async_hook"
6466

6567
@override
6668
def __call__(self, state: AgentState) -> None:
69+
nonlocal hook_calls
70+
hook_calls += 1
71+
assert len(state.response.messages) == 1
72+
73+
@final
74+
class TestAsyncHook(AgentHook):
75+
type = "before_model"
76+
name = "test_hook"
77+
78+
@override
79+
async def __call__(self, state: AgentState) -> None:
80+
nonlocal hook_calls
81+
hook_calls += 1
6782
assert len(state.response.messages) == 1
6883

6984
async with Agent(
7085
model=(await self.model()),
7186
system_prompt="Your name is stefan",
7287
service=self.service,
73-
hooks=[TestHook()],
88+
hooks=[TestHook(), TestAsyncHook()],
7489
) as agent:
7590
result = await agent.invoke(
7691
[
@@ -82,6 +97,7 @@ def __call__(self, state: AgentState) -> None:
8297

8398
response = result.messages[-1].content.strip().lower().replace(".", "")
8499
assert "stefan" == response
100+
assert hook_calls == 2
85101

86102
@pytest.mark.asyncio
87103
async def test_agent_hook_decorator(self):
@@ -96,18 +112,37 @@ def test_hook_before(state: AgentState) -> None:
96112

97113
assert len(state.response.messages) == 1
98114

115+
@before_model
116+
async def test_async_hook_before(state: AgentState) -> None:
117+
nonlocal hook_calls
118+
hook_calls += 1
119+
120+
assert len(state.response.messages) == 1
121+
99122
@after_model
100123
def test_hook_after(state: AgentState) -> None:
101124
nonlocal hook_calls
102125
hook_calls += 1
103126

104127
assert len(state.response.messages) == 2
105128

129+
@after_model
130+
async def test_async_hook_after(state: AgentState) -> None:
131+
nonlocal hook_calls
132+
hook_calls += 1
133+
134+
assert len(state.response.messages) == 2
135+
106136
async with Agent(
107137
model=(await self.model()),
108138
system_prompt="Your name is stefan",
109139
service=self.service,
110-
hooks=[test_hook_before, test_hook_after],
140+
hooks=[
141+
test_hook_before,
142+
test_async_hook_before,
143+
test_hook_after,
144+
test_async_hook_after,
145+
],
111146
) as agent:
112147
result = await agent.invoke(
113148
[
@@ -119,7 +154,7 @@ def test_hook_after(state: AgentState) -> None:
119154

120155
response = result.messages[-1].content.strip().lower().replace(".", "")
121156
assert "stefan" == response
122-
assert hook_calls == 2
157+
assert hook_calls == 4
123158

124159
@pytest.mark.asyncio
125160
async def test_agent_hook_agent(self):
@@ -137,8 +172,24 @@ def before_agent_hook(state: AgentState) -> None:
137172

138173
assert len(state.response.messages) == 1
139174

175+
@before_agent
176+
async def before_async_agent_hook(state: AgentState) -> None:
177+
nonlocal hook_calls
178+
hook_calls += 1
179+
180+
assert len(state.response.messages) == 1
181+
140182
@after_agent
141-
def after_agent_hook(state: AgentState) -> None:
183+
async def after_agent_hook(state: AgentState) -> None:
184+
nonlocal hook_calls
185+
hook_calls += 1
186+
187+
person = state.response.structured_output
188+
assert person.name.lower() == "stefan"
189+
assert len(state.response.messages) == 2
190+
191+
@after_agent
192+
async def after_async_agent_hook(state: AgentState) -> None:
142193
nonlocal hook_calls
143194
hook_calls += 1
144195

@@ -150,7 +201,12 @@ def after_agent_hook(state: AgentState) -> None:
150201
model=(await self.model()),
151202
system_prompt="Your name is stefan",
152203
service=self.service,
153-
hooks=[before_agent_hook, after_agent_hook],
204+
hooks=[
205+
before_agent_hook,
206+
before_async_agent_hook,
207+
after_agent_hook,
208+
after_async_agent_hook,
209+
],
154210
output_schema=Person,
155211
) as agent:
156212
result = await agent.invoke(
@@ -163,7 +219,7 @@ def after_agent_hook(state: AgentState) -> None:
163219

164220
response = result.messages[-1].content.strip().lower().replace(".", "")
165221
assert '{"name":"stefan"}' == response
166-
assert hook_calls == 2
222+
assert hook_calls == 4
167223

168224
@pytest.mark.asyncio
169225
async def test_agent_loop_stop_conditions_token_limit(self):

0 commit comments

Comments
 (0)