|
79 | 79 | ToolMessage, |
80 | 80 | ) |
81 | 81 | from splunklib.ai.middleware import ( |
| 82 | + AgentMiddlewareHandler, |
82 | 83 | AgentState, |
83 | 84 | AgentMiddleware, |
| 85 | + AgentRequest, |
84 | 86 | ModelMiddlewareHandler, |
85 | 87 | ModelRequest, |
86 | 88 | SubagentMiddlewareHandler, |
@@ -122,55 +124,128 @@ class LangChainAgentImpl(AgentImpl[OutputT]): |
122 | 124 | _thread_id: uuid.UUID |
123 | 125 | _config: RunnableConfig |
124 | 126 | _output_schema: type[OutputT] | None |
| 127 | + _middleware: Sequence[AgentMiddleware] |
125 | 128 |
|
126 | 129 | def __init__( |
127 | 130 | self, |
128 | 131 | system_prompt: str, |
129 | 132 | model: BaseChatModel, |
130 | 133 | tools: list[BaseTool], |
131 | 134 | output_schema: type[OutputT] | None, |
132 | | - middleware: Sequence[LC_AgentMiddleware] | None = None, |
| 135 | + lcmiddleware: Sequence[LC_AgentMiddleware] | None = None, |
| 136 | + middleware: Sequence[AgentMiddleware] | None = None, |
133 | 137 | ) -> None: |
134 | 138 | super().__init__() |
135 | 139 | self._output_schema = output_schema |
136 | 140 | self._thread_id = uuid.uuid4() |
137 | 141 | self._config = {"configurable": {"thread_id": self._thread_id}} |
| 142 | + self._middleware = middleware or [] |
138 | 143 |
|
139 | 144 | checkpointer = InMemorySaver() |
140 | | - middleware = middleware or [] |
141 | 145 |
|
142 | 146 | self._agent = create_agent( |
143 | 147 | model=model, |
144 | 148 | tools=tools, |
145 | 149 | system_prompt=system_prompt, |
146 | 150 | checkpointer=checkpointer, |
147 | 151 | response_format=output_schema, |
148 | | - middleware=middleware, |
| 152 | + middleware=lcmiddleware or [], |
149 | 153 | ) |
150 | 154 |
|
| 155 | + def _with_agent_middleware( |
| 156 | + self, |
| 157 | + agent_invoke: Callable[[AgentRequest], Awaitable[AgentResponse[Any | None]]], |
| 158 | + ) -> Callable[[AgentRequest], Awaitable[AgentResponse[Any | None]]]: |
| 159 | + # When provided with a list of middlewares, e.g. [m1, m2, m3], |
| 160 | + # they are executed in the following order: |
| 161 | + # |
| 162 | + # m1 -> m2 -> m3 -> agent_invoke |
| 163 | + # |
| 164 | + # Each middleware wraps the next one in the chain. |
| 165 | + # |
| 166 | + # - m1's handler calls m2.agent_middleware(...) |
| 167 | + # - m2's handler calls m3.agent_middleware(...) |
| 168 | + # - m3's handler eventually calls agent_invoke(...) |
| 169 | + # |
| 170 | + # We build the chain by iterating in reverse order. |
| 171 | + # Each middleware wraps the previously constructed handler, |
| 172 | + # so the first middleware in the list becomes the outermost one. |
| 173 | + |
| 174 | + invoke = agent_invoke |
| 175 | + for middleware in reversed(self._middleware): |
| 176 | + |
| 177 | + def make_next( |
| 178 | + m: AgentMiddleware, h: AgentMiddlewareHandler |
| 179 | + ) -> AgentMiddlewareHandler: |
| 180 | + async def next(r: AgentRequest) -> AgentResponse[Any | None]: |
| 181 | + return await m.agent_middleware(r, h) |
| 182 | + |
| 183 | + return next |
| 184 | + |
| 185 | + invoke = make_next(middleware, invoke) |
| 186 | + |
| 187 | + return invoke |
| 188 | + |
151 | 189 | @override |
152 | 190 | async def invoke(self, messages: list[BaseMessage]) -> AgentResponse[OutputT]: |
153 | | - langchain_msgs = [_map_message_to_langchain(m) for m in messages] |
| 191 | + async def invoke_agent(req: AgentRequest) -> AgentResponse[Any | None]: |
| 192 | + langchain_msgs = [_map_message_to_langchain(m) for m in req.messages] |
154 | 193 |
|
155 | | - # call the langchain agent |
156 | | - result = await self._agent.ainvoke( |
157 | | - {"messages": langchain_msgs}, |
158 | | - config=self._config, |
159 | | - ) |
| 194 | + # call the langchain agent |
| 195 | + result = await self._agent.ainvoke( |
| 196 | + {"messages": langchain_msgs}, |
| 197 | + config=self._config, |
| 198 | + ) |
| 199 | + |
| 200 | + sdk_msgs = [_map_message_from_langchain(m) for m in result["messages"]] |
| 201 | + |
| 202 | + # NOTE: Agent responses will always conform to output schema. Verifying |
| 203 | + # if an LLM made any mistakes or not is _always_ up to the developer. |
| 204 | + |
| 205 | + assert ( |
| 206 | + self._output_schema is None |
| 207 | + or type(result["structured_response"]) is self._output_schema |
| 208 | + ) |
| 209 | + |
| 210 | + if self._output_schema: |
| 211 | + return AgentResponse( |
| 212 | + structured_output=result["structured_response"], |
| 213 | + messages=sdk_msgs, |
| 214 | + ) |
| 215 | + else: |
| 216 | + return AgentResponse(structured_output=None, messages=sdk_msgs) |
160 | 217 |
|
161 | | - sdk_msgs = [_map_message_from_langchain(m) for m in result["messages"]] |
| 218 | + result = await self._with_agent_middleware(invoke_agent)( |
| 219 | + AgentRequest( |
| 220 | + messages=messages, |
| 221 | + ) |
| 222 | + ) |
162 | 223 |
|
163 | | - # NOTE: Agent responses will always conform to output schema. Verifying |
164 | | - # if an LLM made any mistakes or not is _always_ up to the developer. |
165 | 224 | if self._output_schema: |
166 | | - return AgentResponse( |
167 | | - structured_output=result["structured_response"], |
168 | | - messages=sdk_msgs, |
| 225 | + if result.structured_output is None: |
| 226 | + raise AssertionError("Agent middleware discarded a structured output") |
| 227 | + |
| 228 | + if type(result.structured_output) is not self._output_schema: |
| 229 | + raise AssertionError( |
| 230 | + f"Agent middleware returned an invalid structured_output type: {type(result.structured_output)}, want: {self._output_schema}" |
| 231 | + ) |
| 232 | + |
| 233 | + return AgentResponse[OutputT]( |
| 234 | + messages=result.messages, |
| 235 | + structured_output=result.structured_output, |
169 | 236 | ) |
| 237 | + else: |
| 238 | + if result.structured_output is not None: |
| 239 | + raise AssertionError( |
| 240 | + "Agent middleware unexpectedly included a structured output" |
| 241 | + ) |
170 | 242 |
|
171 | | - # HACK: This let's us put None in the structured_output field. It also shows |
172 | | - # None as the field type if no `output_schema`was provided to the Agent class. |
173 | | - return AgentResponse(structured_output=cast(OutputT, None), messages=sdk_msgs) |
| 243 | + return AgentResponse[OutputT]( |
| 244 | + messages=result.messages, |
| 245 | + # HACK: This let's us put None in the structured_output field. It also shows |
| 246 | + # None as the field type if no `output_schema`was provided to the Agent class. |
| 247 | + structured_output=cast(OutputT, None), |
| 248 | + ) |
174 | 249 |
|
175 | 250 |
|
176 | 251 | @final |
@@ -229,7 +304,8 @@ async def create_agent( |
229 | 304 | model=model_impl, |
230 | 305 | tools=tools, |
231 | 306 | output_schema=agent.output_schema, |
232 | | - middleware=middleware, |
| 307 | + lcmiddleware=middleware, |
| 308 | + middleware=agent.middleware, |
233 | 309 | ) |
234 | 310 |
|
235 | 311 |
|
|
0 commit comments