Skip to content

Commit daf8f05

Browse files
authored
Introduce middleware (#69)
1 parent 442e3f9 commit daf8f05

9 files changed

Lines changed: 1364 additions & 54 deletions

File tree

splunklib/ai/README.md

Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -386,6 +386,163 @@ async with Agent(
386386

387387
**Note**: Currently input schemas can only be used by subagents, not by regular agents.
388388

389+
## Middleware
390+
391+
Middleware lets you intercept model, tool, and subagent calls in a request/handler chain.
392+
Each middleware can inspect input, call `handler(request)`, and modify the returned response.
393+
394+
Available decorators:
395+
396+
- `model_middleware`
397+
- `tool_middleware`
398+
- `subagent_middleware`
399+
400+
Class-based middleware:
401+
402+
```py
403+
from typing import override
404+
from splunklib.ai.middleware import (
405+
AgentMiddleware,
406+
ModelMiddlewareHandler,
407+
ModelRequest,
408+
SubagentMiddlewareHandler,
409+
SubagentRequest,
410+
SubagentResponse,
411+
ToolMiddlewareHandler,
412+
ToolRequest,
413+
ToolResponse,
414+
)
415+
from splunklib.ai.messages import AIMessage
416+
417+
418+
class ExampleMiddleware(AgentMiddleware):
419+
@override
420+
async def model_middleware(
421+
self, request: ModelRequest, handler: ModelMiddlewareHandler
422+
) -> AIMessage:
423+
request.system_message = request.system_message.replace("SECRET", "[REDACTED]")
424+
return await handler(request)
425+
426+
@override
427+
async def tool_middleware(
428+
self, request: ToolRequest, handler: ToolMiddlewareHandler
429+
) -> ToolResponse:
430+
if request.call.name == "temperature":
431+
return ToolResponse(content="25.0")
432+
return await handler(request)
433+
434+
@override
435+
async def subagent_middleware(
436+
self, request: SubagentRequest, handler: SubagentMiddlewareHandler
437+
) -> SubagentResponse:
438+
if request.call.name == "SummaryAgent":
439+
return SubagentResponse(
440+
content="Executive summary: no critical incidents detected."
441+
)
442+
return await handler(request)
443+
```
444+
445+
Example model middleware:
446+
447+
```py
448+
from splunklib.ai.middleware import (
449+
model_middleware,
450+
ModelMiddlewareHandler,
451+
ModelRequest,
452+
)
453+
from splunklib.ai.messages import AIMessage
454+
455+
456+
@model_middleware
457+
async def redact_system_prompt(
458+
request: ModelRequest, handler: ModelMiddlewareHandler
459+
) -> AIMessage:
460+
request.system_message = request.system_message.replace("SECRET", "[REDACTED]")
461+
return await handler(request)
462+
```
463+
464+
Example tool middleware:
465+
466+
```py
467+
from splunklib.ai.middleware import (
468+
tool_middleware,
469+
ToolMiddlewareHandler,
470+
ToolRequest,
471+
ToolResponse,
472+
)
473+
474+
475+
@tool_middleware
476+
async def mock_temperature(
477+
request: ToolRequest, handler: ToolMiddlewareHandler
478+
) -> ToolResponse:
479+
if request.call.name == "temperature":
480+
return ToolResponse(content="25.0")
481+
return await handler(request)
482+
```
483+
484+
Example subagent middleware:
485+
486+
```py
487+
from splunklib.ai.middleware import (
488+
subagent_middleware,
489+
SubagentMiddlewareHandler,
490+
SubagentRequest,
491+
SubagentResponse,
492+
)
493+
494+
495+
@subagent_middleware
496+
async def mock_subagent(
497+
request: SubagentRequest, handler: SubagentMiddlewareHandler
498+
) -> SubagentResponse:
499+
if request.call.name == "SummaryAgent":
500+
return SubagentResponse(
501+
content="Executive summary: no critical incidents detected."
502+
)
503+
return await handler(request)
504+
```
505+
506+
Retry pattern (bounded retries):
507+
508+
```py
509+
from splunklib.ai.middleware import (
510+
tool_middleware,
511+
ToolMiddlewareHandler,
512+
ToolRequest,
513+
ToolResponse,
514+
)
515+
516+
517+
class RetryableToolError(Exception): pass
518+
519+
520+
@tool_middleware
521+
async def retry_transient_tool_failures(
522+
request: ToolRequest, handler: ToolMiddlewareHandler
523+
) -> ToolResponse:
524+
last_error: Exception | None = None
525+
for _ in range(3):
526+
try:
527+
return await handler(request)
528+
except RetryableToolError as e:
529+
last_error = e
530+
531+
assert last_error is not None
532+
raise last_error
533+
```
534+
535+
Pass middleware to `Agent`:
536+
537+
```py
538+
async with Agent(
539+
model=model,
540+
service=service,
541+
system_prompt="...",
542+
middleware=[redact_system_prompt, mock_temperature, mock_subagent],
543+
) as agent: ...
544+
```
545+
389546
## Hooks
390547

391548
Hooks are user-defined callback functions that can be registered to execute at specific points

splunklib/ai/agent.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,10 @@
1212
# License for the specific language governing permissions and limitations
1313
# under the License.
1414

15-
from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontextmanager
16-
from logging import Logger
1715
import os
1816
from collections.abc import AsyncGenerator, Sequence
17+
from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontextmanager
18+
from logging import Logger
1919
from typing import Self, final, override
2020

2121
from pydantic import BaseModel
@@ -25,6 +25,7 @@
2525
from splunklib.ai.core.backend_registry import get_backend
2626
from splunklib.ai.hooks import AgentHook
2727
from splunklib.ai.messages import AgentResponse, BaseMessage, OutputT
28+
from splunklib.ai.middleware import AgentMiddleware
2829
from splunklib.ai.model import PredefinedModel
2930
from splunklib.ai.tool_filtering import ToolFilters, filter_tools
3031
from splunklib.ai.tools import (
@@ -130,6 +131,7 @@ def __init__(
130131
output_schema: type[OutputT] | None = None,
131132
input_schema: type[BaseModel] | None = None, # Only used by Subagents
132133
hooks: Sequence[AgentHook] | None = None,
134+
middleware: Sequence[AgentMiddleware] | None = None,
133135
name: str = "", # Only used by Subagents
134136
description: str = "", # Only used by Subagents
135137
logger: Logger | None = None,
@@ -143,6 +145,7 @@ def __init__(
143145
input_schema=input_schema,
144146
output_schema=output_schema,
145147
hooks=hooks,
148+
middleware=middleware,
146149
logger=logger,
147150
)
148151

splunklib/ai/base_agent.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,16 +12,17 @@
1212
# License for the specific language governing permissions and limitations
1313
# under the License.
1414

15-
from abc import ABC, abstractmethod
16-
from collections.abc import Sequence
1715
import logging
1816
import secrets
17+
from abc import ABC, abstractmethod
18+
from collections.abc import Sequence
1919
from typing import Generic
2020

2121
from pydantic import BaseModel
2222

2323
from splunklib.ai.hooks import AgentHook
2424
from splunklib.ai.messages import AgentResponse, BaseMessage, OutputT
25+
from splunklib.ai.middleware import AgentMiddleware
2526
from splunklib.ai.model import PredefinedModel
2627
from splunklib.ai.tools import Tool
2728

@@ -36,6 +37,7 @@ class BaseAgent(Generic[OutputT], ABC):
3637
_input_schema: type[BaseModel] | None = None
3738
_output_schema: type[OutputT] | None = None
3839
_hooks: Sequence[AgentHook] | None = None
40+
_middleware: Sequence[AgentMiddleware] | None = None
3941
_trace_id: str
4042
_logger: logging.Logger
4143

@@ -50,6 +52,7 @@ def __init__(
5052
input_schema: type[BaseModel] | None = None,
5153
output_schema: type[OutputT] | None = None,
5254
hooks: Sequence[AgentHook] | None = None,
55+
middleware: Sequence[AgentMiddleware] | None = None,
5356
logger: logging.Logger | None = None,
5457
) -> None:
5558
self._system_prompt = system_prompt
@@ -61,6 +64,7 @@ def __init__(
6164
self._input_schema = input_schema
6265
self._output_schema = output_schema
6366
self._hooks = tuple(hooks) if hooks else ()
67+
self._middleware = tuple(middleware) if middleware else ()
6468
self._trace_id = secrets.token_hex(16) # 32 Hex characters
6569

6670
if logger is None:
@@ -112,6 +116,10 @@ def output_schema(self) -> type[OutputT] | None:
112116
def hooks(self) -> Sequence[AgentHook] | None:
113117
return self._hooks
114118

119+
@property
120+
def middleware(self) -> Sequence[AgentMiddleware] | None:
121+
return self._middleware
122+
115123
@property
116124
def trace_id(self) -> str:
117125
return self._trace_id

0 commit comments

Comments
 (0)