Skip to content

Commit f63d44f

Browse files
authored
Add agent middleware (#78)
1 parent 2236203 commit f63d44f

5 files changed

Lines changed: 496 additions & 23 deletions

File tree

splunklib/ai/README.md

Lines changed: 38 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -393,16 +393,18 @@ Each middleware can inspect input, call `handler(request)`, and modify the retur
393393

394394
Available decorators:
395395

396+
- `agent_middleware`
396397
- `model_middleware`
397398
- `tool_middleware`
398399
- `subagent_middleware`
399400

400401
Class-based middleware:
401402

402403
```py
403-
from typing import override
404+
from typing import Any, override
404405
from splunklib.ai.middleware import (
405-
AgentMiddleware,
406+
AgentMiddlewareHandler,
407+
AgentRequest,
406408
ModelMiddlewareHandler,
407409
ModelRequest,
408410
SubagentMiddlewareHandler,
@@ -412,10 +414,20 @@ from splunklib.ai.middleware import (
412414
ToolRequest,
413415
ToolResponse,
414416
)
415-
from splunklib.ai.messages import AIMessage
417+
from splunklib.ai.messages import AIMessage, AgentResponse, ToolCall
416418

417419

418420
class ExampleMiddleware(AgentMiddleware):
421+
@override
422+
async def agent_middleware(
423+
self, request: AgentRequest, handler: AgentMiddlewareHandler
424+
) -> AgentResponse[Any | None]:
425+
# Keep retrying until the agent makes at least one tool call.
426+
resp = await handler(request)
427+
while not any(m for m in resp.messages if isinstance(m, ToolCall)):
428+
resp = await handler(request)
429+
return resp
430+
419431
@override
420432
async def model_middleware(
421433
self, request: ModelRequest, handler: ModelMiddlewareHandler
@@ -442,6 +454,29 @@ class ExampleMiddleware(AgentMiddleware):
442454
return await handler(request)
443455
```
444456

457+
Example agent middleware:
458+
459+
```py
460+
from typing import Any
461+
from splunklib.ai.middleware import (
462+
agent_middleware,
463+
AgentMiddlewareHandler,
464+
AgentRequest,
465+
)
466+
from splunklib.ai.messages import AgentResponse, ToolCall
467+
468+
469+
@agent_middleware
470+
async def force_tool_call(
471+
request: AgentRequest, handler: AgentMiddlewareHandler
472+
) -> AgentResponse[Any | None]:
473+
# Keep retrying until the agent makes at least one tool call.
474+
resp = await handler(request)
475+
while not any(m for m in resp.messages if isinstance(m, ToolCall)):
476+
resp = await handler(request)
477+
return resp
478+
```
479+
445480
Example model middleware:
446481

447482
```py

splunklib/ai/engines/langchain.py

Lines changed: 95 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -79,8 +79,10 @@
7979
ToolMessage,
8080
)
8181
from splunklib.ai.middleware import (
82+
AgentMiddlewareHandler,
8283
AgentState,
8384
AgentMiddleware,
85+
AgentRequest,
8486
ModelMiddlewareHandler,
8587
ModelRequest,
8688
SubagentMiddlewareHandler,
@@ -122,55 +124,128 @@ class LangChainAgentImpl(AgentImpl[OutputT]):
122124
_thread_id: uuid.UUID
123125
_config: RunnableConfig
124126
_output_schema: type[OutputT] | None
127+
_middleware: Sequence[AgentMiddleware]
125128

126129
def __init__(
127130
self,
128131
system_prompt: str,
129132
model: BaseChatModel,
130133
tools: list[BaseTool],
131134
output_schema: type[OutputT] | None,
132-
middleware: Sequence[LC_AgentMiddleware] | None = None,
135+
lcmiddleware: Sequence[LC_AgentMiddleware] | None = None,
136+
middleware: Sequence[AgentMiddleware] | None = None,
133137
) -> None:
134138
super().__init__()
135139
self._output_schema = output_schema
136140
self._thread_id = uuid.uuid4()
137141
self._config = {"configurable": {"thread_id": self._thread_id}}
142+
self._middleware = middleware or []
138143

139144
checkpointer = InMemorySaver()
140-
middleware = middleware or []
141145

142146
self._agent = create_agent(
143147
model=model,
144148
tools=tools,
145149
system_prompt=system_prompt,
146150
checkpointer=checkpointer,
147151
response_format=output_schema,
148-
middleware=middleware,
152+
middleware=lcmiddleware or [],
149153
)
150154

155+
def _with_agent_middleware(
156+
self,
157+
agent_invoke: Callable[[AgentRequest], Awaitable[AgentResponse[Any | None]]],
158+
) -> Callable[[AgentRequest], Awaitable[AgentResponse[Any | None]]]:
159+
# When provided with a list of middlewares, e.g. [m1, m2, m3],
160+
# they are executed in the following order:
161+
#
162+
# m1 -> m2 -> m3 -> agent_invoke
163+
#
164+
# Each middleware wraps the next one in the chain.
165+
#
166+
# - m1's handler calls m2.agent_middleware(...)
167+
# - m2's handler calls m3.agent_middleware(...)
168+
# - m3's handler eventually calls agent_invoke(...)
169+
#
170+
# We build the chain by iterating in reverse order.
171+
# Each middleware wraps the previously constructed handler,
172+
# so the first middleware in the list becomes the outermost one.
173+
174+
invoke = agent_invoke
175+
for middleware in reversed(self._middleware):
176+
177+
def make_next(
178+
m: AgentMiddleware, h: AgentMiddlewareHandler
179+
) -> AgentMiddlewareHandler:
180+
async def next(r: AgentRequest) -> AgentResponse[Any | None]:
181+
return await m.agent_middleware(r, h)
182+
183+
return next
184+
185+
invoke = make_next(middleware, invoke)
186+
187+
return invoke
188+
151189
@override
152190
async def invoke(self, messages: list[BaseMessage]) -> AgentResponse[OutputT]:
153-
langchain_msgs = [_map_message_to_langchain(m) for m in messages]
191+
async def invoke_agent(req: AgentRequest) -> AgentResponse[Any | None]:
192+
langchain_msgs = [_map_message_to_langchain(m) for m in req.messages]
154193

155-
# call the langchain agent
156-
result = await self._agent.ainvoke(
157-
{"messages": langchain_msgs},
158-
config=self._config,
159-
)
194+
# call the langchain agent
195+
result = await self._agent.ainvoke(
196+
{"messages": langchain_msgs},
197+
config=self._config,
198+
)
199+
200+
sdk_msgs = [_map_message_from_langchain(m) for m in result["messages"]]
201+
202+
# NOTE: Agent responses will always conform to output schema. Verifying
203+
# if an LLM made any mistakes or not is _always_ up to the developer.
204+
205+
assert (
206+
self._output_schema is None
207+
or type(result["structured_response"]) is self._output_schema
208+
)
209+
210+
if self._output_schema:
211+
return AgentResponse(
212+
structured_output=result["structured_response"],
213+
messages=sdk_msgs,
214+
)
215+
else:
216+
return AgentResponse(structured_output=None, messages=sdk_msgs)
160217

161-
sdk_msgs = [_map_message_from_langchain(m) for m in result["messages"]]
218+
result = await self._with_agent_middleware(invoke_agent)(
219+
AgentRequest(
220+
messages=messages,
221+
)
222+
)
162223

163-
# NOTE: Agent responses will always conform to output schema. Verifying
164-
# if an LLM made any mistakes or not is _always_ up to the developer.
165224
if self._output_schema:
166-
return AgentResponse(
167-
structured_output=result["structured_response"],
168-
messages=sdk_msgs,
225+
if result.structured_output is None:
226+
raise AssertionError("Agent middleware discarded a structured output")
227+
228+
if type(result.structured_output) is not self._output_schema:
229+
raise AssertionError(
230+
f"Agent middleware returned an invalid structured_output type: {type(result.structured_output)}, want: {self._output_schema}"
231+
)
232+
233+
return AgentResponse[OutputT](
234+
messages=result.messages,
235+
structured_output=result.structured_output,
169236
)
237+
else:
238+
if result.structured_output is not None:
239+
raise AssertionError(
240+
"Agent middleware unexpectedly included a structured output"
241+
)
170242

171-
# HACK: This let's us put None in the structured_output field. It also shows
172-
# None as the field type if no `output_schema`was provided to the Agent class.
173-
return AgentResponse(structured_output=cast(OutputT, None), messages=sdk_msgs)
243+
return AgentResponse[OutputT](
244+
messages=result.messages,
245+
# HACK: This let's us put None in the structured_output field. It also shows
246+
# None as the field type if no `output_schema`was provided to the Agent class.
247+
structured_output=cast(OutputT, None),
248+
)
174249

175250

176251
@final
@@ -229,7 +304,8 @@ async def create_agent(
229304
model=model_impl,
230305
tools=tools,
231306
output_schema=agent.output_schema,
232-
middleware=middleware,
307+
lcmiddleware=middleware,
308+
middleware=agent.middleware,
233309
)
234310

235311

splunklib/ai/messages.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ class SubagentCall:
3535
args: dict[str, Any]
3636
id: str | None # TODO: can be None?
3737

38+
3839
@dataclass(frozen=True)
3940
class BaseMessage:
4041
role: str = ""

splunklib/ai/middleware.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from splunklib.ai.messages import (
2020
AIMessage,
2121
AgentResponse,
22+
BaseMessage,
2223
SubagentCall,
2324
ToolCall,
2425
)
@@ -75,6 +76,14 @@ class ModelRequest:
7576
ModelMiddlewareHandler = Callable[[ModelRequest], Awaitable[AIMessage]]
7677

7778

79+
@dataclass
80+
class AgentRequest:
81+
messages: list[BaseMessage]
82+
83+
84+
AgentMiddlewareHandler = Callable[[AgentRequest], Awaitable[AgentResponse[Any | None]]]
85+
86+
7887
class AgentMiddleware:
7988
async def tool_middleware(
8089
self,
@@ -103,6 +112,15 @@ async def model_middleware(
103112

104113
return await handler(request)
105114

115+
async def agent_middleware(
116+
self,
117+
request: AgentRequest,
118+
handler: AgentMiddlewareHandler,
119+
) -> AgentResponse[Any | None]:
120+
"""Executed in between invoke"""
121+
122+
return await handler(request)
123+
106124

107125
def tool_middleware(
108126
func: Callable[[ToolRequest, ToolMiddlewareHandler], Awaitable[ToolResponse]],
@@ -149,3 +167,20 @@ async def model_middleware(
149167
return await func(request, handler)
150168

151169
return _CustomMiddleware()
170+
171+
172+
def agent_middleware(
173+
func: Callable[
174+
[AgentRequest, AgentMiddlewareHandler], Awaitable[AgentResponse[Any | None]]
175+
],
176+
) -> AgentMiddleware:
177+
class _CustomMiddleware(AgentMiddleware):
178+
@override
179+
async def agent_middleware(
180+
self,
181+
request: AgentRequest,
182+
handler: AgentMiddlewareHandler,
183+
) -> AgentResponse[Any | None]:
184+
return await func(request, handler)
185+
186+
return _CustomMiddleware()

0 commit comments

Comments
 (0)