3535 AgentRequest ,
3636 ModelMiddlewareHandler ,
3737 ModelRequest ,
38+ ModelResponse ,
3839 SubagentMiddlewareHandler ,
3940 SubagentRequest ,
4041 SubagentResponse ,
@@ -274,7 +275,7 @@ async def tool_test_middleware(
274275 @model_middleware
275276 async def model_test_middleware (
276277 request : ModelRequest , handler : ModelMiddlewareHandler
277- ) -> AIMessage :
278+ ) -> ModelResponse :
278279 nonlocal model_called
279280 model_called = True
280281 return await handler (request )
@@ -310,7 +311,7 @@ class ExampleMiddleware(AgentMiddleware):
310311 @override
311312 async def model_middleware (
312313 self , request : ModelRequest , handler : ModelMiddlewareHandler
313- ) -> AIMessage :
314+ ) -> ModelResponse :
314315 nonlocal model_called
315316 model_called = True
316317 return await handler (request )
@@ -512,21 +513,21 @@ async def test_agent_middleware_model_retry(self) -> None:
512513 @model_middleware
513514 async def test_middleware (
514515 request : ModelRequest , handler : ModelMiddlewareHandler
515- ) -> AIMessage :
516+ ) -> ModelResponse :
516517 nonlocal middleware_called
517518 middleware_called = True
518519
519520 first_result = await handler (request )
520- assert isinstance (first_result , AIMessage )
521+ assert isinstance (first_result , ModelResponse )
521522
522523 second_result = await handler (request )
523524
524525 # Only if it's a model response that contains the tool calls
525- if first_result .calls :
526- tool_call = first_result .calls [0 ]
526+ if first_result .message . calls :
527+ tool_call = first_result .message . calls [0 ]
527528 assert isinstance (tool_call , ToolCall )
528529
529- second_tool_call = first_result .calls [0 ]
530+ second_tool_call = first_result .message . calls [0 ]
530531 assert isinstance (second_tool_call , ToolCall )
531532
532533 assert tool_call .name == second_tool_call .name == "temperature"
@@ -562,21 +563,21 @@ class NicknameGeneratorInput(BaseModel):
562563 @model_middleware
563564 async def test_middleware (
564565 request : ModelRequest , handler : ModelMiddlewareHandler
565- ) -> AIMessage :
566+ ) -> ModelResponse :
566567 nonlocal middleware_called
567568 middleware_called = True
568569
569570 first_result = await handler (request )
570- assert isinstance (first_result , AIMessage )
571+ assert isinstance (first_result , ModelResponse )
571572
572573 second_result = await handler (request )
573574
574575 # only if it's a model response that contains the subagent calls
575- if first_result .calls :
576- subagent_call = first_result .calls [0 ]
576+ if first_result .message . calls :
577+ subagent_call = first_result .message . calls [0 ]
577578 assert isinstance (subagent_call , SubagentCall )
578579
579- second_subagent_call = first_result .calls [0 ]
580+ second_subagent_call = first_result .message . calls [0 ]
580581 assert isinstance (second_subagent_call , SubagentCall )
581582
582583 assert (
@@ -627,11 +628,11 @@ async def test_agent_middleware_model_made_up_response(self) -> None:
627628 @model_middleware
628629 async def test_middleware (
629630 _request : ModelRequest , _handler : ModelMiddlewareHandler
630- ) -> AIMessage :
631+ ) -> ModelResponse :
631632 nonlocal middleware_called
632633 middleware_called = True
633634
634- return AIMessage (content = "My response is made up" )
635+ return ModelResponse ( message = AIMessage (content = "My response is made up" ) )
635636
636637 async with Agent (
637638 model = await self .model (),
@@ -658,7 +659,7 @@ async def test_agent_middleware_model_exception_raised(self) -> None:
658659 @model_middleware
659660 async def test_middleware (
660661 _request : ModelRequest , _handler : ModelMiddlewareHandler
661- ) -> AIMessage :
662+ ) -> ModelResponse :
662663 raise Exception ("testing" )
663664
664665 async with Agent (
@@ -676,6 +677,86 @@ async def test_middleware(
676677 ]
677678 )
678679
680+ @pytest .mark .asyncio
681+ async def test_model_middleware_structured_output (self ) -> None :
682+ pytest .importorskip ("langchain_openai" )
683+
684+ # Regression test - make sure that model middleware does not
685+ # cause structured output to be dropped.
686+
687+ class Output (BaseModel ):
688+ name : str = Field (description = "name of the Person" )
689+
690+ @model_middleware
691+ async def test_middleware (
692+ req : ModelRequest , handler : ModelMiddlewareHandler
693+ ) -> ModelResponse :
694+ return await handler (req )
695+
696+ async with Agent (
697+ model = await self .model (),
698+ system_prompt = "Your name is stefan" ,
699+ service = self .service ,
700+ middleware = [test_middleware ],
701+ output_schema = Output ,
702+ ) as agent :
703+ resp = await agent .invoke ([HumanMessage (content = "What is your name?" )])
704+ assert resp .structured_output .name .lower () == "stefan"
705+
706+ @pytest .mark .asyncio
707+ async def test_model_middleware_modify_structured_output (self ) -> None :
708+ pytest .importorskip ("langchain_openai" )
709+
710+ class Output (BaseModel ):
711+ name : str = Field (description = "name of the Person" )
712+
713+ @model_middleware
714+ async def test_middleware (
715+ req : ModelRequest , handler : ModelMiddlewareHandler
716+ ) -> ModelResponse :
717+ resp = await handler (req )
718+ assert type (resp .structured_output ) is Output
719+ resp .structured_output .name = "Mike"
720+ return resp
721+
722+ async with Agent (
723+ model = await self .model (),
724+ system_prompt = "Your name is stefan" ,
725+ service = self .service ,
726+ middleware = [test_middleware ],
727+ output_schema = Output ,
728+ ) as agent :
729+ resp = await agent .invoke ([HumanMessage (content = "What is your name?" )])
730+ assert resp .structured_output .name == "Mike"
731+
732+ @pytest .mark .asyncio
733+ async def test_model_middleware_made_up_structured_output (self ) -> None :
734+ pytest .importorskip ("langchain_openai" )
735+
736+ class Output (BaseModel ):
737+ name : str = Field (description = "name of the Person" )
738+
739+ @model_middleware
740+ async def test_middleware (
741+ _req : ModelRequest , _handler : ModelMiddlewareHandler
742+ ) -> ModelResponse :
743+ return ModelResponse (
744+ message = AIMessage (
745+ content = "Stefan" ,
746+ ),
747+ structured_output = Output (name = "Stefan" ),
748+ )
749+
750+ async with Agent (
751+ model = await self .model (),
752+ system_prompt = "Your name is stefan" ,
753+ service = self .service ,
754+ middleware = [test_middleware ],
755+ output_schema = Output ,
756+ ) as agent :
757+ resp = await agent .invoke ([HumanMessage (content = "What is your name?" )])
758+ assert resp .structured_output .name .lower () == "stefan"
759+
679760 @pytest .mark .asyncio
680761 async def test_agent_middleware (self ) -> None :
681762 pytest .importorskip ("langchain_openai" )
0 commit comments