Skip to content

Commit 8171f30

Browse files
authored
Fix errors and warnings in test_langchain_backend.py (#70)
* Fix errors and warnings in test_langchain_backend.py * Shorten import * Replace brute-force cast()s with asserts
1 parent b6ecd4d commit 8171f30

1 file changed

Lines changed: 16 additions & 21 deletions

File tree

tests/unit/ai/engine/test_langchain_backend.py

Lines changed: 16 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -12,26 +12,22 @@
1212
# License for the specific language governing permissions and limitations
1313
# under the License.
1414

15+
# pyright: reportPrivateUsage=false
16+
1517
import unittest
1618

1719
import pytest
20+
from langchain.messages import AIMessage as LC_AIMessage
21+
from langchain.messages import HumanMessage as LC_HumanMessage
22+
from langchain.messages import SystemMessage as LC_SystemMessage
23+
from langchain.messages import ToolCall as LC_ToolCall
24+
from langchain.messages import ToolMessage as LC_ToolMessage
1825

19-
from langchain.messages import (
20-
AIMessage as LC_AIMessage,
21-
HumanMessage as LC_HumanMessage,
22-
SystemMessage as LC_SystemMessage,
23-
ToolCall as LC_ToolCall,
24-
ToolMessage as LC_ToolMessage,
25-
)
26-
27-
from splunklib.ai.core.backend import (
28-
InvalidMessageTypeError,
29-
InvalidModelError,
30-
)
26+
from splunklib.ai.core.backend import InvalidMessageTypeError, InvalidModelError
3127
from splunklib.ai.engines import langchain as lc
3228
from splunklib.ai.messages import (
33-
AIMessage,
3429
AgentCall,
30+
AIMessage,
3531
HumanMessage,
3632
SubagentMessage,
3733
SystemMessage,
@@ -57,9 +53,9 @@ def test_map_message_from_langchain_ai_with_agent_call(self) -> None:
5753
name=f"{lc.AGENT_PREFIX}assistant", args={"q": "test"}, id="tc-2"
5854
)
5955
message = LC_AIMessage(content="done", tool_calls=[tool_call])
60-
6156
mapped = lc._map_message_from_langchain(message)
6257

58+
assert isinstance(mapped, AIMessage)
6359
assert mapped.calls == [
6460
AgentCall(
6561
name="assistant",
@@ -77,6 +73,7 @@ def test_map_message_from_langchain_ai_with_mixed_calls(self) -> None:
7773

7874
mapped = lc._map_message_from_langchain(message)
7975

76+
assert isinstance(mapped, AIMessage)
8077
assert mapped.calls == [
8178
ToolCall(name="lookup", args={"q": "test"}, id="tc-1"),
8279
AgentCall(
@@ -129,7 +126,7 @@ def test_map_message_from_langchain_subagent(self) -> None:
129126

130127
def test_map_message_from_langchain_invalid_raises(self) -> None:
131128
with pytest.raises(InvalidMessageTypeError):
132-
lc._map_message_from_langchain(object())
129+
lc._map_message_from_langchain(object()) # pyright: ignore[reportArgumentType]
133130

134131

135132
class MapMessageToLangchainTests(unittest.TestCase):
@@ -280,7 +277,7 @@ def test_map_message_to_langchain_subagent(self) -> None:
280277

281278
def test_map_message_to_langchain_invalid_raises(self) -> None:
282279
with pytest.raises(InvalidMessageTypeError):
283-
lc._map_message_to_langchain(object())
280+
lc._map_message_to_langchain(object()) # pyright: ignore[reportArgumentType]
284281

285282

286283
class CreateLangchainModelTests(unittest.TestCase):
@@ -289,7 +286,9 @@ def test_create_langchain_model_invalid_raises(self) -> None:
289286
lc._create_langchain_model(PredefinedModel(model="unknown"))
290287

291288
def test_create_langchain_model_openai(self) -> None:
292-
langchain_openai = pytest.importorskip("langchain_openai")
289+
pytest.importorskip("langchain_openai")
290+
import langchain_openai
291+
293292
model = OpenAIModel(
294293
model="gpt-test",
295294
base_url="https://example.com",
@@ -302,7 +301,3 @@ def test_create_langchain_model_openai(self) -> None:
302301
assert result.model_name == model.model
303302
assert result.openai_api_base == model.base_url
304303
assert result.temperature == model.temperature
305-
306-
307-
if __name__ == "__main__":
308-
unittest.main()

0 commit comments

Comments
 (0)