Skip to content

Commit 6a6178b

Browse files
authored
Fix convertion type errors in langchain backend (#76)
1 parent 75899b0 commit 6a6178b

1 file changed

Lines changed: 11 additions & 5 deletions

File tree

splunklib/ai/engines/langchain.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
from langchain.agents.middleware.summarization import TokenCounter
4242
from langchain.agents.middleware.types import ModelCallResult
4343
from langchain.messages import AIMessage as LC_AIMessage
44+
from langchain.messages import AnyMessage as LC_AnyMessage
4445
from langchain.messages import HumanMessage as LC_HumanMessage
4546
from langchain.messages import SystemMessage as LC_SystemMessage
4647
from langchain.messages import ToolCall as LC_ToolCall
@@ -529,16 +530,21 @@ def _convert_tool_message_from_lc(
529530

530531
def _convert_model_result_from_lc(model_response: ModelCallResult) -> AIMessage:
531532
if isinstance(model_response, ModelResponse):
532-
model_response = model_response.result[-1]
533+
ai_message = next(
534+
(m for m in model_response.result if isinstance(m, LC_AIMessage)), None
535+
)
536+
assert ai_message, "ModelResponse should contain at least one LC_AIMessage"
537+
else:
538+
ai_message = model_response
533539

534540
return AIMessage(
535-
content=model_response.content,
536-
calls=[_map_tool_call_from_langchain(tc) for tc in model_response.tool_calls],
541+
content=ai_message.content.__str__(),
542+
calls=[_map_tool_call_from_langchain(tc) for tc in ai_message.tool_calls],
537543
)
538544

539545

540546
def _convert_agent_state_to_lc(state: AgentState) -> LC_AgentState: # pyright: ignore[reportMissingTypeArgument, reportUnknownParameterType]
541-
return LC_AgentState(
547+
return LC_AgentState( # pyright: ignore[reportUnknownVariableType]
542548
messages=[_map_message_to_langchain(m) for m in state.response.messages],
543549
)
544550

@@ -751,7 +757,7 @@ def _map_message_from_langchain(message: LC_BaseMessage) -> BaseMessage:
751757
raise InvalidMessageTypeError("Invalid langchain message type")
752758

753759

754-
def _map_message_to_langchain(message: BaseMessage) -> LC_BaseMessage:
760+
def _map_message_to_langchain(message: BaseMessage) -> LC_AnyMessage:
755761
match message:
756762
case AIMessage():
757763
lc_message = LC_AIMessage(content=message.content)

0 commit comments

Comments
 (0)