@@ -57,20 +57,35 @@ async def test_agent_hooks_duplicated(self):
5757 async def test_agent_hook (self ):
5858 pytest .importorskip ("langchain_openai" )
5959
60+ hook_calls = 0
61+
6062 @final
6163 class TestHook (AgentHook ):
6264 type = "before_model"
63- name = "test_hook "
65+ name = "test_async_hook "
6466
6567 @override
6668 def __call__ (self , state : AgentState ) -> None :
69+ nonlocal hook_calls
70+ hook_calls += 1
71+ assert len (state .response .messages ) == 1
72+
73+ @final
74+ class TestAsyncHook (AgentHook ):
75+ type = "before_model"
76+ name = "test_hook"
77+
78+ @override
79+ async def __call__ (self , state : AgentState ) -> None :
80+ nonlocal hook_calls
81+ hook_calls += 1
6782 assert len (state .response .messages ) == 1
6883
6984 async with Agent (
7085 model = (await self .model ()),
7186 system_prompt = "Your name is stefan" ,
7287 service = self .service ,
73- hooks = [TestHook ()],
88+ hooks = [TestHook (), TestAsyncHook () ],
7489 ) as agent :
7590 result = await agent .invoke (
7691 [
@@ -82,6 +97,7 @@ def __call__(self, state: AgentState) -> None:
8297
8398 response = result .messages [- 1 ].content .strip ().lower ().replace ("." , "" )
8499 assert "stefan" == response
100+ assert hook_calls == 2
85101
86102 @pytest .mark .asyncio
87103 async def test_agent_hook_decorator (self ):
@@ -96,18 +112,37 @@ def test_hook_before(state: AgentState) -> None:
96112
97113 assert len (state .response .messages ) == 1
98114
115+ @before_model
116+ async def test_async_hook_before (state : AgentState ) -> None :
117+ nonlocal hook_calls
118+ hook_calls += 1
119+
120+ assert len (state .response .messages ) == 1
121+
99122 @after_model
100123 def test_hook_after (state : AgentState ) -> None :
101124 nonlocal hook_calls
102125 hook_calls += 1
103126
104127 assert len (state .response .messages ) == 2
105128
129+ @after_model
130+ async def test_async_hook_after (state : AgentState ) -> None :
131+ nonlocal hook_calls
132+ hook_calls += 1
133+
134+ assert len (state .response .messages ) == 2
135+
106136 async with Agent (
107137 model = (await self .model ()),
108138 system_prompt = "Your name is stefan" ,
109139 service = self .service ,
110- hooks = [test_hook_before , test_hook_after ],
140+ hooks = [
141+ test_hook_before ,
142+ test_async_hook_before ,
143+ test_hook_after ,
144+ test_async_hook_after ,
145+ ],
111146 ) as agent :
112147 result = await agent .invoke (
113148 [
@@ -119,7 +154,7 @@ def test_hook_after(state: AgentState) -> None:
119154
120155 response = result .messages [- 1 ].content .strip ().lower ().replace ("." , "" )
121156 assert "stefan" == response
122- assert hook_calls == 2
157+ assert hook_calls == 4
123158
124159 @pytest .mark .asyncio
125160 async def test_agent_hook_agent (self ):
@@ -137,8 +172,24 @@ def before_agent_hook(state: AgentState) -> None:
137172
138173 assert len (state .response .messages ) == 1
139174
175+ @before_agent
176+ async def before_async_agent_hook (state : AgentState ) -> None :
177+ nonlocal hook_calls
178+ hook_calls += 1
179+
180+ assert len (state .response .messages ) == 1
181+
140182 @after_agent
141- def after_agent_hook (state : AgentState ) -> None :
183+ async def after_agent_hook (state : AgentState ) -> None :
184+ nonlocal hook_calls
185+ hook_calls += 1
186+
187+ person = state .response .structured_output
188+ assert person .name .lower () == "stefan"
189+ assert len (state .response .messages ) == 2
190+
191+ @after_agent
192+ async def after_async_agent_hook (state : AgentState ) -> None :
142193 nonlocal hook_calls
143194 hook_calls += 1
144195
@@ -150,7 +201,12 @@ def after_agent_hook(state: AgentState) -> None:
150201 model = (await self .model ()),
151202 system_prompt = "Your name is stefan" ,
152203 service = self .service ,
153- hooks = [before_agent_hook , after_agent_hook ],
204+ hooks = [
205+ before_agent_hook ,
206+ before_async_agent_hook ,
207+ after_agent_hook ,
208+ after_async_agent_hook ,
209+ ],
154210 output_schema = Person ,
155211 ) as agent :
156212 result = await agent .invoke (
@@ -163,7 +219,7 @@ def after_agent_hook(state: AgentState) -> None:
163219
164220 response = result .messages [- 1 ].content .strip ().lower ().replace ("." , "" )
165221 assert '{"name":"stefan"}' == response
166- assert hook_calls == 2
222+ assert hook_calls == 4
167223
168224 @pytest .mark .asyncio
169225 async def test_agent_loop_stop_conditions_token_limit (self ):
0 commit comments