|
41 | 41 | from langchain.agents.middleware.summarization import TokenCounter |
42 | 42 | from langchain.agents.middleware.types import ModelCallResult |
43 | 43 | from langchain.messages import AIMessage as LC_AIMessage |
| 44 | +from langchain.messages import AnyMessage as LC_AnyMessage |
44 | 45 | from langchain.messages import HumanMessage as LC_HumanMessage |
45 | 46 | from langchain.messages import SystemMessage as LC_SystemMessage |
46 | 47 | from langchain.messages import ToolCall as LC_ToolCall |
@@ -529,16 +530,21 @@ def _convert_tool_message_from_lc( |
529 | 530 |
|
530 | 531 | def _convert_model_result_from_lc(model_response: ModelCallResult) -> AIMessage: |
531 | 532 | if isinstance(model_response, ModelResponse): |
532 | | - model_response = model_response.result[-1] |
| 533 | + ai_message = next( |
| 534 | + (m for m in model_response.result if isinstance(m, LC_AIMessage)), None |
| 535 | + ) |
| 536 | + assert ai_message, "ModelResponse should contain at least one LC_AIMessage" |
| 537 | + else: |
| 538 | + ai_message = model_response |
533 | 539 |
|
534 | 540 | return AIMessage( |
535 | | - content=model_response.content, |
536 | | - calls=[_map_tool_call_from_langchain(tc) for tc in model_response.tool_calls], |
| 541 | + content=ai_message.content.__str__(), |
| 542 | + calls=[_map_tool_call_from_langchain(tc) for tc in ai_message.tool_calls], |
537 | 543 | ) |
538 | 544 |
|
539 | 545 |
|
540 | 546 | def _convert_agent_state_to_lc(state: AgentState) -> LC_AgentState: # pyright: ignore[reportMissingTypeArgument, reportUnknownParameterType] |
541 | | - return LC_AgentState( |
| 547 | + return LC_AgentState( # pyright: ignore[reportUnknownVariableType] |
542 | 548 | messages=[_map_message_to_langchain(m) for m in state.response.messages], |
543 | 549 | ) |
544 | 550 |
|
@@ -751,7 +757,7 @@ def _map_message_from_langchain(message: LC_BaseMessage) -> BaseMessage: |
751 | 757 | raise InvalidMessageTypeError("Invalid langchain message type") |
752 | 758 |
|
753 | 759 |
|
754 | | -def _map_message_to_langchain(message: BaseMessage) -> LC_BaseMessage: |
| 760 | +def _map_message_to_langchain(message: BaseMessage) -> LC_AnyMessage: |
755 | 761 | match message: |
756 | 762 | case AIMessage(): |
757 | 763 | lc_message = LC_AIMessage(content=message.content) |
|
0 commit comments