Skip to content

Commit 0d0bbc2

Browse files
authored
Support async local tools (#48)
1 parent f0d7430 commit 0d0bbc2

4 files changed

Lines changed: 61 additions & 2 deletions

File tree

splunklib/ai/registry.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -102,12 +102,14 @@ async def _() -> list[types.Tool]:
102102

103103
@self._server.call_tool(validate_input=True)
104104
async def _(name: str, arguments: dict[str, Any]) -> types.CallToolResult:
105-
return self._call_tool(name, arguments)
105+
return await self._call_tool(name, arguments)
106106

107107
def _list_tools(self) -> list[types.Tool]:
108108
return self._tools
109109

110-
def _call_tool(self, name: str, arguments: dict[str, Any]) -> types.CallToolResult:
110+
async def _call_tool(
111+
self, name: str, arguments: dict[str, Any]
112+
) -> types.CallToolResult:
111113
func = self._tools_func.get(name)
112114
if func is None:
113115
raise ValueError(f"Tool {name} does not exist")
@@ -129,6 +131,11 @@ def _call_tool(self, name: str, arguments: dict[str, Any]) -> types.CallToolResu
129131

130132
res = func(**arguments)
131133

134+
# In case func was an async function, await the returned coroutine.
135+
# If not then we already have the result.
136+
if inspect.isawaitable(res):
137+
res = await res
138+
132139
if self._tools_wrapped_result.get(name):
133140
res = _WrappedResult(res)
134141

tests/integration/ai/test_registry.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,24 @@ async def test_missing_meta_params(self):
112112
self.assertEqual(res.structuredContent, None)
113113

114114

115+
class TestAsyncToolRegistry(TestRegistryTestCase):
116+
async def test_tool_hello(self):
117+
async with self.connect("async_tool.py") as session:
118+
res = await session.call_tool(
119+
"hello",
120+
arguments={"name": "Stefan"},
121+
meta={
122+
"splunk": {
123+
"management_token": self.get_splunk_token(),
124+
"management_url": self.splunk_url,
125+
}
126+
},
127+
)
128+
self.assertEqual(res.isError, False)
129+
self.assertEqual(res.content, [])
130+
self.assertEqual(res.structuredContent, {"result": "Hello Stefan"})
131+
132+
115133
if __name__ == "__main__":
116134
import unittest
117135

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
from splunklib.ai.registry import ToolRegistry
2+
3+
registry = ToolRegistry()
4+
5+
6+
@registry.tool()
7+
async def hello(name: str) -> str:
8+
return f"Hello {name}"
9+
10+
11+
registry.run()

tests/unit/ai/test_registry_unit.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -331,6 +331,29 @@ def fancy_tool(foo: int | None, bar: Data, baz: int = -1) -> Data:
331331
},
332332
)
333333

334+
def test_async_tool(self) -> None:
335+
r = ToolRegistry()
336+
337+
@r.tool()
338+
async def str_tool() -> str:
339+
return ""
340+
341+
tool = r._tools[0]
342+
self.assertEqual(tool.name, "str_tool")
343+
self.assertEqual(
344+
tool.inputSchema,
345+
{"properties": {}, "type": "object", "additionalProperties": False},
346+
)
347+
self.assertEqual(
348+
tool.outputSchema,
349+
{
350+
"properties": {"result": {"title": "Result", "type": "string"}},
351+
"required": ["result"],
352+
"title": "_WrappedResult",
353+
"type": "object",
354+
},
355+
)
356+
334357

335358
class TestParams(unittest.TestCase):
336359
def test_description_param(self) -> None:

0 commit comments

Comments
 (0)