Skip to content

Commit 04dfeb6

Browse files
authored
Propagate structured outputs in model middleware (#81)
1 parent f63d44f commit 04dfeb6

4 files changed

Lines changed: 134 additions & 34 deletions

File tree

splunklib/ai/README.md

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -407,14 +407,15 @@ from splunklib.ai.middleware import (
407407
AgentRequest,
408408
ModelMiddlewareHandler,
409409
ModelRequest,
410+
ModelResponse,
410411
SubagentMiddlewareHandler,
411412
SubagentRequest,
412413
SubagentResponse,
413414
ToolMiddlewareHandler,
414415
ToolRequest,
415416
ToolResponse,
416417
)
417-
from splunklib.ai.messages import AIMessage, AgentResponse, ToolCall
418+
from splunklib.ai.messages import AgentResponse, ToolCall
418419

419420

420421
class ExampleMiddleware(AgentMiddleware):
@@ -431,7 +432,7 @@ class ExampleMiddleware(AgentMiddleware):
431432
@override
432433
async def model_middleware(
433434
self, request: ModelRequest, handler: ModelMiddlewareHandler
434-
) -> AIMessage:
435+
) -> ModelResponse:
435436
request.system_message = request.system_message.replace("SECRET", "[REDACTED]")
436437
return await handler(request)
437438

@@ -484,14 +485,13 @@ from splunklib.ai.middleware import (
484485
model_middleware,
485486
ModelMiddlewareHandler,
486487
ModelRequest,
488+
ModelResponse,
487489
)
488-
from splunklib.ai.messages import AIMessage
489-
490490

491491
@model_middleware
492492
async def redact_system_prompt(
493493
request: ModelRequest, handler: ModelMiddlewareHandler
494-
) -> AIMessage:
494+
) -> ModelResponse:
495495
request.system_message = request.system_message.replace("SECRET", "[REDACTED]")
496496
return await handler(request)
497497
```

splunklib/ai/engines/langchain.py

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@
8585
AgentRequest,
8686
ModelMiddlewareHandler,
8787
ModelRequest,
88+
ModelResponse,
8889
SubagentMiddlewareHandler,
8990
SubagentRequest,
9091
SubagentResponse,
@@ -352,7 +353,7 @@ async def awrap_model_call(
352353
sdk_request,
353354
_convert_model_handler_from_lc(handler, original_request=request),
354355
)
355-
return _convert_ai_message_to_model_result(sdk_response)
356+
return _convert_model_response_to_model_result(sdk_response)
356357

357358
@override
358359
async def awrap_tool_call(
@@ -436,7 +437,7 @@ def _convert_model_handler_from_lc(
436437
handler: Callable[[LC_ModelRequest], Awaitable[LC_ModelCallResult]],
437438
original_request: LC_ModelRequest,
438439
) -> ModelMiddlewareHandler:
439-
async def _sdk_handler(request: ModelRequest) -> AIMessage:
440+
async def _sdk_handler(request: ModelRequest) -> ModelResponse:
440441
lc_request = _convert_model_request_to_lc(request, original_request)
441442
result = await handler(lc_request)
442443

@@ -508,10 +509,17 @@ def _convert_model_request_to_lc(
508509
)
509510

510511

511-
def _convert_ai_message_to_model_result(message: AIMessage) -> LC_ModelCallResult:
512-
lc_message = LC_AIMessage(content=message.content)
512+
def _convert_model_response_to_model_result(
513+
resp: ModelResponse,
514+
) -> LC_ModelCallResult:
515+
lc_message = LC_AIMessage(content=resp.message.content)
513516
# This field can't be set via __init__()
514-
lc_message.tool_calls = [_map_tool_call_to_langchain(c) for c in message.calls]
517+
lc_message.tool_calls = [_map_tool_call_to_langchain(c) for c in resp.message.calls]
518+
if resp.structured_output is not None:
519+
return LC_ModelResponse(
520+
result=[lc_message],
521+
structured_response=resp.structured_output,
522+
)
515523
return lc_message
516524

517525

@@ -585,18 +593,23 @@ def _convert_tool_message_from_lc(
585593
raise NotImplementedError("Command is not supported")
586594

587595

588-
def _convert_model_result_from_lc(model_response: LC_ModelCallResult) -> AIMessage:
596+
def _convert_model_result_from_lc(model_response: LC_ModelCallResult) -> ModelResponse:
589597
if isinstance(model_response, LC_ModelResponse):
590598
ai_message = next(
591599
(m for m in model_response.result if isinstance(m, LC_AIMessage)), None
592600
)
593601
assert ai_message, "ModelResponse should contain at least one LC_AIMessage"
602+
structured_response = model_response.structured_response
594603
else:
595604
ai_message = model_response
596-
597-
return AIMessage(
598-
content=ai_message.content.__str__(),
599-
calls=[_map_tool_call_from_langchain(tc) for tc in ai_message.tool_calls],
605+
structured_response = None
606+
607+
return ModelResponse(
608+
message=AIMessage(
609+
content=ai_message.content.__str__(),
610+
calls=[_map_tool_call_from_langchain(tc) for tc in ai_message.tool_calls],
611+
),
612+
structured_output=structured_response,
600613
)
601614

602615

splunklib/ai/middleware.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,13 @@ class ModelRequest:
7373
state: AgentState
7474

7575

76-
ModelMiddlewareHandler = Callable[[ModelRequest], Awaitable[AIMessage]]
76+
@dataclass
77+
class ModelResponse:
78+
message: AIMessage
79+
structured_output: Any | None = None
80+
81+
82+
ModelMiddlewareHandler = Callable[[ModelRequest], Awaitable[ModelResponse]]
7783

7884

7985
@dataclass
@@ -107,7 +113,7 @@ async def model_middleware(
107113
self,
108114
request: ModelRequest,
109115
handler: ModelMiddlewareHandler,
110-
) -> AIMessage:
116+
) -> ModelResponse:
111117
"""Executed in between the LLM calls"""
112118

113119
return await handler(request)
@@ -155,15 +161,15 @@ async def subagent_middleware(
155161

156162

157163
def model_middleware(
158-
func: Callable[[ModelRequest, ModelMiddlewareHandler], Awaitable[AIMessage]],
164+
func: Callable[[ModelRequest, ModelMiddlewareHandler], Awaitable[ModelResponse]],
159165
) -> AgentMiddleware:
160166
class _CustomMiddleware(AgentMiddleware):
161167
@override
162168
async def model_middleware(
163169
self,
164170
request: ModelRequest,
165171
handler: ModelMiddlewareHandler,
166-
) -> AIMessage:
172+
) -> ModelResponse:
167173
return await func(request, handler)
168174

169175
return _CustomMiddleware()

tests/integration/ai/test_middleware.py

Lines changed: 96 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
AgentRequest,
3636
ModelMiddlewareHandler,
3737
ModelRequest,
38+
ModelResponse,
3839
SubagentMiddlewareHandler,
3940
SubagentRequest,
4041
SubagentResponse,
@@ -274,7 +275,7 @@ async def tool_test_middleware(
274275
@model_middleware
275276
async def model_test_middleware(
276277
request: ModelRequest, handler: ModelMiddlewareHandler
277-
) -> AIMessage:
278+
) -> ModelResponse:
278279
nonlocal model_called
279280
model_called = True
280281
return await handler(request)
@@ -310,7 +311,7 @@ class ExampleMiddleware(AgentMiddleware):
310311
@override
311312
async def model_middleware(
312313
self, request: ModelRequest, handler: ModelMiddlewareHandler
313-
) -> AIMessage:
314+
) -> ModelResponse:
314315
nonlocal model_called
315316
model_called = True
316317
return await handler(request)
@@ -512,21 +513,21 @@ async def test_agent_middleware_model_retry(self) -> None:
512513
@model_middleware
513514
async def test_middleware(
514515
request: ModelRequest, handler: ModelMiddlewareHandler
515-
) -> AIMessage:
516+
) -> ModelResponse:
516517
nonlocal middleware_called
517518
middleware_called = True
518519

519520
first_result = await handler(request)
520-
assert isinstance(first_result, AIMessage)
521+
assert isinstance(first_result, ModelResponse)
521522

522523
second_result = await handler(request)
523524

524525
# Only if it's a model response that contains the tool calls
525-
if first_result.calls:
526-
tool_call = first_result.calls[0]
526+
if first_result.message.calls:
527+
tool_call = first_result.message.calls[0]
527528
assert isinstance(tool_call, ToolCall)
528529

529-
second_tool_call = first_result.calls[0]
530+
second_tool_call = first_result.message.calls[0]
530531
assert isinstance(second_tool_call, ToolCall)
531532

532533
assert tool_call.name == second_tool_call.name == "temperature"
@@ -562,21 +563,21 @@ class NicknameGeneratorInput(BaseModel):
562563
@model_middleware
563564
async def test_middleware(
564565
request: ModelRequest, handler: ModelMiddlewareHandler
565-
) -> AIMessage:
566+
) -> ModelResponse:
566567
nonlocal middleware_called
567568
middleware_called = True
568569

569570
first_result = await handler(request)
570-
assert isinstance(first_result, AIMessage)
571+
assert isinstance(first_result, ModelResponse)
571572

572573
second_result = await handler(request)
573574

574575
# only if it's a model response that contains the subagent calls
575-
if first_result.calls:
576-
subagent_call = first_result.calls[0]
576+
if first_result.message.calls:
577+
subagent_call = first_result.message.calls[0]
577578
assert isinstance(subagent_call, SubagentCall)
578579

579-
second_subagent_call = first_result.calls[0]
580+
second_subagent_call = first_result.message.calls[0]
580581
assert isinstance(second_subagent_call, SubagentCall)
581582

582583
assert (
@@ -627,11 +628,11 @@ async def test_agent_middleware_model_made_up_response(self) -> None:
627628
@model_middleware
628629
async def test_middleware(
629630
_request: ModelRequest, _handler: ModelMiddlewareHandler
630-
) -> AIMessage:
631+
) -> ModelResponse:
631632
nonlocal middleware_called
632633
middleware_called = True
633634

634-
return AIMessage(content="My response is made up")
635+
return ModelResponse(message=AIMessage(content="My response is made up"))
635636

636637
async with Agent(
637638
model=await self.model(),
@@ -658,7 +659,7 @@ async def test_agent_middleware_model_exception_raised(self) -> None:
658659
@model_middleware
659660
async def test_middleware(
660661
_request: ModelRequest, _handler: ModelMiddlewareHandler
661-
) -> AIMessage:
662+
) -> ModelResponse:
662663
raise Exception("testing")
663664

664665
async with Agent(
@@ -676,6 +677,86 @@ async def test_middleware(
676677
]
677678
)
678679

680+
@pytest.mark.asyncio
681+
async def test_model_middleware_structured_output(self) -> None:
682+
pytest.importorskip("langchain_openai")
683+
684+
# Regression test - make sure that model middleware does not
685+
# cause structured output to be dropped.
686+
687+
class Output(BaseModel):
688+
name: str = Field(description="name of the Person")
689+
690+
@model_middleware
691+
async def test_middleware(
692+
req: ModelRequest, handler: ModelMiddlewareHandler
693+
) -> ModelResponse:
694+
return await handler(req)
695+
696+
async with Agent(
697+
model=await self.model(),
698+
system_prompt="Your name is stefan",
699+
service=self.service,
700+
middleware=[test_middleware],
701+
output_schema=Output,
702+
) as agent:
703+
resp = await agent.invoke([HumanMessage(content="What is your name?")])
704+
assert resp.structured_output.name.lower() == "stefan"
705+
706+
@pytest.mark.asyncio
707+
async def test_model_middleware_modify_structured_output(self) -> None:
708+
pytest.importorskip("langchain_openai")
709+
710+
class Output(BaseModel):
711+
name: str = Field(description="name of the Person")
712+
713+
@model_middleware
714+
async def test_middleware(
715+
req: ModelRequest, handler: ModelMiddlewareHandler
716+
) -> ModelResponse:
717+
resp = await handler(req)
718+
assert type(resp.structured_output) is Output
719+
resp.structured_output.name = "Mike"
720+
return resp
721+
722+
async with Agent(
723+
model=await self.model(),
724+
system_prompt="Your name is stefan",
725+
service=self.service,
726+
middleware=[test_middleware],
727+
output_schema=Output,
728+
) as agent:
729+
resp = await agent.invoke([HumanMessage(content="What is your name?")])
730+
assert resp.structured_output.name == "Mike"
731+
732+
@pytest.mark.asyncio
733+
async def test_model_middleware_made_up_structured_output(self) -> None:
734+
pytest.importorskip("langchain_openai")
735+
736+
class Output(BaseModel):
737+
name: str = Field(description="name of the Person")
738+
739+
@model_middleware
740+
async def test_middleware(
741+
_req: ModelRequest, _handler: ModelMiddlewareHandler
742+
) -> ModelResponse:
743+
return ModelResponse(
744+
message=AIMessage(
745+
content="Stefan",
746+
),
747+
structured_output=Output(name="Stefan"),
748+
)
749+
750+
async with Agent(
751+
model=await self.model(),
752+
system_prompt="Your name is stefan",
753+
service=self.service,
754+
middleware=[test_middleware],
755+
output_schema=Output,
756+
) as agent:
757+
resp = await agent.invoke([HumanMessage(content="What is your name?")])
758+
assert resp.structured_output.name.lower() == "stefan"
759+
679760
@pytest.mark.asyncio
680761
async def test_agent_middleware(self) -> None:
681762
pytest.importorskip("langchain_openai")

0 commit comments

Comments
 (0)