1212# License for the specific language governing permissions and limitations
1313# under the License.
1414
15+ # pyright: reportPrivateUsage=false
16+
1517import unittest
1618
1719import pytest
20+ from langchain .messages import AIMessage as LC_AIMessage
21+ from langchain .messages import HumanMessage as LC_HumanMessage
22+ from langchain .messages import SystemMessage as LC_SystemMessage
23+ from langchain .messages import ToolCall as LC_ToolCall
24+ from langchain .messages import ToolMessage as LC_ToolMessage
1825
19- from langchain .messages import (
20- AIMessage as LC_AIMessage ,
21- HumanMessage as LC_HumanMessage ,
22- SystemMessage as LC_SystemMessage ,
23- ToolCall as LC_ToolCall ,
24- ToolMessage as LC_ToolMessage ,
25- )
26-
27- from splunklib .ai .core .backend import (
28- InvalidMessageTypeError ,
29- InvalidModelError ,
30- )
26+ from splunklib .ai .core .backend import InvalidMessageTypeError , InvalidModelError
3127from splunklib .ai .engines import langchain as lc
3228from splunklib .ai .messages import (
33- AIMessage ,
3429 AgentCall ,
30+ AIMessage ,
3531 HumanMessage ,
3632 SubagentMessage ,
3733 SystemMessage ,
@@ -57,9 +53,9 @@ def test_map_message_from_langchain_ai_with_agent_call(self) -> None:
5753 name = f"{ lc .AGENT_PREFIX } assistant" , args = {"q" : "test" }, id = "tc-2"
5854 )
5955 message = LC_AIMessage (content = "done" , tool_calls = [tool_call ])
60-
6156 mapped = lc ._map_message_from_langchain (message )
6257
58+ assert isinstance (mapped , AIMessage )
6359 assert mapped .calls == [
6460 AgentCall (
6561 name = "assistant" ,
@@ -77,6 +73,7 @@ def test_map_message_from_langchain_ai_with_mixed_calls(self) -> None:
7773
7874 mapped = lc ._map_message_from_langchain (message )
7975
76+ assert isinstance (mapped , AIMessage )
8077 assert mapped .calls == [
8178 ToolCall (name = "lookup" , args = {"q" : "test" }, id = "tc-1" ),
8279 AgentCall (
@@ -129,7 +126,7 @@ def test_map_message_from_langchain_subagent(self) -> None:
129126
130127 def test_map_message_from_langchain_invalid_raises (self ) -> None :
131128 with pytest .raises (InvalidMessageTypeError ):
132- lc ._map_message_from_langchain (object ())
129+ lc ._map_message_from_langchain (object ()) # pyright: ignore[reportArgumentType]
133130
134131
135132class MapMessageToLangchainTests (unittest .TestCase ):
@@ -280,7 +277,7 @@ def test_map_message_to_langchain_subagent(self) -> None:
280277
281278 def test_map_message_to_langchain_invalid_raises (self ) -> None :
282279 with pytest .raises (InvalidMessageTypeError ):
283- lc ._map_message_to_langchain (object ())
280+ lc ._map_message_to_langchain (object ()) # pyright: ignore[reportArgumentType]
284281
285282
286283class CreateLangchainModelTests (unittest .TestCase ):
@@ -289,7 +286,9 @@ def test_create_langchain_model_invalid_raises(self) -> None:
289286 lc ._create_langchain_model (PredefinedModel (model = "unknown" ))
290287
291288 def test_create_langchain_model_openai (self ) -> None :
292- langchain_openai = pytest .importorskip ("langchain_openai" )
289+ pytest .importorskip ("langchain_openai" )
290+ import langchain_openai
291+
293292 model = OpenAIModel (
294293 model = "gpt-test" ,
295294 base_url = "https://example.com" ,
@@ -302,7 +301,3 @@ def test_create_langchain_model_openai(self) -> None:
302301 assert result .model_name == model .model
303302 assert result .openai_api_base == model .base_url
304303 assert result .temperature == model .temperature
305-
306-
307- if __name__ == "__main__" :
308- unittest .main ()
0 commit comments